diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/BQTableSchemaToProtoDescriptor.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/BQTableSchemaToProtoDescriptor.java new file mode 100644 index 0000000000..ec9083fefa --- /dev/null +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/BQTableSchemaToProtoDescriptor.java @@ -0,0 +1,152 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.DescriptorProtos.DescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto; +import com.google.protobuf.DescriptorProtos.FileDescriptorProto; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +/** + * Converts a BQ table schema to protobuf descriptor. All field names will be converted to lowercase + * when constructing the protobuf descriptor. The mapping between field types and field modes are + * shown in the ImmutableMaps below. + */ +public class BQTableSchemaToProtoDescriptor { + private static ImmutableMap + BQTableSchemaModeMap = + ImmutableMap.of( + TableFieldSchema.Mode.NULLABLE, FieldDescriptorProto.Label.LABEL_OPTIONAL, + TableFieldSchema.Mode.REPEATED, FieldDescriptorProto.Label.LABEL_REPEATED, + TableFieldSchema.Mode.REQUIRED, FieldDescriptorProto.Label.LABEL_REQUIRED); + + private static ImmutableMap + BQTableSchemaTypeMap = + new ImmutableMap.Builder() + .put(TableFieldSchema.Type.BOOL, FieldDescriptorProto.Type.TYPE_BOOL) + .put(TableFieldSchema.Type.BYTES, FieldDescriptorProto.Type.TYPE_BYTES) + .put(TableFieldSchema.Type.DATE, FieldDescriptorProto.Type.TYPE_INT32) + .put(TableFieldSchema.Type.DATETIME, FieldDescriptorProto.Type.TYPE_INT64) + .put(TableFieldSchema.Type.DOUBLE, FieldDescriptorProto.Type.TYPE_DOUBLE) + .put(TableFieldSchema.Type.GEOGRAPHY, FieldDescriptorProto.Type.TYPE_STRING) + .put(TableFieldSchema.Type.INT64, FieldDescriptorProto.Type.TYPE_INT64) + .put(TableFieldSchema.Type.NUMERIC, FieldDescriptorProto.Type.TYPE_BYTES) + .put(TableFieldSchema.Type.STRING, FieldDescriptorProto.Type.TYPE_STRING) + .put(TableFieldSchema.Type.STRUCT, FieldDescriptorProto.Type.TYPE_MESSAGE) + .put(TableFieldSchema.Type.TIME, FieldDescriptorProto.Type.TYPE_INT64) + .put(TableFieldSchema.Type.TIMESTAMP, FieldDescriptorProto.Type.TYPE_INT64) + .build(); + + /** + * Converts TableFieldSchema to a Descriptors.Descriptor object. + * + * @param BQTableSchema + * @throws Descriptors.DescriptorValidationException + */ + public static Descriptor convertBQTableSchemaToProtoDescriptor(TableSchema BQTableSchema) + throws Descriptors.DescriptorValidationException { + Preconditions.checkNotNull(BQTableSchema, "BQTableSchema is null."); + return convertBQTableSchemaToProtoDescriptorImpl( + BQTableSchema, "root", new HashMap, Descriptor>()); + } + + /** + * Converts a TableFieldSchema to a Descriptors.Descriptor object. + * + * @param BQTableSchema + * @param scope Keeps track of current scope to prevent repeated naming while constructing + * descriptor. + * @param dependencyMap Stores already constructed descriptors to prevent reconstruction + * @throws Descriptors.DescriptorValidationException + */ + private static Descriptor convertBQTableSchemaToProtoDescriptorImpl( + TableSchema BQTableSchema, + String scope, + HashMap, Descriptor> dependencyMap) + throws Descriptors.DescriptorValidationException { + List dependenciesList = new ArrayList(); + List fields = new ArrayList(); + int index = 1; + for (TableFieldSchema BQTableField : BQTableSchema.getFieldsList()) { + String currentScope = scope + "__" + BQTableField.getName(); + if (BQTableField.getType() == TableFieldSchema.Type.STRUCT) { + ImmutableList fieldList = + ImmutableList.copyOf(BQTableField.getFieldsList()); + if (dependencyMap.containsKey(fieldList)) { + Descriptor descriptor = dependencyMap.get(fieldList); + dependenciesList.add(descriptor.getFile()); + fields.add(convertBQTableFieldToProtoField(BQTableField, index++, descriptor.getName())); + } else { + Descriptor descriptor = + convertBQTableSchemaToProtoDescriptorImpl( + TableSchema.newBuilder().addAllFields(fieldList).build(), + currentScope, + dependencyMap); + dependenciesList.add(descriptor.getFile()); + dependencyMap.put(fieldList, descriptor); + fields.add(convertBQTableFieldToProtoField(BQTableField, index++, currentScope)); + } + } else { + fields.add(convertBQTableFieldToProtoField(BQTableField, index++, currentScope)); + } + } + FileDescriptor[] dependenciesArray = new FileDescriptor[dependenciesList.size()]; + dependenciesArray = dependenciesList.toArray(dependenciesArray); + DescriptorProto descriptorProto = + DescriptorProto.newBuilder().setName(scope).addAllField(fields).build(); + FileDescriptorProto fileDescriptorProto = + FileDescriptorProto.newBuilder().addMessageType(descriptorProto).build(); + FileDescriptor fileDescriptor = + FileDescriptor.buildFrom(fileDescriptorProto, dependenciesArray); + Descriptor descriptor = fileDescriptor.findMessageTypeByName(scope); + return descriptor; + } + + /** + * Converts a BQTableField to ProtoField + * + * @param BQTableField BQ Field used to construct a FieldDescriptorProto + * @param index Index for protobuf fields. + * @param scope used to name descriptors + */ + private static FieldDescriptorProto convertBQTableFieldToProtoField( + TableFieldSchema BQTableField, int index, String scope) { + TableFieldSchema.Mode mode = BQTableField.getMode(); + String fieldName = BQTableField.getName().toLowerCase(); + if (BQTableField.getType() == TableFieldSchema.Type.STRUCT) { + return FieldDescriptorProto.newBuilder() + .setName(fieldName) + .setTypeName(scope) + .setLabel((FieldDescriptorProto.Label) BQTableSchemaModeMap.get(mode)) + .setNumber(index) + .build(); + } + return FieldDescriptorProto.newBuilder() + .setName(fieldName) + .setType((FieldDescriptorProto.Type) BQTableSchemaTypeMap.get(BQTableField.getType())) + .setLabel((FieldDescriptorProto.Label) BQTableSchemaModeMap.get(mode)) + .setNumber(index) + .build(); + } +} diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/JsonStreamWriter.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/JsonStreamWriter.java new file mode 100644 index 0000000000..89c417461e --- /dev/null +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/JsonStreamWriter.java @@ -0,0 +1,400 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.api.core.ApiFuture; +import com.google.api.gax.batching.BatchingSettings; +import com.google.api.gax.core.CredentialsProvider; +import com.google.api.gax.core.ExecutorProvider; +import com.google.api.gax.retrying.RetrySettings; +import com.google.api.gax.rpc.TransportChannelProvider; +import com.google.common.base.Preconditions; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Int64Value; +import com.google.protobuf.Message; +import java.io.IOException; +import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import javax.annotation.Nullable; +import org.json.JSONArray; +import org.json.JSONObject; + +/** + * A StreamWriter that can write JSON data (JSONObjects) to BigQuery tables. The JsonStreamWriter is + * built on top of a StreamWriter, and it simply converts all JSON data to protobuf messages then + * calls StreamWriter's append() method to write to BigQuery tables. It maintains all StreamWriter + * functions, but also provides an additional feature: schema update support, where if the BigQuery + * table schema is updated, users will be able to ingest data on the new schema after some time (in + * order of minutes). + */ +public class JsonStreamWriter implements AutoCloseable { + private static String streamPatternString = + "projects/[^/]+/datasets/[^/]+/tables/[^/]+/streams/[^/]+"; + private static Pattern streamPattern = Pattern.compile(streamPatternString); + private static final Logger LOG = Logger.getLogger(JsonStreamWriter.class.getName()); + + private BigQueryWriteClient client; + private String streamName; + private StreamWriter streamWriter; + private Descriptor descriptor; + private TableSchema tableSchema; + + /** + * Constructs the JsonStreamWriter + * + * @param builder The Builder object for the JsonStreamWriter + */ + private JsonStreamWriter(Builder builder) + throws Descriptors.DescriptorValidationException, IllegalArgumentException, IOException, + InterruptedException { + Matcher matcher = streamPattern.matcher(builder.streamName); + if (!matcher.matches()) { + throw new IllegalArgumentException("Invalid stream name: " + builder.streamName); + } + + this.streamName = builder.streamName; + this.client = builder.client; + this.descriptor = + BQTableSchemaToProtoDescriptor.convertBQTableSchemaToProtoDescriptor(builder.tableSchema); + + StreamWriter.Builder streamWriterBuilder; + if (this.client == null) { + streamWriterBuilder = StreamWriter.newBuilder(builder.streamName); + } else { + streamWriterBuilder = StreamWriter.newBuilder(builder.streamName, builder.client); + } + setStreamWriterSettings( + streamWriterBuilder, + builder.channelProvider, + builder.credentialsProvider, + builder.batchingSettings, + builder.retrySettings, + builder.executorProvider, + builder.endpoint); + this.streamWriter = streamWriterBuilder.build(); + } + /** + * Writes a JSONArray that contains JSONObjects to the BigQuery table by first converting the JSON + * data to protobuf messages, then using StreamWriter's append() to write the data. If there is a + * schema update, the OnSchemaUpdateRunnable will be used to determine what actions to perform. + * + * @param jsonArr The JSON array that contains JSONObjects to be written + * @param allowUnknownFields if true, json data can have fields unknown to the BigQuery table. + * @return ApiFuture returns an AppendRowsResponse message wrapped in an + * ApiFuture + */ + public ApiFuture append(JSONArray jsonArr, boolean allowUnknownFields) { + return append(jsonArr, -1, allowUnknownFields); + } + + /** + * Writes a JSONArray that contains JSONObjects to the BigQuery table by first converting the JSON + * data to protobuf messages, then using StreamWriter's append() to write the data. If there is a + * schema update, the OnSchemaUpdateRunnable will be used to determine what actions to perform. + * + * @param jsonArr The JSON array that contains JSONObjects to be written + * @param offset Offset for deduplication + * @param allowUnknownFields if true, json data can have fields unknown to the BigQuery table. + * @return ApiFuture returns an AppendRowsResponse message wrapped in an + * ApiFuture + */ + public ApiFuture append( + JSONArray jsonArr, long offset, boolean allowUnknownFields) { + ProtoRows.Builder rowsBuilder = ProtoRows.newBuilder(); + // Any error in convertJsonToProtoMessage will throw an + // IllegalArgumentException/IllegalStateException/NullPointerException and will halt processing + // of JSON data. + for (int i = 0; i < jsonArr.length(); i++) { + JSONObject json = jsonArr.getJSONObject(i); + Message protoMessage = + JsonToProtoMessage.convertJsonToProtoMessage(this.descriptor, json, allowUnknownFields); + rowsBuilder.addSerializedRows(protoMessage.toByteString()); + } + AppendRowsRequest.ProtoData.Builder data = AppendRowsRequest.ProtoData.newBuilder(); + // Need to make sure refreshAppendAndSetDescriptor finish first before this can run + synchronized (this) { + data.setWriterSchema(ProtoSchemaConverter.convert(this.descriptor)); + data.setRows(rowsBuilder.build()); + final ApiFuture appendResponseFuture = + this.streamWriter.append( + AppendRowsRequest.newBuilder() + .setProtoRows(data.build()) + .setOffset(Int64Value.of(offset)) + .build()); + return appendResponseFuture; + } + } + + /** + * Refreshes connection for a JsonStreamWriter by first flushing all remaining rows, then calling + * refreshAppend(), and finally setting the descriptor. All of these actions need to be performed + * atomically to avoid having synchronization issues with append(). Flushing all rows first is + * necessary since if there are rows remaining when the connection refreshes, it will send out the + * old writer schema instead of the new one. + */ + void refreshConnection() + throws IOException, InterruptedException, Descriptors.DescriptorValidationException { + synchronized (this) { + this.streamWriter.writeAllOutstanding(); + this.streamWriter.refreshAppend(); + this.descriptor = + BQTableSchemaToProtoDescriptor.convertBQTableSchemaToProtoDescriptor(this.tableSchema); + } + } + + /** + * Gets streamName + * + * @return String + */ + public String getStreamName() { + return this.streamName; + } + + /** + * Gets current descriptor + * + * @return Descriptor + */ + public Descriptor getDescriptor() { + return this.descriptor; + } + + /** Sets all StreamWriter settings. */ + private void setStreamWriterSettings( + StreamWriter.Builder builder, + @Nullable TransportChannelProvider channelProvider, + @Nullable CredentialsProvider credentialsProvider, + @Nullable BatchingSettings batchingSettings, + @Nullable RetrySettings retrySettings, + @Nullable ExecutorProvider executorProvider, + @Nullable String endpoint) { + if (channelProvider != null) { + builder.setChannelProvider(channelProvider); + } + if (credentialsProvider != null) { + builder.setCredentialsProvider(credentialsProvider); + } + if (batchingSettings != null) { + builder.setBatchingSettings(batchingSettings); + } + if (retrySettings != null) { + builder.setRetrySettings(retrySettings); + } + if (executorProvider != null) { + builder.setExecutorProvider(executorProvider); + } + if (endpoint != null) { + builder.setEndpoint(endpoint); + } + JsonStreamWriterOnSchemaUpdateRunnable jsonStreamWriterOnSchemaUpdateRunnable = + new JsonStreamWriterOnSchemaUpdateRunnable(); + jsonStreamWriterOnSchemaUpdateRunnable.setJsonStreamWriter(this); + builder.setOnSchemaUpdateRunnable(jsonStreamWriterOnSchemaUpdateRunnable); + } + + /** + * Setter for table schema. Used for schema updates. + * + * @param tableSchema + */ + void setTableSchema(TableSchema tableSchema) { + this.tableSchema = tableSchema; + } + + /** + * newBuilder that constructs a JsonStreamWriter builder with BigQuery client being initialized by + * StreamWriter by default. + * + * @param streamName name of the stream that must follow + * "projects/[^/]+/datasets/[^/]+/tables/[^/]+/streams/[^/]+" + * @param tableSchema The schema of the table when the stream was created, which is passed back + * through {@code WriteStream} + * @return Builder + */ + public static Builder newBuilder(String streamName, TableSchema tableSchema) { + Preconditions.checkNotNull(streamName, "StreamName is null."); + Preconditions.checkNotNull(tableSchema, "TableSchema is null."); + return new Builder(streamName, tableSchema, null); + } + + /** + * newBuilder that constructs a JsonStreamWriter builder. + * + * @param streamName name of the stream that must follow + * "projects/[^/]+/datasets/[^/]+/tables/[^/]+/streams/[^/]+" + * @param tableSchema The schema of the table when the stream was created, which is passed back + * through {@code WriteStream} + * @param client + * @return Builder + */ + public static Builder newBuilder( + String streamName, TableSchema tableSchema, BigQueryWriteClient client) { + Preconditions.checkNotNull(streamName, "StreamName is null."); + Preconditions.checkNotNull(tableSchema, "TableSchema is null."); + Preconditions.checkNotNull(client, "BigQuery client is null."); + return new Builder(streamName, tableSchema, client); + } + + /** Closes the underlying StreamWriter. */ + @Override + public void close() { + this.streamWriter.close(); + } + + private class JsonStreamWriterOnSchemaUpdateRunnable extends OnSchemaUpdateRunnable { + private JsonStreamWriter jsonStreamWriter; + /** + * Setter for the jsonStreamWriter + * + * @param jsonStreamWriter + */ + public void setJsonStreamWriter(JsonStreamWriter jsonStreamWriter) { + this.jsonStreamWriter = jsonStreamWriter; + } + + /** Getter for the jsonStreamWriter */ + public JsonStreamWriter getJsonStreamWriter() { + return this.jsonStreamWriter; + } + + @Override + public void run() { + this.getJsonStreamWriter().setTableSchema(this.getUpdatedSchema()); + try { + this.getJsonStreamWriter().refreshConnection(); + } catch (InterruptedException | IOException e) { + LOG.severe("StreamWriter failed to refresh upon schema update." + e); + return; + } catch (Descriptors.DescriptorValidationException e) { + LOG.severe( + "Schema update fail: updated schema could not be converted to a valid descriptor."); + return; + } + LOG.info("Successfully updated schema: " + this.getUpdatedSchema()); + } + } + + public static final class Builder { + private String streamName; + private BigQueryWriteClient client; + private TableSchema tableSchema; + + private TransportChannelProvider channelProvider; + private CredentialsProvider credentialsProvider; + private BatchingSettings batchingSettings; + private RetrySettings retrySettings; + private ExecutorProvider executorProvider; + private String endpoint; + + /** + * Constructor for JsonStreamWriter's Builder + * + * @param streamName name of the stream that must follow + * "projects/[^/]+/datasets/[^/]+/tables/[^/]+/streams/[^/]+" + * @param tableSchema schema used to convert Json to proto messages. + * @param client + */ + private Builder(String streamName, TableSchema tableSchema, BigQueryWriteClient client) { + this.streamName = streamName; + this.tableSchema = tableSchema; + this.client = client; + } + + /** + * Setter for the underlying StreamWriter's TransportChannelProvider. + * + * @param channelProvider + * @return Builder + */ + public Builder setChannelProvider(TransportChannelProvider channelProvider) { + this.channelProvider = + Preconditions.checkNotNull(channelProvider, "ChannelProvider is null."); + return this; + } + + /** + * Setter for the underlying StreamWriter's CredentialsProvider. + * + * @param credentialsProvider + * @return Builder + */ + public Builder setCredentialsProvider(CredentialsProvider credentialsProvider) { + this.credentialsProvider = + Preconditions.checkNotNull(credentialsProvider, "CredentialsProvider is null."); + return this; + } + + /** + * Setter for the underlying StreamWriter's BatchingSettings. + * + * @param batchingSettings + * @return Builder + */ + public Builder setBatchingSettings(BatchingSettings batchingSettings) { + this.batchingSettings = + Preconditions.checkNotNull(batchingSettings, "BatchingSettings is null."); + return this; + } + + /** + * Setter for the underlying StreamWriter's RetrySettings. + * + * @param retrySettings + * @return Builder + */ + public Builder setRetrySettings(RetrySettings retrySettings) { + this.retrySettings = Preconditions.checkNotNull(retrySettings, "RetrySettings is null."); + return this; + } + + /** + * Setter for the underlying StreamWriter's ExecutorProvider. + * + * @param executorProvider + * @return Builder + */ + public Builder setExecutorProvider(ExecutorProvider executorProvider) { + this.executorProvider = + Preconditions.checkNotNull(executorProvider, "ExecutorProvider is null."); + return this; + } + + /** + * Setter for the underlying StreamWriter's Endpoint. + * + * @param endpoint + * @return Builder + */ + public Builder setEndpoint(String endpoint) { + this.endpoint = Preconditions.checkNotNull(endpoint, "Endpoint is null."); + return this; + } + + /** + * Builds JsonStreamWriter + * + * @return JsonStreamWriter + */ + public JsonStreamWriter build() + throws Descriptors.DescriptorValidationException, IllegalArgumentException, IOException, + InterruptedException { + return new JsonStreamWriter(this); + } + } +} diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/JsonToProtoMessage.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/JsonToProtoMessage.java new file mode 100644 index 0000000000..8182e21176 --- /dev/null +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/JsonToProtoMessage.java @@ -0,0 +1,323 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import com.google.protobuf.UninitializedMessageException; +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; + +/** + * Converts Json data to protocol buffer messages given the protocol buffer descriptor. The protobuf + * descriptor must have all fields lowercased. + */ +public class JsonToProtoMessage { + private static ImmutableMap FieldTypeToDebugMessage = + new ImmutableMap.Builder() + .put(FieldDescriptor.Type.BOOL, "boolean") + .put(FieldDescriptor.Type.BYTES, "string") + .put(FieldDescriptor.Type.INT32, "int32") + .put(FieldDescriptor.Type.DOUBLE, "double") + .put(FieldDescriptor.Type.INT64, "int64") + .put(FieldDescriptor.Type.STRING, "string") + .put(FieldDescriptor.Type.MESSAGE, "object") + .build(); + + /** + * Converts Json data to protocol buffer messages given the protocol buffer descriptor. + * + * @param protoSchema + * @param json + * @param allowUnknownFields Ignores unknown JSON fields. + * @throws IllegalArgumentException when JSON data is not compatible with proto descriptor. + */ + public static DynamicMessage convertJsonToProtoMessage( + Descriptor protoSchema, JSONObject json, boolean allowUnknownFields) + throws IllegalArgumentException { + Preconditions.checkNotNull(json, "JSONObject is null."); + Preconditions.checkNotNull(protoSchema, "Protobuf descriptor is null."); + Preconditions.checkState(json.length() != 0, "JSONObject is empty."); + + return convertJsonToProtoMessageImpl( + protoSchema, json, "root", /*topLevel=*/ true, allowUnknownFields); + } + + /** + * Converts Json data to protocol buffer messages given the protocol buffer descriptor. + * + * @param protoSchema + * @param json + * @param jsonScope Debugging purposes + * @param allowUnknownFields Ignores unknown JSON fields. + * @param topLevel checks if root level has any matching fields. + * @throws IllegalArgumentException when JSON data is not compatible with proto descriptor. + */ + private static DynamicMessage convertJsonToProtoMessageImpl( + Descriptor protoSchema, + JSONObject json, + String jsonScope, + boolean topLevel, + boolean allowUnknownFields) + throws IllegalArgumentException { + + DynamicMessage.Builder protoMsg = DynamicMessage.newBuilder(protoSchema); + String[] jsonNames = JSONObject.getNames(json); + if (jsonNames == null) { + return protoMsg.build(); + } + int matchedFields = 0; + for (int i = 0; i < jsonNames.length; i++) { + String jsonName = jsonNames[i]; + // We want lowercase here to support case-insensitive data writes. + // The protobuf descriptor that is used is assumed to have all lowercased fields + String jsonLowercaseName = jsonName.toLowerCase(); + String currentScope = jsonScope + "." + jsonName; + FieldDescriptor field = protoSchema.findFieldByName(jsonLowercaseName); + if (field == null) { + if (!allowUnknownFields) { + throw new IllegalArgumentException( + String.format( + "JSONObject has fields unknown to BigQuery: %s. Set allowUnknownFields to True to allow unknown fields.", + currentScope)); + } else { + continue; + } + } + matchedFields++; + if (!field.isRepeated()) { + fillField(protoMsg, field, json, jsonName, currentScope, allowUnknownFields); + } else { + fillRepeatedField(protoMsg, field, json, jsonName, currentScope, allowUnknownFields); + } + } + + if (matchedFields == 0 && topLevel) { + throw new IllegalArgumentException( + "There are no matching fields found for the JSONObject and the protocol buffer descriptor."); + } + DynamicMessage msg; + try { + msg = protoMsg.build(); + } catch (UninitializedMessageException e) { + String errorMsg = e.getMessage(); + int idxOfColon = errorMsg.indexOf(":"); + String missingFieldName = errorMsg.substring(idxOfColon + 2); + throw new IllegalArgumentException( + String.format( + "JSONObject does not have the required field %s.%s.", jsonScope, missingFieldName)); + } + if (topLevel && msg.getSerializedSize() == 0) { + throw new IllegalArgumentException("The created protobuf message is empty."); + } + return msg; + } + + /** + * Fills a non-repetaed protoField with the json data. + * + * @param protoMsg The protocol buffer message being constructed + * @param fieldDescriptor + * @param json + * @param exactJsonKeyName Exact key name in JSONObject instead of lowercased version + * @param currentScope Debugging purposes + * @param allowUnknownFields Ignores unknown JSON fields. + * @throws IllegalArgumentException when JSON data is not compatible with proto descriptor. + */ + private static void fillField( + DynamicMessage.Builder protoMsg, + FieldDescriptor fieldDescriptor, + JSONObject json, + String exactJsonKeyName, + String currentScope, + boolean allowUnknownFields) + throws IllegalArgumentException { + + java.lang.Object val = json.get(exactJsonKeyName); + switch (fieldDescriptor.getType()) { + case BOOL: + if (val instanceof Boolean) { + protoMsg.setField(fieldDescriptor, (Boolean) val); + return; + } + break; + case BYTES: + if (val instanceof String) { + protoMsg.setField(fieldDescriptor, ((String) val).getBytes()); + return; + } + break; + case INT64: + if (val instanceof Integer) { + protoMsg.setField(fieldDescriptor, new Long((Integer) val)); + return; + } else if (val instanceof Long) { + protoMsg.setField(fieldDescriptor, (Long) val); + return; + } + break; + case INT32: + if (val instanceof Integer) { + protoMsg.setField(fieldDescriptor, (Integer) val); + return; + } + break; + case STRING: + if (val instanceof String) { + protoMsg.setField(fieldDescriptor, (String) val); + return; + } + break; + case DOUBLE: + if (val instanceof Double) { + protoMsg.setField(fieldDescriptor, (Double) val); + return; + } else if (val instanceof Float) { + protoMsg.setField(fieldDescriptor, new Double((Float) val)); + return; + } + break; + case MESSAGE: + if (val instanceof JSONObject) { + Message.Builder message = protoMsg.newBuilderForField(fieldDescriptor); + protoMsg.setField( + fieldDescriptor, + convertJsonToProtoMessageImpl( + fieldDescriptor.getMessageType(), + json.getJSONObject(exactJsonKeyName), + currentScope, + /*topLevel =*/ false, + allowUnknownFields)); + return; + } + break; + } + throw new IllegalArgumentException( + String.format( + "JSONObject does not have a %s field at %s.", + FieldTypeToDebugMessage.get(fieldDescriptor.getType()), currentScope)); + } + + /** + * Fills a repeated protoField with the json data. + * + * @param protoMsg The protocol buffer message being constructed + * @param fieldDescriptor + * @param json If root level has no matching fields, throws exception. + * @param exactJsonKeyName Exact key name in JSONObject instead of lowercased version + * @param currentScope Debugging purposes + * @param allowUnknownFields Ignores unknown JSON fields. + * @throws IllegalArgumentException when JSON data is not compatible with proto descriptor. + */ + private static void fillRepeatedField( + DynamicMessage.Builder protoMsg, + FieldDescriptor fieldDescriptor, + JSONObject json, + String exactJsonKeyName, + String currentScope, + boolean allowUnknownFields) + throws IllegalArgumentException { + + JSONArray jsonArray; + try { + jsonArray = json.getJSONArray(exactJsonKeyName); + } catch (JSONException e) { + throw new IllegalArgumentException( + "JSONObject does not have a array field at " + currentScope + "."); + } + java.lang.Object val; + int index; + boolean fail = false; + for (int i = 0; i < jsonArray.length(); i++) { + val = jsonArray.get(i); + index = i; + switch (fieldDescriptor.getType()) { + case BOOL: + if (val instanceof Boolean) { + protoMsg.addRepeatedField(fieldDescriptor, (Boolean) val); + } else { + fail = true; + } + break; + case BYTES: + if (val instanceof String) { + protoMsg.addRepeatedField(fieldDescriptor, ((String) val).getBytes()); + } else { + fail = true; + } + break; + case INT64: + if (val instanceof Integer) { + protoMsg.addRepeatedField(fieldDescriptor, new Long((Integer) val)); + } else if (val instanceof Long) { + protoMsg.addRepeatedField(fieldDescriptor, (Long) val); + } else { + fail = true; + } + break; + case INT32: + if (val instanceof Integer) { + protoMsg.addRepeatedField(fieldDescriptor, (Integer) val); + } else { + fail = true; + } + break; + case STRING: + if (val instanceof String) { + protoMsg.addRepeatedField(fieldDescriptor, (String) val); + } else { + fail = true; + } + break; + case DOUBLE: + if (val instanceof Double) { + protoMsg.addRepeatedField(fieldDescriptor, (Double) val); + } else if (val instanceof Float) { + protoMsg.addRepeatedField(fieldDescriptor, new Double((float) val)); + } else { + fail = true; + } + break; + case MESSAGE: + if (val instanceof JSONObject) { + Message.Builder message = protoMsg.newBuilderForField(fieldDescriptor); + protoMsg.addRepeatedField( + fieldDescriptor, + convertJsonToProtoMessageImpl( + fieldDescriptor.getMessageType(), + jsonArray.getJSONObject(i), + currentScope, + /*topLevel =*/ false, + allowUnknownFields)); + } else { + fail = true; + } + break; + } + if (fail) { + throw new IllegalArgumentException( + String.format( + "JSONObject does not have a %s field at %s[%d].", + FieldTypeToDebugMessage.get(fieldDescriptor.getType()), currentScope, index)); + } + } + } +} diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/OnSchemaUpdateRunnable.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/OnSchemaUpdateRunnable.java new file mode 100644 index 0000000000..17c961cab7 --- /dev/null +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/OnSchemaUpdateRunnable.java @@ -0,0 +1,54 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +/** + * A abstract class that implements the Runnable interface and provides access to the current + * StreamWriter and updatedSchema. This runnable will only be called when a updated schema has been + * passed back through the AppendRowsResponse. Users should only implement the run() function. + */ +public abstract class OnSchemaUpdateRunnable implements Runnable { + private StreamWriter streamWriter; + private TableSchema updatedSchema; + + /** + * Setter for the updatedSchema + * + * @param updatedSchema + */ + void setUpdatedSchema(TableSchema updatedSchema) { + this.updatedSchema = updatedSchema; + } + + /** + * Setter for the streamWriter + * + * @param streamWriter + */ + void setStreamWriter(StreamWriter streamWriter) { + this.streamWriter = streamWriter; + } + + /** Getter for the updatedSchema */ + TableSchema getUpdatedSchema() { + return this.updatedSchema; + } + + /** Getter for the streamWriter */ + StreamWriter getStreamWriter() { + return this.streamWriter; + } +} diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/ProtoSchemaConverter.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/ProtoSchemaConverter.java new file mode 100644 index 0000000000..f2112a2be0 --- /dev/null +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/ProtoSchemaConverter.java @@ -0,0 +1,118 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.api.gax.grpc.GrpcStatusCode; +import com.google.api.gax.rpc.InvalidArgumentException; +import com.google.protobuf.DescriptorProtos.DescriptorProto; +import com.google.protobuf.DescriptorProtos.EnumDescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import io.grpc.Status; +import java.util.HashSet; +import java.util.Set; + +// A Converter class that turns a native protobuf::DescriptorProto to a self contained +// protobuf::DescriptorProto +// that can be reconstructed by the backend. +public class ProtoSchemaConverter { + private static String getNameFromFullName(String fullName) { + return fullName.replace('.', '_'); + } + + private static ProtoSchema convertInternal( + Descriptor input, + Set visitedTypes, + Set enumTypes, + Set structTypes, + DescriptorProto.Builder rootProtoSchema) { + DescriptorProto.Builder resultProto = DescriptorProto.newBuilder(); + if (rootProtoSchema == null) { + rootProtoSchema = resultProto; + } + String protoFullName = input.getFullName(); + String protoName = getNameFromFullName(protoFullName); + resultProto.setName(protoName); + Set localEnumTypes = new HashSet(); + visitedTypes.add(input.getFullName()); + for (int i = 0; i < input.getFields().size(); i++) { + FieldDescriptor inputField = input.getFields().get(i); + FieldDescriptorProto.Builder resultField = inputField.toProto().toBuilder(); + if (inputField.getType() == FieldDescriptor.Type.GROUP + || inputField.getType() == FieldDescriptor.Type.MESSAGE) { + String msgFullName = inputField.getMessageType().getFullName(); + String msgName = getNameFromFullName(msgFullName); + if (structTypes.contains(msgFullName)) { + resultField.setTypeName(msgName); + } else { + if (visitedTypes.contains(msgFullName)) { + throw new InvalidArgumentException( + "Recursive type is not supported:" + inputField.getMessageType().getFullName(), + null, + GrpcStatusCode.of(Status.Code.INVALID_ARGUMENT), + false); + } + visitedTypes.add(msgFullName); + rootProtoSchema.addNestedType( + convertInternal( + inputField.getMessageType(), + visitedTypes, + enumTypes, + structTypes, + rootProtoSchema) + .getProtoDescriptor()); + visitedTypes.remove(msgFullName); + resultField.setTypeName( + rootProtoSchema.getNestedType(rootProtoSchema.getNestedTypeCount() - 1).getName()); + } + } + + if (inputField.getType() == FieldDescriptor.Type.ENUM) { + // For enums, in order to avoid value conflict, we will always define + // a enclosing struct called enum_full_name_E that includes the actual + // enum. + String enumFullName = inputField.getEnumType().getFullName(); + String enclosingTypeName = getNameFromFullName(enumFullName) + "_E"; + String enumName = inputField.getEnumType().getName(); + String actualEnumFullName = enclosingTypeName + "." + enumName; + if (enumTypes.contains(enumFullName)) { + resultField.setTypeName(actualEnumFullName); + } else { + EnumDescriptorProto enumType = inputField.getEnumType().toProto(); + resultProto.addNestedType( + DescriptorProto.newBuilder() + .setName(enclosingTypeName) + .addEnumType(enumType.toBuilder().setName(enumName)) + .build()); + resultField.setTypeName(actualEnumFullName); + enumTypes.add(enumFullName); + } + } + resultProto.addField(resultField); + } + structTypes.add(protoFullName); + + return ProtoSchema.newBuilder().setProtoDescriptor(resultProto.build()).build(); + } + + public static ProtoSchema convert(Descriptor descriptor) { + Set visitedTypes = new HashSet(); + Set enumTypes = new HashSet(); + Set structTypes = new HashSet(); + return convertInternal(descriptor, visitedTypes, enumTypes, structTypes, null); + } +} diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/SchemaCompatibility.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/SchemaCompatibility.java new file mode 100644 index 0000000000..238bbbcf34 --- /dev/null +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/SchemaCompatibility.java @@ -0,0 +1,543 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.LegacySQLTypeName; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.testing.RemoteBigQueryHelper; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.protobuf.Descriptors; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * A class that checks the schema compatibility between Proto schema in proto descriptor and + * Bigquery table schema. If this check is passed, then user can write to BigQuery table using the + * user schema, otherwise the write will fail. + * + *

The implementation as of now is not complete, which measn, if this check passed, there is + * still a possbility of writing will fail. + */ +public class SchemaCompatibility { + private BigQuery bigquery; + private static SchemaCompatibility compat; + private static String tablePatternString = "projects/([^/]+)/datasets/([^/]+)/tables/([^/]+)"; + private static Pattern tablePattern = Pattern.compile(tablePatternString); + private static final int NestingLimit = 15; + // private static Set SupportedTypesHashSet = + + private static Set SupportedTypes = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Descriptors.FieldDescriptor.Type.INT32, + Descriptors.FieldDescriptor.Type.INT64, + Descriptors.FieldDescriptor.Type.UINT32, + Descriptors.FieldDescriptor.Type.UINT64, + Descriptors.FieldDescriptor.Type.FIXED32, + Descriptors.FieldDescriptor.Type.FIXED64, + Descriptors.FieldDescriptor.Type.SFIXED32, + Descriptors.FieldDescriptor.Type.SFIXED64, + Descriptors.FieldDescriptor.Type.FLOAT, + Descriptors.FieldDescriptor.Type.DOUBLE, + Descriptors.FieldDescriptor.Type.BOOL, + Descriptors.FieldDescriptor.Type.BYTES, + Descriptors.FieldDescriptor.Type.STRING, + Descriptors.FieldDescriptor.Type.MESSAGE, + Descriptors.FieldDescriptor.Type.GROUP, + Descriptors.FieldDescriptor.Type.ENUM))); + + private SchemaCompatibility(BigQuery bigquery) { + // TODO: Add functionality that allows SchemaCompatibility to build schemas. + this.bigquery = bigquery; + } + + /** + * Gets a singleton {code SchemaCompatibility} object. + * + * @return + */ + public static SchemaCompatibility getInstance() { + if (compat == null) { + RemoteBigQueryHelper bigqueryHelper = RemoteBigQueryHelper.create(); + compat = new SchemaCompatibility(bigqueryHelper.getOptions().getService()); + } + return compat; + } + + /** + * Gets a {code SchemaCompatibility} object with custom BigQuery stub. + * + * @param bigquery + * @return + */ + @VisibleForTesting + public static SchemaCompatibility getInstance(BigQuery bigquery) { + Preconditions.checkNotNull(bigquery, "BigQuery is null."); + return new SchemaCompatibility(bigquery); + } + + private TableId getTableId(String tableName) { + Matcher matcher = tablePattern.matcher(tableName); + if (!matcher.matches() || matcher.groupCount() != 3) { + throw new IllegalArgumentException("Invalid table name: " + tableName); + } + return TableId.of(matcher.group(1), matcher.group(2), matcher.group(3)); + } + + /** + * @param field + * @return True if fieldtype is supported by BQ Schema + */ + public static boolean isSupportedType(Descriptors.FieldDescriptor field) { + Preconditions.checkNotNull(field, "Field is null."); + Descriptors.FieldDescriptor.Type fieldType = field.getType(); + if (!SupportedTypes.contains(fieldType)) { + return false; + } + return true; + } + + private static boolean isCompatibleWithBQBool(Descriptors.FieldDescriptor.Type field) { + if (field == Descriptors.FieldDescriptor.Type.BOOL + || field == Descriptors.FieldDescriptor.Type.INT32 + || field == Descriptors.FieldDescriptor.Type.INT64 + || field == Descriptors.FieldDescriptor.Type.UINT32 + || field == Descriptors.FieldDescriptor.Type.UINT64 + || field == Descriptors.FieldDescriptor.Type.FIXED32 + || field == Descriptors.FieldDescriptor.Type.FIXED64 + || field == Descriptors.FieldDescriptor.Type.SFIXED32 + || field == Descriptors.FieldDescriptor.Type.SFIXED64) { + return true; + } + return false; + } + + private static boolean isCompatibleWithBQBytes(Descriptors.FieldDescriptor.Type field) { + if (field == Descriptors.FieldDescriptor.Type.BYTES) { + return true; + } + return false; + } + + private static boolean isCompatibleWithBQDate(Descriptors.FieldDescriptor.Type field) { + if (field == Descriptors.FieldDescriptor.Type.INT32 + || field == Descriptors.FieldDescriptor.Type.INT64 + || field == Descriptors.FieldDescriptor.Type.SFIXED32 + || field == Descriptors.FieldDescriptor.Type.SFIXED64) { + + return true; + } + return false; + } + + private static boolean isCompatibleWithBQDatetime(Descriptors.FieldDescriptor.Type field) { + if (field == Descriptors.FieldDescriptor.Type.STRING + || field == Descriptors.FieldDescriptor.Type.INT64) { + return true; + } + return false; + } + + private static boolean isCompatibleWithBQFloat(Descriptors.FieldDescriptor.Type field) { + if (field == Descriptors.FieldDescriptor.Type.FLOAT) { + return true; + } + if (field == Descriptors.FieldDescriptor.Type.DOUBLE) { + return true; + } + return false; + } + + private static boolean isCompatibleWithBQGeography(Descriptors.FieldDescriptor.Type field) { + if (field == Descriptors.FieldDescriptor.Type.STRING) { + return true; + } + return false; + } + + private static boolean isCompatibleWithBQInteger(Descriptors.FieldDescriptor.Type field) { + if (field == Descriptors.FieldDescriptor.Type.INT64 + || field == Descriptors.FieldDescriptor.Type.SFIXED64 + || field == Descriptors.FieldDescriptor.Type.INT32 + || field == Descriptors.FieldDescriptor.Type.UINT32 + || field == Descriptors.FieldDescriptor.Type.FIXED32 + || field == Descriptors.FieldDescriptor.Type.SFIXED32 + || field == Descriptors.FieldDescriptor.Type.ENUM) { + return true; + } + return false; + } + + private static boolean isCompatibleWithBQNumeric(Descriptors.FieldDescriptor.Type field) { + if (field == Descriptors.FieldDescriptor.Type.INT32 + || field == Descriptors.FieldDescriptor.Type.INT64 + || field == Descriptors.FieldDescriptor.Type.UINT32 + || field == Descriptors.FieldDescriptor.Type.UINT64 + || field == Descriptors.FieldDescriptor.Type.FIXED32 + || field == Descriptors.FieldDescriptor.Type.FIXED64 + || field == Descriptors.FieldDescriptor.Type.SFIXED32 + || field == Descriptors.FieldDescriptor.Type.SFIXED64 + || field == Descriptors.FieldDescriptor.Type.STRING + || field == Descriptors.FieldDescriptor.Type.BYTES + || field == Descriptors.FieldDescriptor.Type.FLOAT + || field == Descriptors.FieldDescriptor.Type.DOUBLE) { + return true; + } + + return false; + } + + private static boolean isCompatibleWithBQRecord(Descriptors.FieldDescriptor.Type field) { + if (field == Descriptors.FieldDescriptor.Type.MESSAGE + || field == Descriptors.FieldDescriptor.Type.GROUP) { + return true; + } + return false; + } + + private static boolean isCompatibleWithBQString(Descriptors.FieldDescriptor.Type field) { + if (field == Descriptors.FieldDescriptor.Type.STRING + || field == Descriptors.FieldDescriptor.Type.ENUM) { + return true; + } + return false; + } + + private static boolean isCompatibleWithBQTime(Descriptors.FieldDescriptor.Type field) { + if (field == Descriptors.FieldDescriptor.Type.INT64 + || field == Descriptors.FieldDescriptor.Type.STRING) { + + return true; + } + return false; + } + + private static boolean isCompatibleWithBQTimestamp(Descriptors.FieldDescriptor.Type field) { + if (isCompatibleWithBQInteger(field)) { + return true; + } + return false; + } + + /** + * Checks if proto field option is compatible with BQ field mode. + * + * @param protoField + * @param BQField + * @param protoScope Debugging purposes to show error if messages are nested. + * @param BQScope Debugging purposes to show error if messages are nested. + * @throws IllegalArgumentException if proto field type is incompatible with BQ field type. + */ + private void protoFieldModeIsCompatibleWithBQFieldMode( + Descriptors.FieldDescriptor protoField, Field BQField, String protoScope, String BQScope) + throws IllegalArgumentException { + if (BQField.getMode() == null) { + throw new IllegalArgumentException( + "Big query schema contains invalid field option for " + BQScope + "."); + } + switch (BQField.getMode()) { + case REPEATED: + if (!protoField.isRepeated()) { + throw new IllegalArgumentException( + "Given proto field " + + protoScope + + " is not repeated but Big Query field " + + BQScope + + " is."); + } + break; + case REQUIRED: + if (!protoField.isRequired()) { + throw new IllegalArgumentException( + "Given proto field " + + protoScope + + " is not required but Big Query field " + + BQScope + + " is."); + } + break; + case NULLABLE: + if (protoField.isRepeated()) { + throw new IllegalArgumentException( + "Given proto field " + + protoScope + + " is repeated but Big Query field " + + BQScope + + " is optional."); + } + break; + } + } + /** + * Checks if proto field type is compatible with BQ field type. + * + * @param protoField + * @param BQField + * @param allowUnknownFields + * @param protoScope Debugging purposes to show error if messages are nested. + * @param BQScope Debugging purposes to show error if messages are nested. + * @param allMessageTypes Keeps track of all current protos to avoid recursively nested protos. + * @param rootProtoName Debugging purposes for nested level > 15. + * @throws IllegalArgumentException if proto field type is incompatible with BQ field type. + */ + private void protoFieldTypeIsCompatibleWithBQFieldType( + Descriptors.FieldDescriptor protoField, + Field BQField, + boolean allowUnknownFields, + String protoScope, + String BQScope, + HashSet allMessageTypes, + String rootProtoName) + throws IllegalArgumentException { + + LegacySQLTypeName BQType = BQField.getType(); + Descriptors.FieldDescriptor.Type protoType = protoField.getType(); + boolean match = false; + switch (BQType.toString()) { + case "BOOLEAN": + match = isCompatibleWithBQBool(protoType); + break; + case "BYTES": + match = isCompatibleWithBQBytes(protoType); + break; + case "DATE": + match = isCompatibleWithBQDate(protoType); + break; + case "DATETIME": + match = isCompatibleWithBQDatetime(protoType); + break; + case "FLOAT": + match = isCompatibleWithBQFloat(protoType); + break; + case "GEOGRAPHY": + match = isCompatibleWithBQGeography(protoType); + break; + case "INTEGER": + match = isCompatibleWithBQInteger(protoType); + break; + case "NUMERIC": + match = isCompatibleWithBQNumeric(protoType); + break; + case "RECORD": + if (allMessageTypes.size() > NestingLimit) { + throw new IllegalArgumentException( + "Proto schema " + + rootProtoName + + " is not supported: contains nested messages of more than 15 levels."); + } + match = isCompatibleWithBQRecord(protoType); + if (!match) { + break; + } + Descriptors.Descriptor message = protoField.getMessageType(); + if (allMessageTypes.contains(message)) { + throw new IllegalArgumentException( + "Proto schema " + protoScope + " is not supported: is a recursively nested message."); + } + allMessageTypes.add(message); + isProtoCompatibleWithBQ( + protoField.getMessageType(), + Schema.of(BQField.getSubFields()), + allowUnknownFields, + protoScope, + BQScope, + false, + allMessageTypes, + rootProtoName); + allMessageTypes.remove(message); + break; + case "STRING": + match = isCompatibleWithBQString(protoType); + break; + case "TIME": + match = isCompatibleWithBQTime(protoType); + break; + case "TIMESTAMP": + match = isCompatibleWithBQTimestamp(protoType); + break; + } + if (!match) { + throw new IllegalArgumentException( + "The proto field " + + protoScope + + " does not have a matching type with the big query field " + + BQScope + + "."); + } + } + + /** + * Checks if proto schema is compatible with BQ schema. + * + * @param protoSchema + * @param BQSchema + * @param allowUnknownFields + * @param protoScope Debugging purposes to show error if messages are nested. + * @param BQScope Debugging purposes to show error if messages are nested. + * @param topLevel True if this is the root level of proto (in terms of nested messages) + * @param allMessageTypes Keeps track of all current protos to avoid recursively nested protos. + * @param rootProtoName Debugging purposes for nested level > 15. + * @throws IllegalArgumentException if proto field type is incompatible with BQ field type. + */ + private void isProtoCompatibleWithBQ( + Descriptors.Descriptor protoSchema, + Schema BQSchema, + boolean allowUnknownFields, + String protoScope, + String BQScope, + boolean topLevel, + HashSet allMessageTypes, + String rootProtoName) + throws IllegalArgumentException { + + int matchedFields = 0; + HashMap protoFieldMap = new HashMap<>(); + List protoFields = protoSchema.getFields(); + List BQFields = BQSchema.getFields(); + + if (protoFields.size() > BQFields.size()) { + if (!allowUnknownFields) { + throw new IllegalArgumentException( + "Proto schema " + + protoScope + + " has " + + protoFields.size() + + " fields, while BQ schema " + + BQScope + + " has " + + BQFields.size() + + " fields."); + } + } + // Use hashmap to map from lowercased name to appropriate field to account for casing difference + for (Descriptors.FieldDescriptor field : protoFields) { + protoFieldMap.put(field.getName().toLowerCase(), field); + } + + for (Field BQField : BQFields) { + String fieldName = BQField.getName().toLowerCase(); + Descriptors.FieldDescriptor protoField = null; + if (protoFieldMap.containsKey(fieldName)) { + protoField = protoFieldMap.get(fieldName); + } + + String currentBQScope = BQScope + "." + BQField.getName(); + if (protoField == null && BQField.getMode() == Field.Mode.REQUIRED) { + throw new IllegalArgumentException( + "The required Big Query field " + + currentBQScope + + " is missing in the proto schema " + + protoScope + + "."); + } + if (protoField == null) { + continue; + } + String currentProtoScope = protoScope + "." + protoField.getName(); + if (!isSupportedType(protoField)) { + throw new IllegalArgumentException( + "Proto schema " + + currentProtoScope + + " is not supported: contains " + + protoField.getType() + + " field type."); + } + if (protoField.isMapField()) { + throw new IllegalArgumentException( + "Proto schema " + currentProtoScope + " is not supported: is a map field."); + } + protoFieldModeIsCompatibleWithBQFieldMode( + protoField, BQField, currentProtoScope, currentBQScope); + protoFieldTypeIsCompatibleWithBQFieldType( + protoField, + BQField, + allowUnknownFields, + currentProtoScope, + currentBQScope, + allMessageTypes, + rootProtoName); + matchedFields++; + } + + if (matchedFields == 0 && topLevel) { + throw new IllegalArgumentException( + "There is no matching fields found for the proto schema " + + protoScope + + " and the BQ table schema " + + BQScope + + "."); + } + } + + /** + * Checks if proto schema is compatible with BQ schema after retrieving BQ schema by BQTableName. + * + * @param BQTableName Must include project_id, dataset_id, and table_id in the form that matches + * the regex "projects/([^/]+)/datasets/([^/]+)/tables/([^/]+)" + * @param protoSchema + * @param allowUnknownFields Flag indicating proto can have unknown fields. + * @throws IllegalArgumentException if proto field type is incompatible with BQ field type. + */ + public void check( + String BQTableName, Descriptors.Descriptor protoSchema, boolean allowUnknownFields) + throws IllegalArgumentException { + Preconditions.checkNotNull(BQTableName, "TableName is null."); + Preconditions.checkNotNull(protoSchema, "Protobuf descriptor is null."); + + TableId tableId = getTableId(BQTableName); + Table table = bigquery.getTable(tableId); + Schema BQSchema = table.getDefinition().getSchema(); + String protoSchemaName = protoSchema.getName(); + HashSet allMessageTypes = new HashSet<>(); + allMessageTypes.add(protoSchema); + isProtoCompatibleWithBQ( + protoSchema, + BQSchema, + allowUnknownFields, + protoSchemaName, + tableId.getTable(), + true, + allMessageTypes, + protoSchemaName); + } + + /** + * Checks if proto schema is compatible with BQ schema after retrieving BQ schema by BQTableName. + * Assumes allowUnknownFields is false. + * + * @param BQTableName Must include project_id, dataset_id, and table_id in the form that matches + * the regex "projects/([^/]+)/datasets/([^/]+)/tables/([^/]+)" + * @param protoSchema + * @throws IllegalArgumentException if proto field type is incompatible with BQ field type. + */ + public void check(String BQTableName, Descriptors.Descriptor protoSchema) + throws IllegalArgumentException { + + check(BQTableName, protoSchema, false); + } +} diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriter.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriter.java new file mode 100644 index 0000000000..b7b7fbb035 --- /dev/null +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriter.java @@ -0,0 +1,1018 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.api.core.ApiFuture; +import com.google.api.core.SettableApiFuture; +import com.google.api.gax.batching.BatchingSettings; +import com.google.api.gax.batching.FlowControlSettings; +import com.google.api.gax.batching.FlowController; +import com.google.api.gax.core.BackgroundResource; +import com.google.api.gax.core.BackgroundResourceAggregation; +import com.google.api.gax.core.CredentialsProvider; +import com.google.api.gax.core.ExecutorAsBackgroundResource; +import com.google.api.gax.core.ExecutorProvider; +import com.google.api.gax.core.InstantiatingExecutorProvider; +import com.google.api.gax.grpc.GrpcStatusCode; +import com.google.api.gax.retrying.RetrySettings; +import com.google.api.gax.rpc.AbortedException; +import com.google.api.gax.rpc.BidiStreamingCallable; +import com.google.api.gax.rpc.ClientStream; +import com.google.api.gax.rpc.ResponseObserver; +import com.google.api.gax.rpc.StreamController; +import com.google.api.gax.rpc.TransportChannelProvider; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.bigquery.storage.v1beta2.StorageProto.*; +import com.google.common.base.Preconditions; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.threeten.bp.Duration; + +/** + * A BigQuery Stream Writer that can be used to write data into BigQuery Table. + * + *

This is to be used to managed streaming write when you are working with PENDING streams or + * want to explicitly manage offset. In that most common cases when writing with COMMITTED stream + * without offset, please use a simpler writer {@code DirectWriter}. + * + *

A {@link StreamWrier} provides built-in capabilities to: handle batching of messages; + * controlling memory utilization (through flow control); automatic connection re-establishment and + * request cleanup (only keeps write schema on first request in the stream). + * + *

With customizable options that control: + * + *

    + *
  • Message batching: such as number of messages or max batch byte size, and batching deadline + *
  • Inflight message control: such as number of messages or max batch byte size + *
+ * + *

{@link StreamWriter} will use the credentials set on the channel, which uses application + * default credentials through {@link GoogleCredentials#getApplicationDefault} by default. + */ +public class StreamWriter implements AutoCloseable { + private static final Logger LOG = Logger.getLogger(StreamWriter.class.getName()); + + private static String streamPatternString = + "(projects/[^/]+/datasets/[^/]+/tables/[^/]+)/streams/[^/]+"; + + private static Pattern streamPattern = Pattern.compile(streamPatternString); + + private final String streamName; + private final String tableName; + + private final BatchingSettings batchingSettings; + private final RetrySettings retrySettings; + private BigQueryWriteSettings stubSettings; + + private final Lock messagesBatchLock; + private final Lock appendAndRefreshAppendLock; + private final MessagesBatch messagesBatch; + + // Indicates if a stream has some non recoverable exception happened. + private final Lock exceptionLock; + private Throwable streamException; + + private BackgroundResource backgroundResources; + private List backgroundResourceList; + + private BigQueryWriteClient stub; + BidiStreamingCallable bidiStreamingCallable; + ClientStream clientStream; + private final AppendResponseObserver responseObserver; + + private final ScheduledExecutorService executor; + + private final AtomicBoolean shutdown; + private final Waiter messagesWaiter; + private final AtomicBoolean activeAlarm; + private ScheduledFuture currentAlarmFuture; + + private Integer currentRetries = 0; + + // Used for schema updates + private OnSchemaUpdateRunnable onSchemaUpdateRunnable; + + /** The maximum size of one request. Defined by the API. */ + public static long getApiMaxRequestBytes() { + return 10L * 1000L * 1000L; // 10 megabytes (https://en.wikipedia.org/wiki/Megabyte) + } + + /** The maximum size of in flight requests. Defined by the API. */ + public static long getApiMaxInflightRequests() { + return 5000L; + } + + private StreamWriter(Builder builder) + throws IllegalArgumentException, IOException, InterruptedException { + Matcher matcher = streamPattern.matcher(builder.streamName); + if (!matcher.matches()) { + throw new IllegalArgumentException("Invalid stream name: " + builder.streamName); + } + streamName = builder.streamName; + tableName = matcher.group(1); + + this.batchingSettings = builder.batchingSettings; + this.retrySettings = builder.retrySettings; + this.messagesBatch = new MessagesBatch(batchingSettings, this.streamName, this); + messagesBatchLock = new ReentrantLock(); + appendAndRefreshAppendLock = new ReentrantLock(); + activeAlarm = new AtomicBoolean(false); + this.exceptionLock = new ReentrantLock(); + this.streamException = null; + + executor = builder.executorProvider.getExecutor(); + backgroundResourceList = new ArrayList<>(); + if (builder.executorProvider.shouldAutoClose()) { + backgroundResourceList.add(new ExecutorAsBackgroundResource(executor)); + } + messagesWaiter = new Waiter(this.batchingSettings.getFlowControlSettings()); + responseObserver = new AppendResponseObserver(this); + + if (builder.client == null) { + stubSettings = + BigQueryWriteSettings.newBuilder() + .setCredentialsProvider(builder.credentialsProvider) + .setTransportChannelProvider(builder.channelProvider) + .setEndpoint(builder.endpoint) + .build(); + stub = BigQueryWriteClient.create(stubSettings); + backgroundResourceList.add(stub); + } else { + stub = builder.client; + } + backgroundResources = new BackgroundResourceAggregation(backgroundResourceList); + shutdown = new AtomicBoolean(false); + if (builder.onSchemaUpdateRunnable != null) { + this.onSchemaUpdateRunnable = builder.onSchemaUpdateRunnable; + this.onSchemaUpdateRunnable.setStreamWriter(this); + } + + refreshAppend(); + } + + /** Stream name we are writing to. */ + public String getStreamNameString() { + return streamName; + } + + /** Table name we are writing to. */ + public String getTableNameString() { + return tableName; + } + + /** OnSchemaUpdateRunnable for this streamWriter. */ + OnSchemaUpdateRunnable getOnSchemaUpdateRunnable() { + return this.onSchemaUpdateRunnable; + } + + private void setException(Throwable t) { + exceptionLock.lock(); + if (this.streamException == null) { + this.streamException = t; + } + exceptionLock.unlock(); + } + + /** + * Schedules the writing of a message. The write of the message may occur immediately or be + * delayed based on the writer batching options. + * + *

Example of writing a message. + * + *

{@code
+   * AppendRowsRequest message;
+   * ApiFuture messageIdFuture = writer.append(message);
+   * ApiFutures.addCallback(messageIdFuture, new ApiFutureCallback() {
+   *   public void onSuccess(AppendRowsResponse response) {
+   *     if (response.hasOffset()) {
+   *       System.out.println("written with offset: " + response.getOffset());
+   *     } else {
+   *       System.out.println("received an in stream error: " + response.error().toString());
+   *     }
+   *   }
+   *
+   *   public void onFailure(Throwable t) {
+   *     System.out.println("failed to write: " + t);
+   *   }
+   * }, MoreExecutors.directExecutor());
+   * }
+ * + * @param message the message in serialized format to write to BigQuery. + * @return the message ID wrapped in a future. + */ + public ApiFuture append(AppendRowsRequest message) { + appendAndRefreshAppendLock.lock(); + Preconditions.checkState(!shutdown.get(), "Cannot append on a shut-down writer."); + Preconditions.checkNotNull(message, "Message is null."); + final AppendRequestAndFutureResponse outstandingAppend = + new AppendRequestAndFutureResponse(message); + List batchesToSend; + messagesBatchLock.lock(); + try { + batchesToSend = messagesBatch.add(outstandingAppend); + // Setup the next duration based delivery alarm if there are messages batched. + setupAlarm(); + if (!batchesToSend.isEmpty()) { + for (final InflightBatch batch : batchesToSend) { + LOG.fine("Scheduling a batch for immediate sending."); + writeBatch(batch); + } + } + } finally { + messagesBatchLock.unlock(); + appendAndRefreshAppendLock.unlock(); + } + + return outstandingAppend.appendResult; + } + + /** + * This is the general flush method for asynchronise append operation. When you have outstanding + * append requests, calling flush will make sure all outstanding append requests completed and + * successful. Otherwise there will be an exception thrown. + * + * @throws Exception + */ + public void flushAll(long timeoutMillis) throws Exception { + appendAndRefreshAppendLock.lock(); + try { + writeAllOutstanding(); + synchronized (messagesWaiter) { + messagesWaiter.waitComplete(timeoutMillis); + } + } finally { + appendAndRefreshAppendLock.unlock(); + } + exceptionLock.lock(); + try { + if (streamException != null) { + throw new Exception(streamException); + } + } finally { + exceptionLock.unlock(); + } + } + + /** + * Re-establishes a stream connection. + * + * @throws IOException + */ + public void refreshAppend() throws IOException, InterruptedException { + appendAndRefreshAppendLock.lock(); + if (shutdown.get()) { + LOG.warning("Cannot refresh on a already shutdown writer."); + appendAndRefreshAppendLock.unlock(); + return; + } + // There could be a moment, stub is not yet initialized. + if (clientStream != null) { + LOG.info("Closing the stream " + streamName); + clientStream.closeSend(); + } + messagesBatch.resetAttachSchema(); + bidiStreamingCallable = stub.appendRowsCallable(); + clientStream = bidiStreamingCallable.splitCall(responseObserver); + try { + while (!clientStream.isSendReady()) { + Thread.sleep(10); + } + } catch (InterruptedException expected) { + } + Thread.sleep(this.retrySettings.getInitialRetryDelay().toMillis()); + // Can only unlock here since need to sleep the full 7 seconds before stream can allow appends. + appendAndRefreshAppendLock.unlock(); + LOG.info("Write Stream " + streamName + " connection established"); + } + + private void setupAlarm() { + if (!messagesBatch.isEmpty()) { + if (!activeAlarm.getAndSet(true)) { + long delayThresholdMs = getBatchingSettings().getDelayThreshold().toMillis(); + LOG.log(Level.FINE, "Setting up alarm for the next {0} ms.", delayThresholdMs); + currentAlarmFuture = + executor.schedule( + new Runnable() { + @Override + public void run() { + LOG.fine("Sending messages based on schedule"); + activeAlarm.getAndSet(false); + messagesBatchLock.lock(); + try { + writeBatch(messagesBatch.popBatch()); + } finally { + messagesBatchLock.unlock(); + } + } + }, + delayThresholdMs, + TimeUnit.MILLISECONDS); + } + } else if (currentAlarmFuture != null) { + LOG.log(Level.FINER, "Cancelling alarm, no more messages"); + if (activeAlarm.getAndSet(false)) { + currentAlarmFuture.cancel(false); + } + } + } + + /** + * Write any outstanding batches if non-empty. This method sends buffered messages, but does not + * wait for the send operations to complete. To wait for messages to send, call {@code get} on the + * futures returned from {@code append}. + */ + public void writeAllOutstanding() { + InflightBatch unorderedOutstandingBatch = null; + messagesBatchLock.lock(); + try { + if (!messagesBatch.isEmpty()) { + writeBatch(messagesBatch.popBatch()); + } + messagesBatch.reset(); + } finally { + messagesBatchLock.unlock(); + } + } + + private void writeBatch(final InflightBatch inflightBatch) { + if (inflightBatch != null) { + AppendRowsRequest request = inflightBatch.getMergedRequest(); + try { + messagesWaiter.acquire(inflightBatch.getByteSize()); + responseObserver.addInflightBatch(inflightBatch); + clientStream.send(request); + } catch (FlowController.FlowControlException ex) { + inflightBatch.onFailure(ex); + } + } + } + + /** Close the stream writer. Shut down all resources. */ + @Override + public void close() { + LOG.info("Closing stream writer:" + streamName); + shutdown(); + try { + awaitTermination(1, TimeUnit.MINUTES); + } catch (InterruptedException ignored) { + } + } + + // The batch of messages that is being sent/processed. + private static final class InflightBatch { + // List of requests that is going to be batched. + final List inflightRequests; + // A list tracks expected offset for each AppendRequest. Used to reconstruct the Response + // future. + private final ArrayList offsetList; + private final long creationTime; + private int attempt; + private long batchSizeBytes; + private long expectedOffset; + private Boolean attachSchema; + private String streamName; + private final AtomicBoolean failed; + private final StreamWriter streamWriter; + + InflightBatch( + List inflightRequests, + long batchSizeBytes, + String streamName, + Boolean attachSchema, + StreamWriter streamWriter) { + this.inflightRequests = inflightRequests; + this.offsetList = new ArrayList(inflightRequests.size()); + for (AppendRequestAndFutureResponse request : inflightRequests) { + if (request.message.getOffset().getValue() > 0) { + offsetList.add(new Long(request.message.getOffset().getValue())); + } else { + offsetList.add(new Long(-1)); + } + } + this.expectedOffset = offsetList.get(0).longValue(); + attempt = 1; + creationTime = System.currentTimeMillis(); + this.batchSizeBytes = batchSizeBytes; + this.attachSchema = attachSchema; + this.streamName = streamName; + this.failed = new AtomicBoolean(false); + this.streamWriter = streamWriter; + } + + int count() { + return inflightRequests.size(); + } + + long getByteSize() { + return this.batchSizeBytes; + } + + long getExpectedOffset() { + return expectedOffset; + } + + private AppendRowsRequest getMergedRequest() throws IllegalStateException { + if (inflightRequests.size() == 0) { + throw new IllegalStateException("Unexpected empty message batch"); + } + ProtoRows.Builder rowsBuilder = + inflightRequests.get(0).message.getProtoRows().getRows().toBuilder(); + for (int i = 1; i < inflightRequests.size(); i++) { + rowsBuilder.addAllSerializedRows( + inflightRequests.get(i).message.getProtoRows().getRows().getSerializedRowsList()); + } + AppendRowsRequest.ProtoData.Builder data = + inflightRequests.get(0).message.getProtoRows().toBuilder().setRows(rowsBuilder.build()); + AppendRowsRequest.Builder requestBuilder = inflightRequests.get(0).message.toBuilder(); + if (!attachSchema) { + data.clearWriterSchema(); + requestBuilder.clearWriteStream(); + } else { + if (!data.hasWriterSchema()) { + throw new IllegalStateException( + "The first message on the connection must have writer schema set"); + } + requestBuilder.setWriteStream(streamName); + } + return requestBuilder.setProtoRows(data.build()).build(); + } + + private void onFailure(Throwable t) { + if (failed.getAndSet(true)) { + // Error has been set already. + LOG.warning("Ignore " + t.toString() + " since error has already been set"); + return; + } else { + LOG.info("Setting " + t.toString() + " on response"); + this.streamWriter.setException(t); + } + + for (AppendRequestAndFutureResponse request : inflightRequests) { + request.appendResult.setException(t); + } + } + + // Disassemble the batched response and sets the furture on individual request. + private void onSuccess(AppendRowsResponse response) { + for (int i = 0; i < inflightRequests.size(); i++) { + AppendRowsResponse.Builder singleResponse = response.toBuilder(); + if (offsetList.get(i) > 0) { + singleResponse.setOffset(offsetList.get(i)); + } else { + long actualOffset = response.getOffset(); + for (int j = 0; j < i; j++) { + actualOffset += + inflightRequests.get(j).message.getProtoRows().getRows().getSerializedRowsCount(); + } + singleResponse.setOffset(actualOffset); + } + inflightRequests.get(i).appendResult.set(singleResponse.build()); + } + } + } + + // Class that wraps AppendRowsRequest and its cooresponding Response future. + private static final class AppendRequestAndFutureResponse { + final SettableApiFuture appendResult; + final AppendRowsRequest message; + final int messageSize; + + AppendRequestAndFutureResponse(AppendRowsRequest message) { + this.appendResult = SettableApiFuture.create(); + this.message = message; + this.messageSize = message.getProtoRows().getSerializedSize(); + if (this.messageSize > getApiMaxRequestBytes()) { + throw new StatusRuntimeException( + Status.fromCode(Status.Code.FAILED_PRECONDITION) + .withDescription("Message exceeded max size limit: " + getApiMaxRequestBytes())); + } + } + } + + /** The batching settings configured on this {@code StreamWriter}. */ + public BatchingSettings getBatchingSettings() { + return batchingSettings; + } + + /** The retry settings configured on this {@code StreamWriter}. */ + public RetrySettings getRetrySettings() { + return retrySettings; + } + + /** + * Schedules immediate flush of any outstanding messages and waits until all are processed. + * + *

Sends remaining outstanding messages and prevents future calls to publish. This method + * should be invoked prior to deleting the {@link WriteStream} object in order to ensure that no + * pending messages are lost. + */ + protected void shutdown() { + if (shutdown.getAndSet(true)) { + LOG.fine("Already shutdown."); + return; + } + LOG.fine("Shutdown called on writer"); + if (currentAlarmFuture != null && activeAlarm.getAndSet(false)) { + currentAlarmFuture.cancel(false); + } + writeAllOutstanding(); + try { + synchronized (messagesWaiter) { + messagesWaiter.waitComplete(0); + } + } catch (InterruptedException e) { + LOG.warning("Failed to wait for messages to return " + e.toString()); + } + if (clientStream.isSendReady()) { + clientStream.closeSend(); + } + backgroundResources.shutdown(); + } + + /** + * Wait for all work has completed execution after a {@link #shutdown()} request, or the timeout + * occurs, or the current thread is interrupted. + * + *

Call this method to make sure all resources are freed properly. + */ + protected boolean awaitTermination(long duration, TimeUnit unit) throws InterruptedException { + return backgroundResources.awaitTermination(duration, unit); + } + + /** + * Constructs a new {@link Builder} using the given stream. + * + *

Example of creating a {@code WriteStream}. + * + *

{@code
+   * String table = "projects/my_project/datasets/my_dataset/tables/my_table";
+   * String stream;
+   * try (BigQueryWriteClient bigqueryWriteClient = BigQueryWriteClient.create()) {
+   *     CreateWriteStreamRequest request = CreateWriteStreamRequest.newBuilder().setParent(table).build();
+   *     WriteStream response = bigQueryWriteClient.createWriteStream(request);
+   *     stream = response.getName();
+   * }
+   * try (WriteStream writer = WriteStream.newBuilder(stream).build()) {
+   *   //...
+   * }
+   * }
+ */ + public static Builder newBuilder(String streamName) { + Preconditions.checkNotNull(streamName, "StreamName is null."); + return new Builder(streamName, null); + } + + /** + * Constructs a new {@link Builder} using the given stream and an existing BigQueryWriteClient. + */ + public static Builder newBuilder(String streamName, BigQueryWriteClient client) { + Preconditions.checkNotNull(streamName, "StreamName is null."); + Preconditions.checkNotNull(client, "Client is null."); + return new Builder(streamName, client); + } + + /** A builder of {@link StreamWriter}s. */ + public static final class Builder { + static final Duration MIN_TOTAL_TIMEOUT = Duration.ofSeconds(10); + static final Duration MIN_RPC_TIMEOUT = Duration.ofMillis(10); + + // Meaningful defaults. + static final FlowControlSettings DEFAULT_FLOW_CONTROL_SETTINGS = + FlowControlSettings.newBuilder() + .setLimitExceededBehavior(FlowController.LimitExceededBehavior.Block) + .setMaxOutstandingElementCount(1000L) + .setMaxOutstandingRequestBytes(100 * 1024 * 1024L) // 100 Mb + .build(); + public static final BatchingSettings DEFAULT_BATCHING_SETTINGS = + BatchingSettings.newBuilder() + .setDelayThreshold(Duration.ofMillis(10)) + .setRequestByteThreshold(100 * 1024L) // 100 kb + .setElementCountThreshold(100L) + .setFlowControlSettings(DEFAULT_FLOW_CONTROL_SETTINGS) + .build(); + public static final RetrySettings DEFAULT_RETRY_SETTINGS = + RetrySettings.newBuilder() + .setMaxRetryDelay(Duration.ofSeconds(60)) + .setInitialRetryDelay(Duration.ofMillis(100)) + .setMaxAttempts(3) + .build(); + static final boolean DEFAULT_ENABLE_MESSAGE_ORDERING = false; + private static final int THREADS_PER_CPU = 5; + static final ExecutorProvider DEFAULT_EXECUTOR_PROVIDER = + InstantiatingExecutorProvider.newBuilder() + .setExecutorThreadCount(THREADS_PER_CPU * Runtime.getRuntime().availableProcessors()) + .build(); + + private String streamName; + private String endpoint = BigQueryWriteSettings.getDefaultEndpoint(); + + private BigQueryWriteClient client = null; + + // Batching options + BatchingSettings batchingSettings = DEFAULT_BATCHING_SETTINGS; + + RetrySettings retrySettings = DEFAULT_RETRY_SETTINGS; + + private boolean enableMessageOrdering = DEFAULT_ENABLE_MESSAGE_ORDERING; + + private TransportChannelProvider channelProvider = + BigQueryWriteSettings.defaultGrpcTransportProviderBuilder().setChannelsPerCpu(1).build(); + + ExecutorProvider executorProvider = DEFAULT_EXECUTOR_PROVIDER; + private CredentialsProvider credentialsProvider = + BigQueryWriteSettings.defaultCredentialsProviderBuilder().build(); + + private OnSchemaUpdateRunnable onSchemaUpdateRunnable; + + private Builder(String stream, BigQueryWriteClient client) { + this.streamName = Preconditions.checkNotNull(stream); + this.client = client; + } + + /** + * {@code ChannelProvider} to use to create Channels, which must point at Cloud BigQuery Storage + * API endpoint. + * + *

For performance, this client benefits from having multiple underlying connections. See + * {@link com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder#setPoolSize(int)}. + */ + public Builder setChannelProvider(TransportChannelProvider channelProvider) { + this.channelProvider = + Preconditions.checkNotNull(channelProvider, "ChannelProvider is null."); + return this; + } + + /** {@code CredentialsProvider} to use to create Credentials to authenticate calls. */ + public Builder setCredentialsProvider(CredentialsProvider credentialsProvider) { + this.credentialsProvider = + Preconditions.checkNotNull(credentialsProvider, "CredentialsProvider is null."); + return this; + } + + /** + * Sets the {@code BatchSettings} on the writer. + * + * @param batchingSettings + * @return + */ + public Builder setBatchingSettings(BatchingSettings batchingSettings) { + Preconditions.checkNotNull(batchingSettings, "BatchingSettings is null."); + + BatchingSettings.Builder builder = batchingSettings.toBuilder(); + Preconditions.checkNotNull(batchingSettings.getElementCountThreshold()); + Preconditions.checkArgument(batchingSettings.getElementCountThreshold() > 0); + Preconditions.checkNotNull(batchingSettings.getRequestByteThreshold()); + Preconditions.checkArgument(batchingSettings.getRequestByteThreshold() > 0); + if (batchingSettings.getRequestByteThreshold() > getApiMaxRequestBytes()) { + builder.setRequestByteThreshold(getApiMaxRequestBytes()); + } + Preconditions.checkNotNull(batchingSettings.getDelayThreshold()); + Preconditions.checkArgument(batchingSettings.getDelayThreshold().toMillis() > 0); + if (batchingSettings.getFlowControlSettings() == null) { + builder.setFlowControlSettings(DEFAULT_FLOW_CONTROL_SETTINGS); + } else { + + if (batchingSettings.getFlowControlSettings().getMaxOutstandingElementCount() == null) { + builder.setFlowControlSettings( + batchingSettings + .getFlowControlSettings() + .toBuilder() + .setMaxOutstandingElementCount( + DEFAULT_FLOW_CONTROL_SETTINGS.getMaxOutstandingElementCount()) + .build()); + } else { + Preconditions.checkArgument( + batchingSettings.getFlowControlSettings().getMaxOutstandingElementCount() > 0); + if (batchingSettings.getFlowControlSettings().getMaxOutstandingElementCount() + > getApiMaxInflightRequests()) { + builder.setFlowControlSettings( + batchingSettings + .getFlowControlSettings() + .toBuilder() + .setMaxOutstandingElementCount(getApiMaxInflightRequests()) + .build()); + } + } + if (batchingSettings.getFlowControlSettings().getMaxOutstandingRequestBytes() == null) { + builder.setFlowControlSettings( + batchingSettings + .getFlowControlSettings() + .toBuilder() + .setMaxOutstandingRequestBytes( + DEFAULT_FLOW_CONTROL_SETTINGS.getMaxOutstandingRequestBytes()) + .build()); + } else { + Preconditions.checkArgument( + batchingSettings.getFlowControlSettings().getMaxOutstandingRequestBytes() > 0); + } + if (batchingSettings.getFlowControlSettings().getLimitExceededBehavior() == null) { + builder.setFlowControlSettings( + batchingSettings + .getFlowControlSettings() + .toBuilder() + .setLimitExceededBehavior( + DEFAULT_FLOW_CONTROL_SETTINGS.getLimitExceededBehavior()) + .build()); + } else { + Preconditions.checkArgument( + batchingSettings.getFlowControlSettings().getLimitExceededBehavior() + != FlowController.LimitExceededBehavior.Ignore); + } + } + this.batchingSettings = builder.build(); + return this; + } + + /** + * Sets the {@code RetrySettings} on the writer. + * + * @param retrySettings + * @return + */ + public Builder setRetrySettings(RetrySettings retrySettings) { + this.retrySettings = Preconditions.checkNotNull(retrySettings, "RetrySettings is null."); + return this; + } + + /** Gives the ability to set a custom executor to be used by the library. */ + public Builder setExecutorProvider(ExecutorProvider executorProvider) { + this.executorProvider = + Preconditions.checkNotNull(executorProvider, "ExecutorProvider is null."); + return this; + } + + /** Gives the ability to override the gRPC endpoint. */ + public Builder setEndpoint(String endpoint) { + this.endpoint = Preconditions.checkNotNull(endpoint, "Endpoint is null."); + return this; + } + + /** Gives the ability to set action on schema update. */ + public Builder setOnSchemaUpdateRunnable(OnSchemaUpdateRunnable onSchemaUpdateRunnable) { + this.onSchemaUpdateRunnable = + Preconditions.checkNotNull(onSchemaUpdateRunnable, "onSchemaUpdateRunnable is null."); + return this; + } + + /** Builds the {@code StreamWriter}. */ + public StreamWriter build() throws IllegalArgumentException, IOException, InterruptedException { + return new StreamWriter(this); + } + } + + private static final class AppendResponseObserver + implements ResponseObserver { + private Queue inflightBatches = new LinkedList(); + private StreamWriter streamWriter; + + public void addInflightBatch(InflightBatch batch) { + synchronized (this.inflightBatches) { + this.inflightBatches.add(batch); + } + } + + public AppendResponseObserver(StreamWriter streamWriter) { + this.streamWriter = streamWriter; + } + + private boolean isRecoverableError(Throwable t) { + Status status = Status.fromThrowable(t); + return status.getCode() == Status.Code.UNAVAILABLE; + } + + @Override + public void onStart(StreamController controller) { + // no-op + } + + private void abortInflightRequests(Throwable t) { + synchronized (this.inflightBatches) { + while (!this.inflightBatches.isEmpty()) { + InflightBatch inflightBatch = this.inflightBatches.poll(); + inflightBatch.onFailure( + new AbortedException( + "Request aborted due to previous failures", + t, + GrpcStatusCode.of(Status.Code.ABORTED), + true)); + streamWriter.messagesWaiter.release(inflightBatch.getByteSize()); + } + } + } + + @Override + public void onResponse(AppendRowsResponse response) { + InflightBatch inflightBatch = null; + synchronized (this.inflightBatches) { + inflightBatch = this.inflightBatches.poll(); + } + try { + streamWriter.currentRetries = 0; + if (response == null) { + inflightBatch.onFailure(new IllegalStateException("Response is null")); + } + if (response.hasUpdatedSchema()) { + if (streamWriter.getOnSchemaUpdateRunnable() != null) { + streamWriter.getOnSchemaUpdateRunnable().setUpdatedSchema(response.getUpdatedSchema()); + streamWriter.executor.schedule( + streamWriter.getOnSchemaUpdateRunnable(), 0L, TimeUnit.MILLISECONDS); + } + } + // Currently there is nothing retryable. If the error is already exists, then ignore it. + if (response.hasError()) { + if (response.getError().getCode() != 6 /* ALREADY_EXISTS */) { + StatusRuntimeException exception = + new StatusRuntimeException( + Status.fromCodeValue(response.getError().getCode()) + .withDescription(response.getError().getMessage())); + inflightBatch.onFailure(exception); + } + } + if (inflightBatch.getExpectedOffset() > 0 + && response.getOffset() != inflightBatch.getExpectedOffset()) { + IllegalStateException exception = + new IllegalStateException( + String.format( + "The append result offset %s does not match " + "the expected offset %s.", + response.getOffset(), inflightBatch.getExpectedOffset())); + inflightBatch.onFailure(exception); + abortInflightRequests(exception); + } else { + inflightBatch.onSuccess(response); + } + } finally { + streamWriter.messagesWaiter.release(inflightBatch.getByteSize()); + } + } + + @Override + public void onComplete() { + LOG.info("OnComplete called"); + } + + @Override + public void onError(Throwable t) { + LOG.fine("OnError called"); + if (streamWriter.shutdown.get()) { + return; + } + InflightBatch inflightBatch = null; + synchronized (this.inflightBatches) { + if (inflightBatches.isEmpty()) { + // The batches could have been aborted. + return; + } + inflightBatch = this.inflightBatches.poll(); + } + try { + if (isRecoverableError(t)) { + try { + if (streamWriter.currentRetries < streamWriter.getRetrySettings().getMaxAttempts() + && !streamWriter.shutdown.get()) { + streamWriter.refreshAppend(); + LOG.info("Resending requests on transient error:" + streamWriter.currentRetries); + streamWriter.writeBatch(inflightBatch); + synchronized (streamWriter.currentRetries) { + streamWriter.currentRetries++; + } + } else { + inflightBatch.onFailure(t); + abortInflightRequests(t); + synchronized (streamWriter.currentRetries) { + streamWriter.currentRetries = 0; + } + } + } catch (IOException | InterruptedException e) { + LOG.info("Got exception while retrying."); + inflightBatch.onFailure(e); + abortInflightRequests(e); + synchronized (streamWriter.currentRetries) { + streamWriter.currentRetries = 0; + } + } + } else { + inflightBatch.onFailure(t); + abortInflightRequests(t); + synchronized (streamWriter.currentRetries) { + streamWriter.currentRetries = 0; + } + } + } finally { + streamWriter.messagesWaiter.release(inflightBatch.getByteSize()); + } + } + }; + + // This class controls how many messages are going to be sent out in a batch. + private static class MessagesBatch { + private List messages; + private long batchedBytes; + private final BatchingSettings batchingSettings; + private Boolean attachSchema = true; + private final String streamName; + private final StreamWriter streamWriter; + + private MessagesBatch( + BatchingSettings batchingSettings, String streamName, StreamWriter streamWriter) { + this.batchingSettings = batchingSettings; + this.streamName = streamName; + this.streamWriter = streamWriter; + reset(); + } + + // Get all the messages out in a batch. + private InflightBatch popBatch() { + InflightBatch batch = + new InflightBatch( + messages, batchedBytes, this.streamName, this.attachSchema, this.streamWriter); + this.attachSchema = false; + reset(); + return batch; + } + + private void reset() { + messages = new LinkedList<>(); + batchedBytes = 0; + } + + private void resetAttachSchema() { + attachSchema = true; + } + + private boolean isEmpty() { + return messages.isEmpty(); + } + + private long getBatchedBytes() { + return batchedBytes; + } + + private int getMessagesCount() { + return messages.size(); + } + + private boolean hasBatchingBytes() { + return getMaxBatchBytes() > 0; + } + + private long getMaxBatchBytes() { + return batchingSettings.getRequestByteThreshold(); + } + + // The message batch returned could contain the previous batch of messages plus the current + // message. + // if the message is too large. + private List add(AppendRequestAndFutureResponse outstandingAppend) { + List batchesToSend = new ArrayList<>(); + // Check if the next message makes the current batch exceed the max batch byte size. + if (!isEmpty() + && hasBatchingBytes() + && getBatchedBytes() + outstandingAppend.messageSize >= getMaxBatchBytes()) { + batchesToSend.add(popBatch()); + } + + messages.add(outstandingAppend); + batchedBytes += outstandingAppend.messageSize; + + // Border case: If the message to send is greater or equals to the max batch size then send it + // immediately. + // Alternatively if after adding the message we have reached the batch max messages then we + // have a batch to send. + if ((hasBatchingBytes() && outstandingAppend.messageSize >= getMaxBatchBytes()) + || getMessagesCount() == batchingSettings.getElementCountThreshold()) { + batchesToSend.add(popBatch()); + } + + return batchesToSend; + } + } +} diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/Waiter.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/Waiter.java new file mode 100644 index 0000000000..fd2efc489c --- /dev/null +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/Waiter.java @@ -0,0 +1,180 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.api.core.InternalApi; +import com.google.api.gax.batching.FlowControlSettings; +import com.google.api.gax.batching.FlowController; +import java.util.LinkedList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.logging.Logger; + +/** + * A barrier kind of object that helps keep track of pending actions and synchronously wait until + * all have completed. + */ +class Waiter { + private static final Logger LOG = + Logger.getLogger(com.google.cloud.bigquery.storage.v1beta2.Waiter.class.getName()); + + private long pendingCount; + private long pendingSize; + private long countLimit; + private long sizeLimit; + private FlowController.LimitExceededBehavior behavior; + private LinkedList awaitingMessageAcquires; + private LinkedList awaitingBytesAcquires; + private final Lock lock; + + Waiter(FlowControlSettings flowControlSettings) { + pendingCount = 0; + pendingSize = 0; + this.awaitingMessageAcquires = new LinkedList(); + this.awaitingBytesAcquires = new LinkedList(); + this.countLimit = flowControlSettings.getMaxOutstandingElementCount(); + this.sizeLimit = flowControlSettings.getMaxOutstandingRequestBytes(); + this.behavior = flowControlSettings.getLimitExceededBehavior(); + this.lock = new ReentrantLock(); + } + + private void notifyNextAcquires() { + if (!awaitingMessageAcquires.isEmpty()) { + CountDownLatch awaitingAcquire = awaitingMessageAcquires.getFirst(); + awaitingAcquire.countDown(); + } + if (!awaitingBytesAcquires.isEmpty()) { + CountDownLatch awaitingAcquire = awaitingBytesAcquires.getFirst(); + awaitingAcquire.countDown(); + } + } + + public synchronized void release(long messageSize) { + lock.lock(); + --pendingCount; + pendingSize -= messageSize; + notifyNextAcquires(); + lock.unlock(); + notifyAll(); + } + + public void acquire(long messageSize) throws FlowController.FlowControlException { + lock.lock(); + try { + if (pendingCount >= countLimit + && behavior == FlowController.LimitExceededBehavior.ThrowException) { + throw new FlowController.MaxOutstandingElementCountReachedException(countLimit); + } + if (pendingSize + messageSize >= sizeLimit + && behavior == FlowController.LimitExceededBehavior.ThrowException) { + throw new FlowController.MaxOutstandingRequestBytesReachedException(sizeLimit); + } + + CountDownLatch messageWaiter = null; + while (pendingCount >= countLimit) { + if (messageWaiter == null) { + messageWaiter = new CountDownLatch(1); + awaitingMessageAcquires.addLast(messageWaiter); + } else { + // This message already in line stays at the head of the line. + messageWaiter = new CountDownLatch(1); + awaitingMessageAcquires.set(0, messageWaiter); + } + lock.unlock(); + try { + messageWaiter.await(); + } catch (InterruptedException e) { + LOG.warning("Interrupted while waiting to acquire flow control tokens"); + } + lock.lock(); + } + ++pendingCount; + if (messageWaiter != null) { + awaitingMessageAcquires.removeFirst(); + } + + if (!awaitingMessageAcquires.isEmpty() && pendingCount < countLimit) { + awaitingMessageAcquires.getFirst().countDown(); + } + + // Now acquire space for bytes. + CountDownLatch bytesWaiter = null; + Long bytesRemaining = messageSize; + while (pendingSize + messageSize >= sizeLimit) { + if (bytesWaiter == null) { + // This message gets added to the back of the line. + bytesWaiter = new CountDownLatch(1); + awaitingBytesAcquires.addLast(bytesWaiter); + } else { + // This message already in line stays at the head of the line. + bytesWaiter = new CountDownLatch(1); + awaitingBytesAcquires.set(0, bytesWaiter); + } + lock.unlock(); + try { + bytesWaiter.await(); + } catch (InterruptedException e) { + LOG.warning("Interrupted while waiting to acquire flow control tokens"); + } + lock.lock(); + } + + pendingSize += messageSize; + if (bytesWaiter != null) { + awaitingBytesAcquires.removeFirst(); + } + // There may be some surplus bytes left; let the next message waiting for bytes have some. + if (!awaitingBytesAcquires.isEmpty() && pendingSize < sizeLimit) { + awaitingBytesAcquires.getFirst().countDown(); + } + } finally { + lock.unlock(); + } + } + + public synchronized void waitComplete(long timeoutMillis) throws InterruptedException { + long end = System.currentTimeMillis() + timeoutMillis; + lock.lock(); + try { + while (pendingCount > 0 && (timeoutMillis == 0 || end > System.currentTimeMillis())) { + lock.unlock(); + try { + wait(timeoutMillis == 0 ? 0 : end - System.currentTimeMillis()); + } catch (InterruptedException e) { + throw e; + } + lock.lock(); + } + if (pendingCount > 0) { + throw new InterruptedException("Wait timeout"); + } + } finally { + lock.unlock(); + } + } + + @InternalApi + public long pendingCount() { + return pendingCount; + } + + @InternalApi + public long pendingSize() { + return pendingSize; + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/BQTableSchemaToProtoDescriptorTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/BQTableSchemaToProtoDescriptorTest.java new file mode 100644 index 0000000000..6c8b945a3d --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/BQTableSchemaToProtoDescriptorTest.java @@ -0,0 +1,365 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import static org.junit.Assert.*; + +import com.google.cloud.bigquery.storage.test.JsonTest.*; +import com.google.cloud.bigquery.storage.test.SchemaTest.*; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class BQTableSchemaToProtoDescriptorTest { + // This is a map between the TableFieldSchema.Type and the descriptor it is supposed to + // produce. The produced descriptor will be used to check against the entry values here. + private static ImmutableMap + BQTableTypeToCorrectProtoDescriptorTest = + new ImmutableMap.Builder() + .put(TableFieldSchema.Type.BOOL, BoolType.getDescriptor()) + .put(TableFieldSchema.Type.BYTES, BytesType.getDescriptor()) + .put(TableFieldSchema.Type.DATE, Int32Type.getDescriptor()) + .put(TableFieldSchema.Type.DATETIME, Int64Type.getDescriptor()) + .put(TableFieldSchema.Type.DOUBLE, DoubleType.getDescriptor()) + .put(TableFieldSchema.Type.GEOGRAPHY, StringType.getDescriptor()) + .put(TableFieldSchema.Type.INT64, Int64Type.getDescriptor()) + .put(TableFieldSchema.Type.NUMERIC, BytesType.getDescriptor()) + .put(TableFieldSchema.Type.STRING, StringType.getDescriptor()) + .put(TableFieldSchema.Type.TIME, Int64Type.getDescriptor()) + .put(TableFieldSchema.Type.TIMESTAMP, Int64Type.getDescriptor()) + .build(); + + // Creates mapping from descriptor to how many times it was reused. + private void mapDescriptorToCount(Descriptor descriptor, HashMap map) { + for (FieldDescriptor field : descriptor.getFields()) { + if (field.getType() == FieldDescriptor.Type.MESSAGE) { + Descriptor subDescriptor = field.getMessageType(); + String messageName = subDescriptor.getName(); + if (map.containsKey(messageName)) { + map.put(messageName, map.get(messageName) + 1); + } else { + map.put(messageName, 1); + } + mapDescriptorToCount(subDescriptor, map); + } + } + } + + private void isDescriptorEqual(Descriptor convertedProto, Descriptor originalProto) { + // Check same number of fields + assertEquals(convertedProto.getFields().size(), originalProto.getFields().size()); + for (FieldDescriptor convertedField : convertedProto.getFields()) { + // Check field name + FieldDescriptor originalField = originalProto.findFieldByName(convertedField.getName()); + assertNotNull(originalField); + // Check type + FieldDescriptor.Type convertedType = convertedField.getType(); + FieldDescriptor.Type originalType = originalField.getType(); + assertEquals(convertedType, originalType); + // Check mode + assertTrue( + (originalField.isRepeated() == convertedField.isRepeated()) + && (originalField.isRequired() == convertedField.isRequired()) + && (originalField.isOptional() == convertedField.isOptional())); + // Recursively check nested messages + if (convertedType == FieldDescriptor.Type.MESSAGE) { + isDescriptorEqual(convertedField.getMessageType(), originalField.getMessageType()); + } + } + } + + @Test + public void testSimpleTypes() throws Exception { + for (Map.Entry entry : + BQTableTypeToCorrectProtoDescriptorTest.entrySet()) { + final TableFieldSchema tableFieldSchema = + TableFieldSchema.newBuilder() + .setType(entry.getKey()) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("test_field_type") + .build(); + final TableSchema tableSchema = + TableSchema.newBuilder().addFields(0, tableFieldSchema).build(); + final Descriptor descriptor = + BQTableSchemaToProtoDescriptor.convertBQTableSchemaToProtoDescriptor(tableSchema); + isDescriptorEqual(descriptor, entry.getValue()); + } + } + + @Test + public void testStructSimple() throws Exception { + final TableFieldSchema StringType = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRING) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("test_field_type") + .build(); + final TableFieldSchema tableFieldSchema = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRUCT) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("test_field_type") + .addFields(0, StringType) + .build(); + final TableSchema tableSchema = TableSchema.newBuilder().addFields(0, tableFieldSchema).build(); + final Descriptor descriptor = + BQTableSchemaToProtoDescriptor.convertBQTableSchemaToProtoDescriptor(tableSchema); + isDescriptorEqual(descriptor, MessageType.getDescriptor()); + } + + @Test + public void testStructComplex() throws Exception { + final TableFieldSchema test_int = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.INT64) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("test_int") + .build(); + final TableFieldSchema test_string = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRING) + .setMode(TableFieldSchema.Mode.REPEATED) + .setName("test_string") + .build(); + final TableFieldSchema test_bytes = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.BYTES) + .setMode(TableFieldSchema.Mode.REQUIRED) + .setName("test_bytes") + .build(); + final TableFieldSchema test_bool = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.BOOL) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("test_bool") + .build(); + final TableFieldSchema test_double = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.DOUBLE) + .setMode(TableFieldSchema.Mode.REPEATED) + .setName("test_double") + .build(); + final TableFieldSchema test_date = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.DATE) + .setMode(TableFieldSchema.Mode.REQUIRED) + .setName("test_date") + .build(); + final TableFieldSchema ComplexLvl2 = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRUCT) + .setMode(TableFieldSchema.Mode.REQUIRED) + .addFields(0, test_int) + .setName("complex_lvl2") + .build(); + final TableFieldSchema ComplexLvl1 = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRUCT) + .setMode(TableFieldSchema.Mode.REQUIRED) + .addFields(0, test_int) + .addFields(1, ComplexLvl2) + .setName("complex_lvl1") + .build(); + final TableSchema tableSchema = + TableSchema.newBuilder() + .addFields(0, test_int) + .addFields(1, test_string) + .addFields(2, test_bytes) + .addFields(3, test_bool) + .addFields(4, test_double) + .addFields(5, test_date) + .addFields(6, ComplexLvl1) + .addFields(7, ComplexLvl2) + .build(); + final Descriptor descriptor = + BQTableSchemaToProtoDescriptor.convertBQTableSchemaToProtoDescriptor(tableSchema); + isDescriptorEqual(descriptor, ComplexRoot.getDescriptor()); + } + + @Test + public void testCasingComplexStruct() throws Exception { + final TableFieldSchema required = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.INT64) + .setMode(TableFieldSchema.Mode.REQUIRED) + .setName("tEsT_ReQuIrEd") + .build(); + final TableFieldSchema repeated = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.INT64) + .setMode(TableFieldSchema.Mode.REPEATED) + .setName("tESt_repEATed") + .build(); + final TableFieldSchema optional = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.INT64) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("test_opTIONal") + .build(); + final TableFieldSchema test_int = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.INT64) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("TEST_INT") + .build(); + final TableFieldSchema test_string = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRING) + .setMode(TableFieldSchema.Mode.REPEATED) + .setName("TEST_STRING") + .build(); + final TableFieldSchema test_bytes = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.BYTES) + .setMode(TableFieldSchema.Mode.REQUIRED) + .setName("TEST_BYTES") + .build(); + final TableFieldSchema test_bool = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.BOOL) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("TEST_BOOL") + .build(); + final TableFieldSchema test_double = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.DOUBLE) + .setMode(TableFieldSchema.Mode.REPEATED) + .setName("TEST_DOUBLE") + .build(); + final TableFieldSchema test_date = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.DATE) + .setMode(TableFieldSchema.Mode.REQUIRED) + .setName("TEST_DATE") + .build(); + final TableFieldSchema option_test = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRUCT) + .setMode(TableFieldSchema.Mode.REQUIRED) + .addFields(0, required) + .addFields(1, repeated) + .addFields(2, optional) + .setName("option_test") + .build(); + final TableSchema tableSchema = + TableSchema.newBuilder() + .addFields(0, test_int) + .addFields(1, test_string) + .addFields(2, test_bytes) + .addFields(3, test_bool) + .addFields(4, test_double) + .addFields(5, test_date) + .addFields(6, option_test) + .build(); + final Descriptor descriptor = + BQTableSchemaToProtoDescriptor.convertBQTableSchemaToProtoDescriptor(tableSchema); + isDescriptorEqual(descriptor, CasingComplex.getDescriptor()); + } + + @Test + public void testOptions() throws Exception { + final TableFieldSchema required = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.INT64) + .setMode(TableFieldSchema.Mode.REQUIRED) + .setName("test_required") + .build(); + final TableFieldSchema repeated = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.INT64) + .setMode(TableFieldSchema.Mode.REPEATED) + .setName("test_repeated") + .build(); + final TableFieldSchema optional = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.INT64) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("test_optional") + .build(); + final TableSchema tableSchema = + TableSchema.newBuilder() + .addFields(0, required) + .addFields(1, repeated) + .addFields(2, optional) + .build(); + final Descriptor descriptor = + BQTableSchemaToProtoDescriptor.convertBQTableSchemaToProtoDescriptor(tableSchema); + isDescriptorEqual(descriptor, OptionTest.getDescriptor()); + } + + @Test + public void testDescriptorReuseDuringCreation() throws Exception { + final TableFieldSchema test_int = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.INT64) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("test_int") + .build(); + final TableFieldSchema reuse_lvl2 = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRUCT) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("reuse_lvl2") + .addFields(0, test_int) + .build(); + final TableFieldSchema reuse_lvl1 = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRUCT) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("reuse_lvl1") + .addFields(0, test_int) + .addFields(0, reuse_lvl2) + .build(); + final TableFieldSchema reuse_lvl1_1 = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRUCT) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("reuse_lvl1_1") + .addFields(0, test_int) + .addFields(0, reuse_lvl2) + .build(); + final TableFieldSchema reuse_lvl1_2 = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRUCT) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("reuse_lvl1_2") + .addFields(0, test_int) + .addFields(0, reuse_lvl2) + .build(); + final TableSchema tableSchema = + TableSchema.newBuilder() + .addFields(0, reuse_lvl1) + .addFields(1, reuse_lvl1_1) + .addFields(2, reuse_lvl1_2) + .build(); + final Descriptor descriptor = + BQTableSchemaToProtoDescriptor.convertBQTableSchemaToProtoDescriptor(tableSchema); + HashMap descriptorToCount = new HashMap(); + mapDescriptorToCount(descriptor, descriptorToCount); + assertEquals(descriptorToCount.size(), 2); + assertTrue(descriptorToCount.containsKey("root__reuse_lvl1")); + assertEquals(descriptorToCount.get("root__reuse_lvl1").intValue(), 3); + assertTrue(descriptorToCount.containsKey("root__reuse_lvl1__reuse_lvl2")); + assertEquals(descriptorToCount.get("root__reuse_lvl1__reuse_lvl2").intValue(), 3); + isDescriptorEqual(descriptor, ReuseRoot.getDescriptor()); + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeBigQueryWrite.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeBigQueryWrite.java new file mode 100644 index 0000000000..618366cfdc --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeBigQueryWrite.java @@ -0,0 +1,85 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.api.gax.grpc.testing.MockGrpcService; +import com.google.protobuf.AbstractMessage; +import io.grpc.ServerServiceDefinition; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import org.threeten.bp.Duration; + +/** + * A fake implementation of {@link MockGrpcService}, that can be used to test clients of a + * StreamWriter. It forwards calls to the real implementation (@link FakeBigQueryWriteImpl}. + */ +public class FakeBigQueryWrite implements MockGrpcService { + private final FakeBigQueryWriteImpl serviceImpl; + + public FakeBigQueryWrite() { + serviceImpl = new FakeBigQueryWriteImpl(); + } + + @Override + public List getRequests() { + return new LinkedList(serviceImpl.getCapturedRequests()); + } + + public List getAppendRequests() { + return serviceImpl.getCapturedRequests(); + } + + public List getWriteStreamRequests() { + return serviceImpl.getCapturedWriteRequests(); + } + + @Override + public void addResponse(AbstractMessage response) { + if (response instanceof AppendRowsResponse) { + serviceImpl.addResponse((AppendRowsResponse) response); + } else if (response instanceof WriteStream) { + serviceImpl.addWriteStreamResponse((WriteStream) response); + } else if (response instanceof FlushRowsResponse) { + serviceImpl.addFlushRowsResponse((FlushRowsResponse) response); + } else { + throw new IllegalStateException("Unsupported service"); + } + } + + @Override + public void addException(Exception exception) { + serviceImpl.addConnectionError(exception); + } + + @Override + public ServerServiceDefinition getServiceDefinition() { + return serviceImpl.bindService(); + } + + @Override + public void reset() { + serviceImpl.reset(); + } + + public void setResponseDelay(Duration delay) { + serviceImpl.setResponseDelay(delay); + } + + public void setExecutor(ScheduledExecutorService executor) { + serviceImpl.setExecutor(executor); + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeBigQueryWriteImpl.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeBigQueryWriteImpl.java new file mode 100644 index 0000000000..7cef4f7483 --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeBigQueryWriteImpl.java @@ -0,0 +1,212 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.common.base.Optional; +import io.grpc.stub.StreamObserver; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Logger; +import org.threeten.bp.Duration; + +/** + * A fake implementation of {@link BigQueryWriteImplBase} that can acts like server in StreamWriter + * unit testing. + */ +class FakeBigQueryWriteImpl extends BigQueryWriteGrpc.BigQueryWriteImplBase { + private static final Logger LOG = Logger.getLogger(FakeBigQueryWriteImpl.class.getName()); + + private final LinkedBlockingQueue requests = new LinkedBlockingQueue<>(); + private final LinkedBlockingQueue writeRequests = + new LinkedBlockingQueue<>(); + private final LinkedBlockingQueue flushRequests = new LinkedBlockingQueue<>(); + private final LinkedBlockingQueue responses = new LinkedBlockingQueue<>(); + private final LinkedBlockingQueue writeResponses = new LinkedBlockingQueue<>(); + private final LinkedBlockingQueue flushResponses = new LinkedBlockingQueue<>(); + private final AtomicInteger nextMessageId = new AtomicInteger(1); + private boolean autoPublishResponse; + private ScheduledExecutorService executor = null; + private Duration responseDelay = Duration.ZERO; + + /** Class used to save the state of a possible response. */ + private static class Response { + Optional appendResponse; + Optional error; + + public Response(AppendRowsResponse appendResponse) { + this.appendResponse = Optional.of(appendResponse); + this.error = Optional.absent(); + } + + public Response(Throwable exception) { + this.appendResponse = Optional.absent(); + this.error = Optional.of(exception); + } + + public AppendRowsResponse getResponse() { + return appendResponse.get(); + } + + public Throwable getError() { + return error.get(); + } + + boolean isError() { + return error.isPresent(); + } + + @Override + public String toString() { + if (isError()) { + return error.get().toString(); + } + return appendResponse.get().toString(); + } + } + + @Override + public void getWriteStream( + GetWriteStreamRequest request, StreamObserver responseObserver) { + Object response = writeResponses.remove(); + if (response instanceof WriteStream) { + writeRequests.add(request); + responseObserver.onNext((WriteStream) response); + responseObserver.onCompleted(); + } else if (response instanceof Exception) { + responseObserver.onError((Exception) response); + } else { + responseObserver.onError(new IllegalArgumentException("Unrecognized response type")); + } + } + + @Override + public void flushRows( + FlushRowsRequest request, StreamObserver responseObserver) { + Object response = writeResponses.remove(); + if (response instanceof FlushRowsResponse) { + flushRequests.add(request); + responseObserver.onNext((FlushRowsResponse) response); + responseObserver.onCompleted(); + } else if (response instanceof Exception) { + responseObserver.onError((Exception) response); + } else { + responseObserver.onError(new IllegalArgumentException("Unrecognized response type")); + } + } + + @Override + public StreamObserver appendRows( + final StreamObserver responseObserver) { + StreamObserver requestObserver = + new StreamObserver() { + @Override + public void onNext(AppendRowsRequest value) { + LOG.info("Get request:" + value.toString()); + final Response response = responses.remove(); + requests.add(value); + if (responseDelay == Duration.ZERO) { + sendResponse(response, responseObserver); + } else { + final Response responseToSend = response; + LOG.info("Schedule a response to be sent at delay"); + executor.schedule( + new Runnable() { + @Override + public void run() { + sendResponse(responseToSend, responseObserver); + } + }, + responseDelay.toMillis(), + TimeUnit.MILLISECONDS); + } + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + return requestObserver; + } + + private void sendResponse( + Response response, StreamObserver responseObserver) { + LOG.info("Sending response: " + response.toString()); + if (response.isError()) { + responseObserver.onError(response.getError()); + } else { + responseObserver.onNext(response.getResponse()); + } + } + + /** Set an executor to use to delay publish responses. */ + public FakeBigQueryWriteImpl setExecutor(ScheduledExecutorService executor) { + this.executor = executor; + return this; + } + + /** Set an amount of time by which to delay publish responses. */ + public FakeBigQueryWriteImpl setResponseDelay(Duration responseDelay) { + this.responseDelay = responseDelay; + return this; + } + + public FakeBigQueryWriteImpl addResponse(AppendRowsResponse appendRowsResponse) { + responses.add(new Response(appendRowsResponse)); + return this; + } + + public FakeBigQueryWriteImpl addResponse(AppendRowsResponse.Builder appendResponseBuilder) { + return addResponse(appendResponseBuilder.build()); + } + + public FakeBigQueryWriteImpl addWriteStreamResponse(WriteStream response) { + writeResponses.add(response); + return this; + } + + public FakeBigQueryWriteImpl addFlushRowsResponse(FlushRowsResponse response) { + flushResponses.add(response); + return this; + } + + public FakeBigQueryWriteImpl addConnectionError(Throwable error) { + responses.add(new Response(error)); + return this; + } + + public List getCapturedRequests() { + return new ArrayList(requests); + } + + public List getCapturedWriteRequests() { + return new ArrayList(writeRequests); + } + + public void reset() { + requests.clear(); + responses.clear(); + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeClock.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeClock.java new file mode 100644 index 0000000000..c5b8610d6e --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeClock.java @@ -0,0 +1,41 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.api.core.ApiClock; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + +/** A Clock to help with testing time-based logic. */ +public class FakeClock implements ApiClock { + + private final AtomicLong millis = new AtomicLong(); + + // Advances the clock value by {@code time} in {@code timeUnit}. + public void advance(long time, TimeUnit timeUnit) { + millis.addAndGet(timeUnit.toMillis(time)); + } + + @Override + public long nanoTime() { + return millisTime() * 1000_000L; + } + + @Override + public long millisTime() { + return millis.get(); + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeScheduledExecutorService.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeScheduledExecutorService.java new file mode 100644 index 0000000000..11a8311014 --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/FakeScheduledExecutorService.java @@ -0,0 +1,346 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.api.core.ApiClock; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.SettableFuture; +import java.util.ArrayList; +import java.util.Deque; +import java.util.LinkedList; +import java.util.List; +import java.util.PriorityQueue; +import java.util.concurrent.AbstractExecutorService; +import java.util.concurrent.Callable; +import java.util.concurrent.Delayed; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Logger; +import org.threeten.bp.Duration; +import org.threeten.bp.Instant; + +/** + * Fake implementation of {@link ScheduledExecutorService} that allows tests control the reference + * time of the executor and decide when to execute any outstanding task. + */ +public class FakeScheduledExecutorService extends AbstractExecutorService + implements ScheduledExecutorService { + private static final Logger LOG = Logger.getLogger(FakeScheduledExecutorService.class.getName()); + + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final PriorityQueue> pendingCallables = new PriorityQueue<>(); + private final FakeClock clock = new FakeClock(); + private final Deque expectedWorkQueue = new LinkedList<>(); + + public ApiClock getClock() { + return clock; + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { + return schedulePendingCallable( + new PendingCallable<>( + Duration.ofMillis(unit.toMillis(delay)), command, PendingCallableType.NORMAL)); + } + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + return schedulePendingCallable( + new PendingCallable<>( + Duration.ofMillis(unit.toMillis(delay)), callable, PendingCallableType.NORMAL)); + } + + @Override + public ScheduledFuture scheduleAtFixedRate( + Runnable command, long initialDelay, long period, TimeUnit unit) { + return schedulePendingCallable( + new PendingCallable<>( + Duration.ofMillis(unit.toMillis(initialDelay)), + command, + PendingCallableType.FIXED_RATE)); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay( + Runnable command, long initialDelay, long delay, TimeUnit unit) { + return schedulePendingCallable( + new PendingCallable<>( + Duration.ofMillis(unit.toMillis(initialDelay)), + command, + PendingCallableType.FIXED_DELAY)); + } + + /** + * This will advance the reference time of the executor and execute (in the same thread) any + * outstanding callable which execution time has passed. + */ + public void advanceTime(Duration toAdvance) { + LOG.info( + "Advance to time to:" + + Instant.ofEpochMilli(clock.millisTime() + toAdvance.toMillis()).toString()); + clock.advance(toAdvance.toMillis(), TimeUnit.MILLISECONDS); + work(); + } + + private void work() { + for (; ; ) { + PendingCallable callable = null; + Instant cmpTime = Instant.ofEpochMilli(clock.millisTime()); + if (!pendingCallables.isEmpty()) { + LOG.info( + "Going to call: Current time: " + + cmpTime.toString() + + " Scheduled time: " + + pendingCallables.peek().getScheduledTime().toString() + + " Creation time:" + + pendingCallables.peek().getCreationTime().toString()); + } + synchronized (pendingCallables) { + if (pendingCallables.isEmpty() + || pendingCallables.peek().getScheduledTime().isAfter(cmpTime)) { + break; + } + callable = pendingCallables.poll(); + } + if (callable != null) { + try { + callable.call(); + } catch (Exception e) { + // We ignore any callable exception, which should be set to the future but not relevant to + // advanceTime. + } + } + } + + synchronized (pendingCallables) { + if (shutdown.get() && pendingCallables.isEmpty()) { + pendingCallables.notifyAll(); + } + } + } + + @Override + public void shutdown() { + if (shutdown.getAndSet(true)) { + throw new IllegalStateException("This executor has been shutdown already"); + } + } + + @Override + public List shutdownNow() { + if (shutdown.getAndSet(true)) { + throw new IllegalStateException("This executor has been shutdown already"); + } + List pending = new ArrayList<>(); + for (final PendingCallable pendingCallable : pendingCallables) { + pending.add( + new Runnable() { + @Override + public void run() { + pendingCallable.call(); + } + }); + } + synchronized (pendingCallables) { + pendingCallables.notifyAll(); + pendingCallables.clear(); + } + return pending; + } + + @Override + public boolean isShutdown() { + return shutdown.get(); + } + + @Override + public boolean isTerminated() { + return pendingCallables.isEmpty(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + synchronized (pendingCallables) { + if (pendingCallables.isEmpty()) { + return true; + } + LOG.info("Wating on pending callables" + pendingCallables.size()); + pendingCallables.wait(unit.toMillis(timeout)); + return pendingCallables.isEmpty(); + } + } + + @Override + public void execute(Runnable command) { + if (shutdown.get()) { + throw new IllegalStateException("This executor has been shutdown"); + } + command.run(); + } + + ScheduledFuture schedulePendingCallable(PendingCallable callable) { + LOG.info( + "Schedule pending callable called " + callable.delay + " " + callable.getScheduledTime()); + if (shutdown.get()) { + throw new IllegalStateException("This executor has been shutdown"); + } + synchronized (pendingCallables) { + pendingCallables.add(callable); + } + work(); + synchronized (expectedWorkQueue) { + // We compare by the callable delay in order decide when to remove expectations from the + // expected work queue, i.e. only the expected work that matches the delay of the scheduled + // callable is removed from the queue. + if (!expectedWorkQueue.isEmpty() && expectedWorkQueue.peek().equals(callable.delay)) { + expectedWorkQueue.poll(); + } + expectedWorkQueue.notifyAll(); + } + + return callable.getScheduledFuture(); + } + + enum PendingCallableType { + NORMAL, + FIXED_RATE, + FIXED_DELAY + } + + /** Class that saves the state of an scheduled pending callable. */ + class PendingCallable implements Comparable> { + Instant creationTime = Instant.ofEpochMilli(clock.millisTime()); + Duration delay; + Callable pendingCallable; + SettableFuture future = SettableFuture.create(); + AtomicBoolean cancelled = new AtomicBoolean(false); + AtomicBoolean done = new AtomicBoolean(false); + PendingCallableType type; + + PendingCallable(Duration delay, final Runnable runnable, PendingCallableType type) { + pendingCallable = + new Callable() { + @Override + public T call() { + runnable.run(); + return null; + } + }; + this.type = type; + this.delay = delay; + } + + PendingCallable(Duration delay, Callable callable, PendingCallableType type) { + pendingCallable = callable; + this.type = type; + this.delay = delay; + } + + private Instant getScheduledTime() { + return creationTime.plus(delay); + } + + private Instant getCreationTime() { + return creationTime; + } + + ScheduledFuture getScheduledFuture() { + return new ScheduledFuture() { + @Override + public long getDelay(TimeUnit unit) { + return unit.convert( + getScheduledTime().toEpochMilli() - clock.millisTime(), TimeUnit.MILLISECONDS); + } + + @Override + public int compareTo(Delayed o) { + return Ints.saturatedCast( + getDelay(TimeUnit.MILLISECONDS) - o.getDelay(TimeUnit.MILLISECONDS)); + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + synchronized (this) { + cancelled.set(true); + return !done.get(); + } + } + + @Override + public boolean isCancelled() { + return cancelled.get(); + } + + @Override + public boolean isDone() { + return done.get(); + } + + @Override + public T get() throws InterruptedException, ExecutionException { + return future.get(); + } + + @Override + public T get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return future.get(timeout, unit); + } + }; + } + + T call() { + T result = null; + synchronized (this) { + if (cancelled.get()) { + return null; + } + try { + result = pendingCallable.call(); + future.set(result); + } catch (Exception e) { + future.setException(e); + } finally { + switch (type) { + case NORMAL: + done.set(true); + break; + case FIXED_DELAY: + this.creationTime = Instant.ofEpochMilli(clock.millisTime()); + schedulePendingCallable(this); + break; + case FIXED_RATE: + this.creationTime = this.creationTime.plus(delay); + schedulePendingCallable(this); + break; + default: + // Nothing to do + } + } + } + return result; + } + + @Override + public int compareTo(PendingCallable other) { + return getScheduledTime().compareTo(other.getScheduledTime()); + } + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/JsonStreamWriterTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/JsonStreamWriterTest.java new file mode 100644 index 0000000000..4fc3e13ef5 --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/JsonStreamWriterTest.java @@ -0,0 +1,958 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.api.core.ApiFuture; +import com.google.api.gax.core.ExecutorProvider; +import com.google.api.gax.core.InstantiatingExecutorProvider; +import com.google.api.gax.core.NoCredentialsProvider; +import com.google.api.gax.grpc.testing.LocalChannelProvider; +import com.google.api.gax.grpc.testing.MockGrpcService; +import com.google.api.gax.grpc.testing.MockServiceHelper; +import com.google.cloud.bigquery.storage.test.JsonTest.ComplexRoot; +import com.google.cloud.bigquery.storage.test.Test.FooType; +import com.google.cloud.bigquery.storage.test.Test.UpdatedFooType; +import com.google.cloud.bigquery.storage.test.Test.UpdatedFooType2; +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors.DescriptorValidationException; +import com.google.protobuf.Timestamp; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.logging.Logger; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.threeten.bp.Instant; + +@RunWith(JUnit4.class) +public class JsonStreamWriterTest { + private static final Logger LOG = Logger.getLogger(JsonStreamWriterTest.class.getName()); + private static final String TEST_STREAM = "projects/p/datasets/d/tables/t/streams/s"; + private static final ExecutorProvider SINGLE_THREAD_EXECUTOR = + InstantiatingExecutorProvider.newBuilder().setExecutorThreadCount(1).build(); + private static LocalChannelProvider channelProvider; + private FakeScheduledExecutorService fakeExecutor; + private FakeBigQueryWrite testBigQueryWrite; + private static MockServiceHelper serviceHelper; + + private final TableFieldSchema FOO = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRING) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("foo") + .build(); + private final TableSchema TABLE_SCHEMA = TableSchema.newBuilder().addFields(0, FOO).build(); + + private final TableFieldSchema BAR = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRING) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("bar") + .build(); + private final TableFieldSchema BAZ = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRING) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("baz") + .build(); + private final TableSchema UPDATED_TABLE_SCHEMA = + TableSchema.newBuilder().addFields(0, FOO).addFields(1, BAR).build(); + private final TableSchema UPDATED_TABLE_SCHEMA_2 = + TableSchema.newBuilder().addFields(0, FOO).addFields(1, BAR).addFields(2, BAZ).build(); + + private final TableFieldSchema TEST_INT = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.INT64) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("test_int") + .build(); + private final TableFieldSchema TEST_STRING = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRING) + .setMode(TableFieldSchema.Mode.REPEATED) + .setName("test_string") + .build(); + private final TableFieldSchema TEST_BYTES = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.BYTES) + .setMode(TableFieldSchema.Mode.REQUIRED) + .setName("test_bytes") + .build(); + private final TableFieldSchema TEST_BOOL = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.BOOL) + .setMode(TableFieldSchema.Mode.NULLABLE) + .setName("test_bool") + .build(); + private final TableFieldSchema TEST_DOUBLE = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.DOUBLE) + .setMode(TableFieldSchema.Mode.REPEATED) + .setName("test_double") + .build(); + private final TableFieldSchema TEST_DATE = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.DATE) + .setMode(TableFieldSchema.Mode.REQUIRED) + .setName("test_date") + .build(); + private final TableFieldSchema COMPLEXLVL2 = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRUCT) + .setMode(TableFieldSchema.Mode.REQUIRED) + .addFields(0, TEST_INT) + .setName("complex_lvl2") + .build(); + private final TableFieldSchema COMPLEXLVL1 = + TableFieldSchema.newBuilder() + .setType(TableFieldSchema.Type.STRUCT) + .setMode(TableFieldSchema.Mode.REQUIRED) + .addFields(0, TEST_INT) + .addFields(1, COMPLEXLVL2) + .setName("complex_lvl1") + .build(); + private final TableSchema COMPLEX_TABLE_SCHEMA = + TableSchema.newBuilder() + .addFields(0, TEST_INT) + .addFields(1, TEST_STRING) + .addFields(2, TEST_BYTES) + .addFields(3, TEST_BOOL) + .addFields(4, TEST_DOUBLE) + .addFields(5, TEST_DATE) + .addFields(6, COMPLEXLVL1) + .addFields(7, COMPLEXLVL2) + .build(); + + @Before + public void setUp() throws Exception { + testBigQueryWrite = new FakeBigQueryWrite(); + serviceHelper = + new MockServiceHelper( + UUID.randomUUID().toString(), Arrays.asList(testBigQueryWrite)); + serviceHelper.start(); + channelProvider = serviceHelper.createChannelProvider(); + fakeExecutor = new FakeScheduledExecutorService(); + testBigQueryWrite.setExecutor(fakeExecutor); + Instant time = Instant.now(); + Timestamp timestamp = + Timestamp.newBuilder().setSeconds(time.getEpochSecond()).setNanos(time.getNano()).build(); + // Add enough GetWriteStream response. + for (int i = 0; i < 4; i++) { + testBigQueryWrite.addResponse( + WriteStream.newBuilder().setName(TEST_STREAM).setCreateTime(timestamp).build()); + } + } + + @After + public void tearDown() throws Exception { + serviceHelper.stop(); + } + + private JsonStreamWriter.Builder getTestJsonStreamWriterBuilder( + String testStream, TableSchema BQTableSchema) { + return JsonStreamWriter.newBuilder(testStream, BQTableSchema) + .setChannelProvider(channelProvider) + .setExecutorProvider(SINGLE_THREAD_EXECUTOR) + .setCredentialsProvider(NoCredentialsProvider.create()); + } + + @Test + public void testTwoParamNewBuilder_nullSchema() { + try { + getTestJsonStreamWriterBuilder(null, TABLE_SCHEMA); + Assert.fail("expected NullPointerException"); + } catch (NullPointerException e) { + assertEquals(e.getMessage(), "StreamName is null."); + } + } + + @Test + public void testTwoParamNewBuilder_nullStream() { + try { + getTestJsonStreamWriterBuilder(TEST_STREAM, null); + Assert.fail("expected NullPointerException"); + } catch (NullPointerException e) { + assertEquals(e.getMessage(), "TableSchema is null."); + } + } + + @Test + public void testTwoParamNewBuilder() + throws DescriptorValidationException, IOException, InterruptedException { + JsonStreamWriter writer = getTestJsonStreamWriterBuilder(TEST_STREAM, TABLE_SCHEMA).build(); + assertEquals(TEST_STREAM, writer.getStreamName()); + } + + @Test + public void testSingleAppendSimpleJson() throws Exception { + FooType expectedProto = FooType.newBuilder().setFoo("allen").build(); + JSONObject foo = new JSONObject(); + foo.put("foo", "allen"); + JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + + try (JsonStreamWriter writer = + getTestJsonStreamWriterBuilder(TEST_STREAM, TABLE_SCHEMA).build()) { + + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0).build()); + + ApiFuture appendFuture = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + + assertEquals(0L, appendFuture.get().getOffset()); + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRows(0), + expectedProto.toByteString()); + } + } + + @Test + public void testSingleAppendMultipleSimpleJson() throws Exception { + FooType expectedProto = FooType.newBuilder().setFoo("allen").build(); + JSONObject foo = new JSONObject(); + foo.put("foo", "allen"); + JSONObject foo1 = new JSONObject(); + foo1.put("foo", "allen"); + JSONObject foo2 = new JSONObject(); + foo2.put("foo", "allen"); + JSONObject foo3 = new JSONObject(); + foo3.put("foo", "allen"); + JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + jsonArr.put(foo1); + jsonArr.put(foo2); + jsonArr.put(foo3); + + try (JsonStreamWriter writer = + getTestJsonStreamWriterBuilder(TEST_STREAM, TABLE_SCHEMA).build()) { + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0).build()); + + ApiFuture appendFuture = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + + assertEquals(0L, appendFuture.get().getOffset()); + assertEquals( + 4, + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + for (int i = 0; i < 4; i++) { + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRows(i), + expectedProto.toByteString()); + } + } + } + + @Test + public void testMultipleAppendSimpleJson() throws Exception { + FooType expectedProto = FooType.newBuilder().setFoo("allen").build(); + JSONObject foo = new JSONObject(); + foo.put("foo", "allen"); + JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + + try (JsonStreamWriter writer = + getTestJsonStreamWriterBuilder(TEST_STREAM, TABLE_SCHEMA).build()) { + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(1).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(2).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(3).build()); + ApiFuture appendFuture; + for (int i = 0; i < 4; i++) { + appendFuture = writer.append(jsonArr, -1, /* allowUnknownFields */ false); + + assertEquals((long) i, appendFuture.get().getOffset()); + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(i) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(i) + .getProtoRows() + .getRows() + .getSerializedRows(0), + expectedProto.toByteString()); + } + } + } + + @Test + public void testSingleAppendComplexJson() throws Exception { + ComplexRoot expectedProto = + ComplexRoot.newBuilder() + .setTestInt(1) + .addTestString("a") + .addTestString("b") + .addTestString("c") + .setTestBytes(ByteString.copyFrom("hello".getBytes())) + .setTestBool(true) + .addTestDouble(1.1) + .addTestDouble(2.2) + .addTestDouble(3.3) + .addTestDouble(4.4) + .setTestDate(1) + .setComplexLvl1( + com.google.cloud.bigquery.storage.test.JsonTest.ComplexLvl1.newBuilder() + .setTestInt(2) + .setComplexLvl2( + com.google.cloud.bigquery.storage.test.JsonTest.ComplexLvl2.newBuilder() + .setTestInt(3) + .build()) + .build()) + .setComplexLvl2( + com.google.cloud.bigquery.storage.test.JsonTest.ComplexLvl2.newBuilder() + .setTestInt(3) + .build()) + .build(); + JSONObject complex_lvl2 = new JSONObject(); + complex_lvl2.put("test_int", 3); + + JSONObject complex_lvl1 = new JSONObject(); + complex_lvl1.put("test_int", 2); + complex_lvl1.put("complex_lvl2", complex_lvl2); + + JSONObject json = new JSONObject(); + json.put("test_int", 1); + json.put("test_string", new JSONArray(new String[] {"a", "b", "c"})); + json.put("test_bytes", "hello"); + json.put("test_bool", true); + json.put("test_DOUBLe", new JSONArray(new Double[] {1.1, 2.2, 3.3, 4.4})); + json.put("test_date", 1); + json.put("complex_lvl1", complex_lvl1); + json.put("complex_lvl2", complex_lvl2); + JSONArray jsonArr = new JSONArray(); + jsonArr.put(json); + + try (JsonStreamWriter writer = + getTestJsonStreamWriterBuilder(TEST_STREAM, COMPLEX_TABLE_SCHEMA).build()) { + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0).build()); + ApiFuture appendFuture = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + + assertEquals(0L, appendFuture.get().getOffset()); + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRows(0), + expectedProto.toByteString()); + } + } + + @Test + public void testAppendMultipleSchemaUpdate() throws Exception { + try (JsonStreamWriter writer = + getTestJsonStreamWriterBuilder(TEST_STREAM, TABLE_SCHEMA).build()) { + // Add fake resposne for FakeBigQueryWrite, first response has updated schema. + testBigQueryWrite.addResponse( + AppendRowsResponse.newBuilder() + .setOffset(0) + .setUpdatedSchema(UPDATED_TABLE_SCHEMA) + .build()); + testBigQueryWrite.addResponse( + AppendRowsResponse.newBuilder() + .setOffset(1) + .setUpdatedSchema(UPDATED_TABLE_SCHEMA_2) + .build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(2).build()); + // First append + JSONObject foo = new JSONObject(); + foo.put("foo", "allen"); + JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + + ApiFuture appendFuture1 = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + + int millis = 0; + while (millis <= 10000) { + if (writer.getDescriptor().getFields().size() == 2) { + break; + } + Thread.sleep(100); + millis += 100; + } + assertTrue(writer.getDescriptor().getFields().size() == 2); + assertEquals(0L, appendFuture1.get().getOffset()); + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRows(0), + FooType.newBuilder().setFoo("allen").build().toByteString()); + + // Second append with updated schema. + JSONObject updatedFoo = new JSONObject(); + updatedFoo.put("foo", "allen"); + updatedFoo.put("bar", "allen2"); + JSONArray updatedJsonArr = new JSONArray(); + updatedJsonArr.put(updatedFoo); + + ApiFuture appendFuture2 = + writer.append(updatedJsonArr, -1, /* allowUnknownFields */ false); + + millis = 0; + while (millis <= 10000) { + if (writer.getDescriptor().getFields().size() == 3) { + break; + } + Thread.sleep(100); + millis += 100; + } + assertTrue(writer.getDescriptor().getFields().size() == 3); + assertEquals(1L, appendFuture2.get().getOffset()); + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(1) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(1) + .getProtoRows() + .getRows() + .getSerializedRows(0), + UpdatedFooType.newBuilder().setFoo("allen").setBar("allen2").build().toByteString()); + + // Third append with updated schema. + JSONObject updatedFoo2 = new JSONObject(); + updatedFoo2.put("foo", "allen"); + updatedFoo2.put("bar", "allen2"); + updatedFoo2.put("baz", "allen3"); + JSONArray updatedJsonArr2 = new JSONArray(); + updatedJsonArr2.put(updatedFoo2); + + ApiFuture appendFuture3 = + writer.append(updatedJsonArr2, -1, /* allowUnknownFields */ false); + + assertEquals(2L, appendFuture3.get().getOffset()); + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(1) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(2) + .getProtoRows() + .getRows() + .getSerializedRows(0), + UpdatedFooType2.newBuilder() + .setFoo("allen") + .setBar("allen2") + .setBaz("allen3") + .build() + .toByteString()); + // // Check if writer schemas were added in for both connections. + assertTrue(testBigQueryWrite.getAppendRequests().get(0).getProtoRows().hasWriterSchema()); + assertTrue(testBigQueryWrite.getAppendRequests().get(1).getProtoRows().hasWriterSchema()); + assertTrue(testBigQueryWrite.getAppendRequests().get(2).getProtoRows().hasWriterSchema()); + } + } + + @Test + // This might be a bug but it is the current behavior. Investigate. + public void testAppendAlreadyExists_doesNotThrowxception() + throws DescriptorValidationException, IOException, InterruptedException, ExecutionException { + try (JsonStreamWriter writer = + getTestJsonStreamWriterBuilder(TEST_STREAM, TABLE_SCHEMA).build()) { + testBigQueryWrite.addResponse( + AppendRowsResponse.newBuilder() + .setError(com.google.rpc.Status.newBuilder().setCode(6).build()) + .build()); + JSONObject foo = new JSONObject(); + foo.put("foo", "allen"); + JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + ApiFuture appendFuture = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + appendFuture.get(); + } + } + + @Test + public void testAppendOutOfRangeException() throws Exception { + try (JsonStreamWriter writer = + getTestJsonStreamWriterBuilder(TEST_STREAM, TABLE_SCHEMA).build()) { + testBigQueryWrite.addResponse( + AppendRowsResponse.newBuilder() + .setError(com.google.rpc.Status.newBuilder().setCode(11).build()) + .build()); + JSONObject foo = new JSONObject(); + foo.put("foo", "allen"); + JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + ApiFuture appendFuture = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + try { + appendFuture.get(); + Assert.fail("expected ExecutionException"); + } catch (ExecutionException ex) { + assertEquals(ex.getCause().getMessage(), "OUT_OF_RANGE: "); + } + } + } + + @Test + public void testAppendOutOfRangeAndUpdateSchema() throws Exception { + try (JsonStreamWriter writer = + getTestJsonStreamWriterBuilder(TEST_STREAM, TABLE_SCHEMA).build()) { + testBigQueryWrite.addResponse( + AppendRowsResponse.newBuilder() + .setError(com.google.rpc.Status.newBuilder().setCode(11).build()) + .setUpdatedSchema(UPDATED_TABLE_SCHEMA) + .build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0).build()); + + JSONObject foo = new JSONObject(); + foo.put("foo", "allen"); + JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + ApiFuture appendFuture = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + try { + appendFuture.get(); + Assert.fail("expected ExecutionException"); + } catch (ExecutionException ex) { + assertEquals(ex.getCause().getMessage(), "OUT_OF_RANGE: "); + int millis = 0; + while (millis <= 10000) { + if (writer.getDescriptor().getFields().size() == 2) { + break; + } + Thread.sleep(100); + millis += 100; + } + assertTrue(writer.getDescriptor().getFields().size() == 2); + } + + JSONObject updatedFoo = new JSONObject(); + updatedFoo.put("foo", "allen"); + updatedFoo.put("bar", "allen2"); + JSONArray updatedJsonArr = new JSONArray(); + updatedJsonArr.put(updatedFoo); + + ApiFuture appendFuture2 = + writer.append(updatedJsonArr, -1, /* allowUnknownFields */ false); + + assertEquals(0L, appendFuture2.get().getOffset()); + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(1) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(1) + .getProtoRows() + .getRows() + .getSerializedRows(0), + UpdatedFooType.newBuilder().setFoo("allen").setBar("allen2").build().toByteString()); + + // Check if writer schemas were added in for both connections. + assertTrue(testBigQueryWrite.getAppendRequests().get(0).getProtoRows().hasWriterSchema()); + assertTrue(testBigQueryWrite.getAppendRequests().get(1).getProtoRows().hasWriterSchema()); + } + } + + @Test + public void testSchemaUpdateWithNonemptyBatch() throws Exception { + try (JsonStreamWriter writer = + getTestJsonStreamWriterBuilder(TEST_STREAM, TABLE_SCHEMA) + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(2L) + .build()) + .build()) { + testBigQueryWrite.addResponse( + AppendRowsResponse.newBuilder() + .setOffset(0) + .setUpdatedSchema(UPDATED_TABLE_SCHEMA) + .build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(2).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(3).build()); + // First append + JSONObject foo = new JSONObject(); + foo.put("foo", "allen"); + JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + + ApiFuture appendFuture1 = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + ApiFuture appendFuture2 = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + ApiFuture appendFuture3 = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + + assertEquals(0L, appendFuture1.get().getOffset()); + assertEquals(1L, appendFuture2.get().getOffset()); + assertEquals( + 2, + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRows(0), + FooType.newBuilder().setFoo("allen").build().toByteString()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRows(1), + FooType.newBuilder().setFoo("allen").build().toByteString()); + + assertEquals(2L, appendFuture3.get().getOffset()); + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(1) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(1) + .getProtoRows() + .getRows() + .getSerializedRows(0), + FooType.newBuilder().setFoo("allen").build().toByteString()); + + int millis = 0; + while (millis <= 10000) { + if (writer.getDescriptor().getFields().size() == 2) { + break; + } + Thread.sleep(100); + millis += 100; + } + assertTrue(writer.getDescriptor().getFields().size() == 2); + + // Second append with updated schema. + JSONObject updatedFoo = new JSONObject(); + updatedFoo.put("foo", "allen"); + updatedFoo.put("bar", "allen2"); + JSONArray updatedJsonArr = new JSONArray(); + updatedJsonArr.put(updatedFoo); + + ApiFuture appendFuture4 = + writer.append(updatedJsonArr, -1, /* allowUnknownFields */ false); + + assertEquals(3L, appendFuture4.get().getOffset()); + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(2) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(2) + .getProtoRows() + .getRows() + .getSerializedRows(0), + UpdatedFooType.newBuilder().setFoo("allen").setBar("allen2").build().toByteString()); + + assertTrue(testBigQueryWrite.getAppendRequests().get(0).getProtoRows().hasWriterSchema()); + assertTrue( + testBigQueryWrite.getAppendRequests().get(1).getProtoRows().hasWriterSchema() + || testBigQueryWrite.getAppendRequests().get(2).getProtoRows().hasWriterSchema()); + } + } + + @Test + public void testMultiThreadAppendNoSchemaUpdate() throws Exception { + try (JsonStreamWriter writer = + getTestJsonStreamWriterBuilder(TEST_STREAM, TABLE_SCHEMA) + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(1L) + .build()) + .build()) { + + JSONObject foo = new JSONObject(); + foo.put("foo", "allen"); + final JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + + final HashSet offset_sets = new HashSet(); + int thread_nums = 5; + Thread[] thread_arr = new Thread[thread_nums]; + for (int i = 0; i < thread_nums; i++) { + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset((long) i).build()); + offset_sets.add((long) i); + Thread t = + new Thread( + new Runnable() { + public void run() { + try { + ApiFuture appendFuture = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + AppendRowsResponse response = appendFuture.get(); + offset_sets.remove(response.getOffset()); + } catch (Exception e) { + LOG.severe("Thread execution failed: " + e.getMessage()); + } + } + }); + thread_arr[i] = t; + t.start(); + } + + for (int i = 0; i < thread_nums; i++) { + thread_arr[i].join(); + } + assertTrue(offset_sets.size() == 0); + for (int i = 0; i < thread_nums; i++) { + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(i) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(i) + .getProtoRows() + .getRows() + .getSerializedRows(0), + FooType.newBuilder().setFoo("allen").build().toByteString()); + } + } + } + + @Test + public void testMultiThreadAppendWithSchemaUpdate() throws Exception { + try (JsonStreamWriter writer = + getTestJsonStreamWriterBuilder(TEST_STREAM, TABLE_SCHEMA) + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(1L) + .build()) + .build()) { + JSONObject foo = new JSONObject(); + foo.put("foo", "allen"); + final JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + + final HashSet offsetSets = new HashSet(); + int numberThreads = 5; + Thread[] thread_arr = new Thread[numberThreads]; + for (int i = 0; i < numberThreads; i++) { + if (i == 2) { + testBigQueryWrite.addResponse( + AppendRowsResponse.newBuilder() + .setOffset((long) i) + .setUpdatedSchema(UPDATED_TABLE_SCHEMA) + .build()); + } else { + testBigQueryWrite.addResponse( + AppendRowsResponse.newBuilder().setOffset((long) i).build()); + } + + offsetSets.add((long) i); + Thread t = + new Thread( + new Runnable() { + public void run() { + try { + ApiFuture appendFuture = + writer.append(jsonArr, -1, /* allowUnknownFields */ false); + AppendRowsResponse response = appendFuture.get(); + offsetSets.remove(response.getOffset()); + } catch (Exception e) { + LOG.severe("Thread execution failed: " + e.getMessage()); + } + } + }); + thread_arr[i] = t; + t.start(); + } + + for (int i = 0; i < numberThreads; i++) { + thread_arr[i].join(); + } + assertTrue(offsetSets.size() == 0); + for (int i = 0; i < numberThreads; i++) { + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(i) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(i) + .getProtoRows() + .getRows() + .getSerializedRows(0), + FooType.newBuilder().setFoo("allen").build().toByteString()); + } + + int millis = 0; + while (millis <= 10000) { + if (writer.getDescriptor().getFields().size() == 2) { + break; + } + Thread.sleep(100); + millis += 100; + } + assertEquals(2, writer.getDescriptor().getFields().size()); + + foo.put("bar", "allen2"); + final JSONArray jsonArr2 = new JSONArray(); + jsonArr2.put(foo); + + for (int i = numberThreads; i < numberThreads + 5; i++) { + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset((long) i).build()); + offsetSets.add((long) i); + Thread t = + new Thread( + new Runnable() { + public void run() { + try { + ApiFuture appendFuture = + writer.append(jsonArr2, -1, /* allowUnknownFields */ false); + AppendRowsResponse response = appendFuture.get(); + offsetSets.remove(response.getOffset()); + } catch (Exception e) { + LOG.severe("Thread execution failed: " + e.getMessage()); + } + } + }); + thread_arr[i - 5] = t; + t.start(); + } + + for (int i = 0; i < numberThreads; i++) { + thread_arr[i].join(); + } + assertTrue(offsetSets.size() == 0); + for (int i = 0; i < numberThreads; i++) { + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(i + 5) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + testBigQueryWrite + .getAppendRequests() + .get(i + 5) + .getProtoRows() + .getRows() + .getSerializedRows(0), + UpdatedFooType.newBuilder().setFoo("allen").setBar("allen2").build().toByteString()); + } + } + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/JsonToProtoMessageTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/JsonToProtoMessageTest.java new file mode 100644 index 0000000000..ec5a7490ba --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/JsonToProtoMessageTest.java @@ -0,0 +1,750 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.bigquery.storage.test.JsonTest.*; +import com.google.cloud.bigquery.storage.test.SchemaTest.*; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import java.util.ArrayList; +import java.util.Map; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class JsonToProtoMessageTest { + private static ImmutableMap AllTypesToDebugMessageTest = + new ImmutableMap.Builder() + .put(BoolType.getDescriptor(), "boolean") + .put(BytesType.getDescriptor(), "string") + .put(Int64Type.getDescriptor(), "int64") + .put(Int32Type.getDescriptor(), "int32") + .put(DoubleType.getDescriptor(), "double") + .put(StringType.getDescriptor(), "string") + .put(RepeatedType.getDescriptor(), "array") + .put(ObjectType.getDescriptor(), "object") + .build(); + + private static ImmutableMap AllTypesToCorrectProto = + new ImmutableMap.Builder() + .put( + BoolType.getDescriptor(), + new Message[] {BoolType.newBuilder().setTestFieldType(true).build()}) + .put( + BytesType.getDescriptor(), + new Message[] { + BytesType.newBuilder() + .setTestFieldType(ByteString.copyFrom("test".getBytes())) + .build() + }) + .put( + Int64Type.getDescriptor(), + new Message[] { + Int64Type.newBuilder().setTestFieldType(Long.MAX_VALUE).build(), + Int64Type.newBuilder().setTestFieldType(new Long(Integer.MAX_VALUE)).build() + }) + .put( + Int32Type.getDescriptor(), + new Message[] {Int32Type.newBuilder().setTestFieldType(Integer.MAX_VALUE).build()}) + .put( + DoubleType.getDescriptor(), + new Message[] {DoubleType.newBuilder().setTestFieldType(1.23).build()}) + .put( + StringType.getDescriptor(), + new Message[] {StringType.newBuilder().setTestFieldType("test").build()}) + .put( + RepeatedType.getDescriptor(), + new Message[] { + RepeatedType.newBuilder() + .addAllTestFieldType( + new ArrayList() { + { + add(1L); + add(2L); + add(3L); + } + }) + .build() + }) + .put( + ObjectType.getDescriptor(), + new Message[] { + ObjectType.newBuilder() + .setTestFieldType(ComplexLvl2.newBuilder().setTestInt(1).build()) + .build() + }) + .build(); + + private static ImmutableMap AllRepeatedTypesToDebugMessageTest = + new ImmutableMap.Builder() + .put(RepeatedBool.getDescriptor(), "boolean") + .put(RepeatedBytes.getDescriptor(), "string") + .put(RepeatedInt64.getDescriptor(), "int64") + .put(RepeatedInt32.getDescriptor(), "int32") + .put(RepeatedDouble.getDescriptor(), "double") + .put(RepeatedString.getDescriptor(), "string") + .put(RepeatedObject.getDescriptor(), "object") + .build(); + + private static ImmutableMap AllRepeatedTypesToCorrectProto = + new ImmutableMap.Builder() + .put( + RepeatedBool.getDescriptor(), + new Message[] { + RepeatedBool.newBuilder().addTestRepeated(true).addTestRepeated(false).build() + }) + .put( + RepeatedBytes.getDescriptor(), + new Message[] { + RepeatedBytes.newBuilder() + .addTestRepeated(ByteString.copyFrom("hello".getBytes())) + .addTestRepeated(ByteString.copyFrom("test".getBytes())) + .build() + }) + .put( + RepeatedString.getDescriptor(), + new Message[] { + RepeatedString.newBuilder().addTestRepeated("hello").addTestRepeated("test").build() + }) + .put( + RepeatedInt64.getDescriptor(), + new Message[] { + RepeatedInt64.newBuilder() + .addTestRepeated(Long.MAX_VALUE) + .addTestRepeated(Long.MIN_VALUE) + .addTestRepeated(Integer.MAX_VALUE) + .addTestRepeated(Integer.MIN_VALUE) + .addTestRepeated(Short.MAX_VALUE) + .addTestRepeated(Short.MIN_VALUE) + .addTestRepeated(Byte.MAX_VALUE) + .addTestRepeated(Byte.MIN_VALUE) + .addTestRepeated(0) + .build(), + RepeatedInt64.newBuilder() + .addTestRepeated(Integer.MAX_VALUE) + .addTestRepeated(Integer.MIN_VALUE) + .addTestRepeated(Short.MAX_VALUE) + .addTestRepeated(Short.MIN_VALUE) + .addTestRepeated(Byte.MAX_VALUE) + .addTestRepeated(Byte.MIN_VALUE) + .addTestRepeated(0) + .build() + }) + .put( + RepeatedInt32.getDescriptor(), + new Message[] { + RepeatedInt32.newBuilder() + .addTestRepeated(Integer.MAX_VALUE) + .addTestRepeated(Integer.MIN_VALUE) + .addTestRepeated(Short.MAX_VALUE) + .addTestRepeated(Short.MIN_VALUE) + .addTestRepeated(Byte.MAX_VALUE) + .addTestRepeated(Byte.MIN_VALUE) + .addTestRepeated(0) + .build() + }) + .put( + RepeatedDouble.getDescriptor(), + new Message[] { + RepeatedDouble.newBuilder() + .addTestRepeated(Double.MAX_VALUE) + .addTestRepeated(Double.MIN_VALUE) + .addTestRepeated(Float.MAX_VALUE) + .addTestRepeated(Float.MIN_VALUE) + .build(), + RepeatedDouble.newBuilder() + .addTestRepeated(Float.MAX_VALUE) + .addTestRepeated(Float.MIN_VALUE) + .build() + }) + .put( + RepeatedObject.getDescriptor(), + new Message[] { + RepeatedObject.newBuilder() + .addTestRepeated(ComplexLvl2.newBuilder().setTestInt(1).build()) + .addTestRepeated(ComplexLvl2.newBuilder().setTestInt(2).build()) + .addTestRepeated(ComplexLvl2.newBuilder().setTestInt(3).build()) + .build() + }) + .build(); + + private static JSONObject[] simpleJSONObjects = { + new JSONObject().put("test_field_type", Long.MAX_VALUE), + new JSONObject().put("test_field_type", Integer.MAX_VALUE), + new JSONObject().put("test_field_type", 1.23), + new JSONObject().put("test_field_type", true), + new JSONObject().put("test_field_type", "test"), + new JSONObject().put("test_field_type", new JSONArray("[1, 2, 3]")), + new JSONObject().put("test_field_type", new JSONObject().put("test_int", 1)) + }; + + private static JSONObject[] simpleJSONArrays = { + new JSONObject() + .put( + "test_repeated", + new JSONArray( + new Long[] { + Long.MAX_VALUE, + Long.MIN_VALUE, + (long) Integer.MAX_VALUE, + (long) Integer.MIN_VALUE, + (long) Short.MAX_VALUE, + (long) Short.MIN_VALUE, + (long) Byte.MAX_VALUE, + (long) Byte.MIN_VALUE, + 0L + })), + new JSONObject() + .put( + "test_repeated", + new JSONArray( + new Integer[] { + Integer.MAX_VALUE, + Integer.MIN_VALUE, + (int) Short.MAX_VALUE, + (int) Short.MIN_VALUE, + (int) Byte.MAX_VALUE, + (int) Byte.MIN_VALUE, + 0 + })), + new JSONObject() + .put( + "test_repeated", + new JSONArray( + new Double[] { + Double.MAX_VALUE, + Double.MIN_VALUE, + (double) Float.MAX_VALUE, + (double) Float.MIN_VALUE + })), + new JSONObject() + .put("test_repeated", new JSONArray(new Float[] {Float.MAX_VALUE, Float.MIN_VALUE})), + new JSONObject().put("test_repeated", new JSONArray(new Boolean[] {true, false})), + new JSONObject().put("test_repeated", new JSONArray(new String[] {"hello", "test"})), + new JSONObject() + .put( + "test_repeated", + new JSONArray( + new JSONObject[] { + new JSONObject().put("test_int", 1), + new JSONObject().put("test_int", 2), + new JSONObject().put("test_int", 3) + })) + }; + + @Test + public void testDifferentNameCasing() throws Exception { + TestInt64 expectedProto = + TestInt64.newBuilder().setByte(1).setShort(1).setInt(1).setLong(1).build(); + + JSONObject json = new JSONObject(); + json.put("bYtE", (byte) 1); + json.put("SHORT", (short) 1); + json.put("inT", 1); + json.put("lONg", 1L); + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(TestInt64.getDescriptor(), json, false); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testInt64() throws Exception { + TestInt64 expectedProto = + TestInt64.newBuilder().setByte(1).setShort(1).setInt(1).setLong(1).build(); + JSONObject json = new JSONObject(); + json.put("byte", (byte) 1); + json.put("short", (short) 1); + json.put("int", 1); + json.put("long", 1L); + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(TestInt64.getDescriptor(), json, false); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testInt32() throws Exception { + TestInt32 expectedProto = TestInt32.newBuilder().setByte(1).setShort(1).setInt(1).build(); + JSONObject json = new JSONObject(); + json.put("byte", (byte) 1); + json.put("short", (short) 1); + json.put("int", 1); + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(TestInt32.getDescriptor(), json, false); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testInt32NotMatchInt64() throws Exception { + JSONObject json = new JSONObject(); + json.put("byte", (byte) 1); + json.put("short", (short) 1); + json.put("int", 1L); + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(TestInt32.getDescriptor(), json, false); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), "JSONObject does not have a int32 field at root.int."); + } + } + + @Test + public void testDouble() throws Exception { + TestDouble expectedProto = TestDouble.newBuilder().setDouble(1.2).setFloat(3.4f).build(); + JSONObject json = new JSONObject(); + json.put("double", 1.2); + json.put("float", 3.4f); + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(TestDouble.getDescriptor(), json, false); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testAllTypes() throws Exception { + for (Map.Entry entry : AllTypesToDebugMessageTest.entrySet()) { + int success = 0; + for (JSONObject json : simpleJSONObjects) { + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(entry.getKey(), json, false); + assertEquals(protoMsg, AllTypesToCorrectProto.get(entry.getKey())[success]); + success += 1; + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + "JSONObject does not have a " + entry.getValue() + " field at root.test_field_type."); + } + } + if (entry.getKey() == Int64Type.getDescriptor()) { + assertEquals(2, success); + } else { + assertEquals(1, success); + } + } + } + + @Test + public void testAllRepeatedTypesWithLimits() throws Exception { + for (Map.Entry entry : AllRepeatedTypesToDebugMessageTest.entrySet()) { + int success = 0; + for (JSONObject json : simpleJSONArrays) { + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(entry.getKey(), json, false); + assertEquals(protoMsg, AllRepeatedTypesToCorrectProto.get(entry.getKey())[success]); + success += 1; + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + "JSONObject does not have a " + + entry.getValue() + + " field at root.test_repeated[0]."); + } + } + if (entry.getKey() == RepeatedInt64.getDescriptor() + || entry.getKey() == RepeatedDouble.getDescriptor()) { + assertEquals(2, success); + } else { + assertEquals(1, success); + } + } + } + + @Test + public void testOptional() throws Exception { + TestInt64 expectedProto = TestInt64.newBuilder().setByte(1).build(); + JSONObject json = new JSONObject(); + json.put("byte", 1); + + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(TestInt64.getDescriptor(), json, false); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testRepeatedIsOptional() throws Exception { + TestRepeatedIsOptional expectedProto = + TestRepeatedIsOptional.newBuilder().setRequiredDouble(1.1).build(); + JSONObject json = new JSONObject(); + json.put("required_double", 1.1); + + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage( + TestRepeatedIsOptional.getDescriptor(), json, false); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testRequired() throws Exception { + JSONObject json = new JSONObject(); + json.put("optional_double", 1.1); + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(TestRequired.getDescriptor(), json, false); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), "JSONObject does not have the required field root.required_double."); + } + } + + @Test + public void testStructSimple() throws Exception { + MessageType expectedProto = + MessageType.newBuilder() + .setTestFieldType(StringType.newBuilder().setTestFieldType("test").build()) + .build(); + JSONObject stringType = new JSONObject(); + stringType.put("test_field_type", "test"); + JSONObject json = new JSONObject(); + json.put("test_field_type", stringType); + + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(MessageType.getDescriptor(), json, false); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testStructSimpleFail() throws Exception { + JSONObject stringType = new JSONObject(); + stringType.put("test_field_type", 1); + JSONObject json = new JSONObject(); + json.put("test_field_type", stringType); + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(MessageType.getDescriptor(), json, false); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + "JSONObject does not have a string field at root.test_field_type.test_field_type."); + } + } + + @Test + public void testStructComplex() throws Exception { + ComplexRoot expectedProto = + ComplexRoot.newBuilder() + .setTestInt(1) + .addTestString("a") + .addTestString("b") + .addTestString("c") + .setTestBytes(ByteString.copyFrom("hello".getBytes())) + .setTestBool(true) + .addTestDouble(1.1) + .addTestDouble(2.2) + .addTestDouble(3.3) + .addTestDouble(4.4) + .setTestDate(1) + .setComplexLvl1( + ComplexLvl1.newBuilder() + .setTestInt(2) + .setComplexLvl2(ComplexLvl2.newBuilder().setTestInt(3).build()) + .build()) + .setComplexLvl2(ComplexLvl2.newBuilder().setTestInt(3).build()) + .build(); + JSONObject complex_lvl2 = new JSONObject(); + complex_lvl2.put("test_int", 3); + + JSONObject complex_lvl1 = new JSONObject(); + complex_lvl1.put("test_int", 2); + complex_lvl1.put("complex_lvl2", complex_lvl2); + + JSONObject json = new JSONObject(); + json.put("test_int", 1); + json.put("test_string", new JSONArray(new String[] {"a", "b", "c"})); + json.put("test_bytes", "hello"); + json.put("test_bool", true); + json.put("test_DOUBLe", new JSONArray(new Double[] {1.1, 2.2, 3.3, 4.4})); + json.put("test_date", 1); + json.put("complex_lvl1", complex_lvl1); + json.put("complex_lvl2", complex_lvl2); + + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(ComplexRoot.getDescriptor(), json, false); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testStructComplexFail() throws Exception { + JSONObject complex_lvl2 = new JSONObject(); + complex_lvl2.put("test_int", 3); + + JSONObject complex_lvl1 = new JSONObject(); + complex_lvl1.put("test_int", "not_int"); + complex_lvl1.put("complex_lvl2", complex_lvl2); + + JSONObject json = new JSONObject(); + json.put("test_int", 1); + json.put("test_string", new JSONArray(new String[] {"a", "b", "c"})); + json.put("test_bytes", "hello"); + json.put("test_bool", true); + json.put("test_double", new JSONArray(new Double[] {1.1, 2.2, 3.3, 4.4})); + json.put("test_date", 1); + json.put("complex_lvl1", complex_lvl1); + json.put("complex_lvl2", complex_lvl2); + + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(ComplexRoot.getDescriptor(), json, false); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), "JSONObject does not have a int64 field at root.complex_lvl1.test_int."); + } + } + + @Test + public void testRepeatedWithMixedTypes() throws Exception { + JSONObject json = new JSONObject(); + json.put("test_repeated", new JSONArray("[1.1, 2.2, true]")); + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(RepeatedDouble.getDescriptor(), json, false); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), "JSONObject does not have a double field at root.test_repeated[2]."); + } + } + + @Test + public void testNestedRepeatedComplex() throws Exception { + NestedRepeated expectedProto = + NestedRepeated.newBuilder() + .addDouble(1.1) + .addDouble(2.2) + .addDouble(3.3) + .addDouble(4.4) + .addDouble(5.5) + .addInt(1) + .addInt(2) + .addInt(3) + .addInt(4) + .addInt(5) + .setRepeatedString( + RepeatedString.newBuilder() + .addTestRepeated("hello") + .addTestRepeated("this") + .addTestRepeated("is") + .addTestRepeated("a") + .addTestRepeated("test") + .build()) + .build(); + double[] doubleArr = {1.1, 2.2, 3.3, 4.4, 5.5}; + String[] stringArr = {"hello", "this", "is", "a", "test"}; + int[] intArr = {1, 2, 3, 4, 5}; + + JSONObject json = new JSONObject(); + json.put("double", new JSONArray(doubleArr)); + json.put("int", new JSONArray(intArr)); + JSONObject jsonRepeatedString = new JSONObject(); + jsonRepeatedString.put("test_repeated", new JSONArray(stringArr)); + json.put("repeated_string", jsonRepeatedString); + + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(NestedRepeated.getDescriptor(), json, false); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testNestedRepeatedComplexFail() throws Exception { + double[] doubleArr = {1.1, 2.2, 3.3, 4.4, 5.5}; + Boolean[] fakeStringArr = {true, false}; + int[] intArr = {1, 2, 3, 4, 5}; + + JSONObject json = new JSONObject(); + json.put("double", new JSONArray(doubleArr)); + json.put("int", new JSONArray(intArr)); + JSONObject jsonRepeatedString = new JSONObject(); + jsonRepeatedString.put("test_repeated", new JSONArray(fakeStringArr)); + json.put("repeated_string", jsonRepeatedString); + + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(NestedRepeated.getDescriptor(), json, false); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + "JSONObject does not have a string field at root.repeated_string.test_repeated[0]."); + } + } + + @Test + public void testAllowUnknownFields() throws Exception { + RepeatedInt64 expectedProto = + RepeatedInt64.newBuilder() + .addTestRepeated(1) + .addTestRepeated(2) + .addTestRepeated(3) + .addTestRepeated(4) + .addTestRepeated(5) + .build(); + JSONObject json = new JSONObject(); + json.put("test_repeated", new JSONArray(new int[] {1, 2, 3, 4, 5})); + json.put("string", "hello"); + + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(RepeatedInt64.getDescriptor(), json, true); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testEmptySecondLevelObject() throws Exception { + ComplexLvl1 expectedProto = + ComplexLvl1.newBuilder() + .setTestInt(1) + .setComplexLvl2(ComplexLvl2.newBuilder().build()) + .build(); + JSONObject complexLvl2 = new JSONObject(); + JSONObject json = new JSONObject(); + json.put("test_int", 1); + json.put("complex_lvl2", complexLvl2); + + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(ComplexLvl1.getDescriptor(), json, true); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testAllowUnknownFieldsError() throws Exception { + JSONObject json = new JSONObject(); + json.put("test_repeated", new JSONArray(new int[] {1, 2, 3, 4, 5})); + json.put("string", "hello"); + + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(RepeatedInt64.getDescriptor(), json, false); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + "JSONObject has fields unknown to BigQuery: root.string. Set allowUnknownFields to True to allow unknown fields."); + } + } + + @Test + public void testEmptyProtoMessage() throws Exception { + JSONObject json = new JSONObject(); + json.put("test_repeated", new JSONArray(new int[0])); + json.put("string", "hello"); + + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(RepeatedInt64.getDescriptor(), json, true); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), "The created protobuf message is empty."); + } + } + + @Test + public void testEmptyJSONObject() throws Exception { + JSONObject json = new JSONObject(); + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(Int64Type.getDescriptor(), json, false); + } catch (IllegalStateException e) { + assertEquals(e.getMessage(), "JSONObject is empty."); + } + } + + @Test + public void testNullJson() throws Exception { + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(Int64Type.getDescriptor(), null, false); + } catch (NullPointerException e) { + assertEquals(e.getMessage(), "JSONObject is null."); + } + } + + @Test + public void testNullDescriptor() throws Exception { + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(null, new JSONObject(), false); + } catch (NullPointerException e) { + assertEquals(e.getMessage(), "Protobuf descriptor is null."); + } + } + + @Test + public void testAllowUnknownFieldsSecondLevel() throws Exception { + JSONObject complex_lvl2 = new JSONObject(); + complex_lvl2.put("no_match", 1); + JSONObject json = new JSONObject(); + json.put("test_int", 1); + json.put("complex_lvl2", complex_lvl2); + + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(ComplexLvl1.getDescriptor(), json, false); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + "JSONObject has fields unknown to BigQuery: root.complex_lvl2.no_match. Set allowUnknownFields to True to allow unknown fields."); + } + } + + @Test + public void testTopLevelMismatch() throws Exception { + JSONObject json = new JSONObject(); + json.put("no_match", 1.1); + + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage( + TopLevelMismatch.getDescriptor(), json, true); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + "There are no matching fields found for the JSONObject and the protocol buffer descriptor."); + } + } + + @Test + public void testTopLevelMatchSecondLevelMismatch() throws Exception { + ComplexLvl1 expectedProto = + ComplexLvl1.newBuilder() + .setTestInt(1) + .setComplexLvl2(ComplexLvl2.newBuilder().build()) + .build(); + JSONObject complex_lvl2 = new JSONObject(); + complex_lvl2.put("no_match", 1); + JSONObject json = new JSONObject(); + json.put("test_int", 1); + json.put("complex_lvl2", complex_lvl2); + + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(ComplexLvl1.getDescriptor(), json, true); + assertEquals(protoMsg, expectedProto); + } + + @Test + public void testJsonNullValue() throws Exception { + JSONObject json = new JSONObject(); + json.put("long", JSONObject.NULL); + json.put("int", 1); + try { + DynamicMessage protoMsg = + JsonToProtoMessage.convertJsonToProtoMessage(TestInt64.getDescriptor(), json, false); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), "JSONObject does not have a int64 field at root.long."); + } + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/ProtoSchemaConverterTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/ProtoSchemaConverterTest.java new file mode 100644 index 0000000000..9e025a13eb --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/ProtoSchemaConverterTest.java @@ -0,0 +1,192 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import com.google.api.gax.rpc.InvalidArgumentException; +import com.google.cloud.bigquery.storage.test.Test.*; +import com.google.protobuf.DescriptorProtos.FileDescriptorProto; +import com.google.protobuf.Descriptors; +import org.junit.*; + +public class ProtoSchemaConverterTest { + @Test + public void convertSimple() { + AllSupportedTypes testProto = AllSupportedTypes.newBuilder().setStringValue("abc").build(); + ProtoSchema protoSchema = ProtoSchemaConverter.convert(testProto.getDescriptorForType()); + Assert.assertEquals( + "name: \"com_google_cloud_bigquery_storage_test_AllSupportedTypes\"\n" + + "field {\n" + + " name: \"int32_value\"\n" + + " number: 1\n" + + " label: LABEL_OPTIONAL\n" + + " type: TYPE_INT32\n" + + "}\n" + + "field {\n" + + " name: \"int64_value\"\n" + + " number: 2\n" + + " label: LABEL_OPTIONAL\n" + + " type: TYPE_INT64\n" + + "}\n" + + "field {\n" + + " name: \"uint32_value\"\n" + + " number: 3\n" + + " label: LABEL_OPTIONAL\n" + + " type: TYPE_UINT32\n" + + "}\n" + + "field {\n" + + " name: \"uint64_value\"\n" + + " number: 4\n" + + " label: LABEL_OPTIONAL\n" + + " type: TYPE_UINT64\n" + + "}\n" + + "field {\n" + + " name: \"float_value\"\n" + + " number: 5\n" + + " label: LABEL_OPTIONAL\n" + + " type: TYPE_FLOAT\n" + + "}\n" + + "field {\n" + + " name: \"double_value\"\n" + + " number: 6\n" + + " label: LABEL_OPTIONAL\n" + + " type: TYPE_DOUBLE\n" + + "}\n" + + "field {\n" + + " name: \"bool_value\"\n" + + " number: 7\n" + + " label: LABEL_OPTIONAL\n" + + " type: TYPE_BOOL\n" + + "}\n" + + "field {\n" + + " name: \"enum_value\"\n" + + " number: 8\n" + + " label: LABEL_OPTIONAL\n" + + " type: TYPE_ENUM\n" + + " type_name: \"com_google_cloud_bigquery_storage_test_TestEnum_E.TestEnum\"\n" + + "}\n" + + "field {\n" + + " name: \"string_value\"\n" + + " number: 9\n" + + " label: LABEL_REQUIRED\n" + + " type: TYPE_STRING\n" + + "}\n" + + "nested_type {\n" + + " name: \"com_google_cloud_bigquery_storage_test_TestEnum_E\"\n" + + " enum_type {\n" + + " name: \"TestEnum\"\n" + + " value {\n" + + " name: \"TestEnum0\"\n" + + " number: 0\n" + + " }\n" + + " value {\n" + + " name: \"TestEnum1\"\n" + + " number: 1\n" + + " }\n" + + " }\n" + + "}\n", + protoSchema.getProtoDescriptor().toString()); + } + + @Test + public void convertNested() { + ComplicateType testProto = ComplicateType.newBuilder().build(); + ProtoSchema protoSchema = ProtoSchemaConverter.convert(testProto.getDescriptorForType()); + Assert.assertEquals( + "name: \"com_google_cloud_bigquery_storage_test_ComplicateType\"\n" + + "field {\n" + + " name: \"nested_repeated_type\"\n" + + " number: 1\n" + + " label: LABEL_REPEATED\n" + + " type: TYPE_MESSAGE\n" + + " type_name: \"com_google_cloud_bigquery_storage_test_NestedType\"\n" + + "}\n" + + "field {\n" + + " name: \"inner_type\"\n" + + " number: 2\n" + + " label: LABEL_OPTIONAL\n" + + " type: TYPE_MESSAGE\n" + + " type_name: \"com_google_cloud_bigquery_storage_test_InnerType\"\n" + + "}\n" + + "nested_type {\n" + + " name: \"com_google_cloud_bigquery_storage_test_InnerType\"\n" + + " field {\n" + + " name: \"value\"\n" + + " number: 1\n" + + " label: LABEL_REPEATED\n" + + " type: TYPE_STRING\n" + + " }\n" + + "}\n" + + "nested_type {\n" + + " name: \"com_google_cloud_bigquery_storage_test_NestedType\"\n" + + " field {\n" + + " name: \"inner_type\"\n" + + " number: 1\n" + + " label: LABEL_REPEATED\n" + + " type: TYPE_MESSAGE\n" + + " type_name: \"com_google_cloud_bigquery_storage_test_InnerType\"\n" + + " }\n" + + "}\n", + protoSchema.getProtoDescriptor().toString()); + } + + @Test + public void convertRecursive() { + try { + RecursiveType testProto = RecursiveType.newBuilder().build(); + ProtoSchema protoSchema = ProtoSchemaConverter.convert(testProto.getDescriptorForType()); + Assert.fail("No exception raised"); + } catch (InvalidArgumentException e) { + Assert.assertEquals( + "Recursive type is not supported:com.google.cloud.bigquery.storage.test.RecursiveType", + e.getMessage()); + } + } + + @Test + public void convertRecursiveTopMessage() { + try { + RecursiveTypeTopMessage testProto = RecursiveTypeTopMessage.newBuilder().build(); + ProtoSchema protoSchema = ProtoSchemaConverter.convert(testProto.getDescriptorForType()); + Assert.fail("No exception raised"); + } catch (InvalidArgumentException e) { + Assert.assertEquals( + "Recursive type is not supported:com.google.cloud.bigquery.storage.test.RecursiveTypeTopMessage", + e.getMessage()); + } + } + + @Test + public void convertDuplicateType() { + DuplicateType testProto = DuplicateType.newBuilder().build(); + ProtoSchema protoSchema = ProtoSchemaConverter.convert(testProto.getDescriptorForType()); + + FileDescriptorProto fileDescriptorProto = + FileDescriptorProto.newBuilder() + .setName("foo.proto") + .addMessageType(protoSchema.getProtoDescriptor()) + .build(); + try { + Descriptors.FileDescriptor fs = + Descriptors.FileDescriptor.buildFrom( + fileDescriptorProto, new Descriptors.FileDescriptor[0]); + Descriptors.Descriptor type = + fs.findMessageTypeByName(protoSchema.getProtoDescriptor().getName()); + Assert.assertEquals(4, type.getFields().size()); + } catch (Descriptors.DescriptorValidationException ex) { + Assert.fail("Got unexpected exception: " + ex.getMessage()); + } + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/SchemaCompatibilityTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/SchemaCompatibilityTest.java new file mode 100644 index 0000000000..ee8761aea3 --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/SchemaCompatibilityTest.java @@ -0,0 +1,1015 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import com.google.cloud.bigquery.*; +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.LegacySQLTypeName; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.storage.test.SchemaTest.*; +import com.google.cloud.bigquery.storage.test.Test.FooType; +import com.google.protobuf.Descriptors; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import javax.annotation.Nullable; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +@RunWith(JUnit4.class) +public class SchemaCompatibilityTest { + @Mock private BigQuery mockBigquery; + @Mock private Table mockBigqueryTable; + Descriptors.Descriptor[] type_descriptors = { + Int32Type.getDescriptor(), + Int64Type.getDescriptor(), + UInt32Type.getDescriptor(), + UInt64Type.getDescriptor(), + Fixed32Type.getDescriptor(), + Fixed64Type.getDescriptor(), + SFixed32Type.getDescriptor(), + SFixed64Type.getDescriptor(), + FloatType.getDescriptor(), + DoubleType.getDescriptor(), + BoolType.getDescriptor(), + BytesType.getDescriptor(), + StringType.getDescriptor(), + EnumType.getDescriptor(), + MessageType.getDescriptor(), + GroupType.getDescriptor() + }; + + @Before + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + when(mockBigquery.getTable(any(TableId.class))).thenReturn(mockBigqueryTable); + } + + @After + public void tearDown() { + verifyNoMoreInteractions(mockBigquery); + verifyNoMoreInteractions(mockBigqueryTable); + } + + public void customizeSchema(final Schema schema) { + TableDefinition definition = + new TableDefinition() { + @Override + public Type getType() { + return null; + } + + @Nullable + @Override + public Schema getSchema() { + return schema; + } + + @Override + public Builder toBuilder() { + return null; + } + }; + when(mockBigqueryTable.getDefinition()).thenReturn(definition); + } + + @Test + public void testSuccess() throws Exception { + customizeSchema( + Schema.of( + Field.newBuilder("Foo", LegacySQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + compact.check("projects/p/datasets/d/tables/t", FooType.getDescriptor(), false); + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + @Test + public void testBadTableName() throws Exception { + try { + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + compact.check("blah", FooType.getDescriptor(), false); + fail("should fail"); + } catch (IllegalArgumentException expected) { + assertEquals("Invalid table name: blah", expected.getMessage()); + } + } + + @Test + public void testSupportedTypes() { + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + for (Descriptors.FieldDescriptor field : SupportedTypes.getDescriptor().getFields()) { + assertTrue(compact.isSupportedType(field)); + } + + for (Descriptors.FieldDescriptor field : NonSupportedTypes.getDescriptor().getFields()) { + assertFalse(compact.isSupportedType(field)); + } + } + + @Test + public void testMap() { + customizeSchema( + Schema.of( + Field.newBuilder("map_value", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + Descriptors.Descriptor testMap = NonSupportedMap.getDescriptor(); + String protoName = testMap.getName() + ".map_value"; + try { + compact.check("projects/p/datasets/d/tables/t", testMap, false); + fail("Should not be supported: field contains map"); + } catch (IllegalArgumentException expected) { + assertEquals( + "Proto schema " + protoName + " is not supported: is a map field.", + expected.getMessage()); + } + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + @Test + public void testNestingSupportedSimple() { + Field BQSupportedNestingLvl2 = + Field.newBuilder("int_value", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.NULLABLE) + .build(); + customizeSchema( + Schema.of( + Field.newBuilder("int_value", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("nesting_value", LegacySQLTypeName.RECORD, BQSupportedNestingLvl2) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + Descriptors.Descriptor testNesting = SupportedNestingLvl1.getDescriptor(); + try { + compact.check("projects/p/datasets/d/tables/t", testNesting, false); + } catch (Exception e) { + fail(e.getMessage()); + } + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + @Test + public void testNestingSupportedStacked() { + Field BQSupportedNestingLvl2 = + Field.newBuilder("int_value", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.NULLABLE) + .build(); + customizeSchema( + Schema.of( + Field.newBuilder("int_value", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("nesting_value1", LegacySQLTypeName.RECORD, BQSupportedNestingLvl2) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("nesting_value2", LegacySQLTypeName.RECORD, BQSupportedNestingLvl2) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + Descriptors.Descriptor testNesting = SupportedNestingStacked.getDescriptor(); + try { + compact.check("projects/p/datasets/d/tables/t", testNesting, false); + } catch (Exception e) { + fail(e.getMessage()); + } + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + /* + * This is not the "exact" test, as BigQuery fields cannot be recursive. Instead, this test uses + * two DIFFERENT records with the same name to simulate recursive protos (protos can't have the + * same name anyways unless they are the same proto). + */ + @Test + public void testNestingContainsRecursive() { + Field BQNonSupportedNestingRecursive = + Field.newBuilder( + "nesting_value", + LegacySQLTypeName.RECORD, + Field.newBuilder("int_value", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.NULLABLE) + .build()) + .setMode(Field.Mode.NULLABLE) + .build(); + + customizeSchema( + Schema.of( + Field.newBuilder("int_value", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder( + "nesting_value", LegacySQLTypeName.RECORD, BQNonSupportedNestingRecursive) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + Descriptors.Descriptor testNesting = NonSupportedNestingContainsRecursive.getDescriptor(); + try { + compact.check("projects/p/datasets/d/tables/t", testNesting, false); + fail("Should not be supported: contains nested messages of more than 15 levels."); + } catch (IllegalArgumentException expected) { + assertEquals( + "Proto schema " + + testNesting.getName() + + ".nesting_value.nesting_value is not supported: is a recursively nested message.", + expected.getMessage()); + } + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + @Test + public void testNestingRecursiveLimit() { + Field NonSupportedNestingLvl16 = + Field.newBuilder("test1", LegacySQLTypeName.INTEGER).setMode(Field.Mode.NULLABLE).build(); + Field NonSupportedNestingLvl15 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl16) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl14 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl15) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl13 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl14) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl12 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl13) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl11 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl12) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl10 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl11) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl9 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl10) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl8 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl9) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl7 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl8) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl6 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl7) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl5 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl6) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl4 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl5) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl3 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl4) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl2 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl3) + .setMode(Field.Mode.NULLABLE) + .build(); + Field NonSupportedNestingLvl1 = + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl2) + .setMode(Field.Mode.NULLABLE) + .build(); + customizeSchema( + Schema.of( + Field.newBuilder("test1", LegacySQLTypeName.RECORD, NonSupportedNestingLvl1) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + Descriptors.Descriptor testNesting = NonSupportedNestingLvl0.getDescriptor(); + try { + compact.check("projects/p/datasets/d/tables/t", testNesting, false); + fail("Should not be supported: contains nested messages of more than 15 levels."); + } catch (IllegalArgumentException expected) { + assertEquals( + "Proto schema " + + testNesting.getName() + + " is not supported: contains nested messages of more than 15 levels.", + expected.getMessage()); + } + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + @Test + public void testProtoMoreFields() { + Schema customSchema = Schema.of(Field.of("int32_value", LegacySQLTypeName.INTEGER)); + customizeSchema(customSchema); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + + try { + compact.check("projects/p/datasets/d/tables/t", SupportedTypes.getDescriptor(), false); + fail("Should fail: proto has more fields and allowUnknownFields flag is false."); + } catch (IllegalArgumentException expected) { + assertEquals( + "Proto schema " + + SupportedTypes.getDescriptor().getName() + + " has " + + SupportedTypes.getDescriptor().getFields().size() + + " fields, while BQ schema t has " + + 1 + + " fields.", + expected.getMessage()); + } + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + @Test + public void testBQRepeated() { + customizeSchema( + Schema.of( + Field.newBuilder("repeated_mode", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.REPEATED) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + compact.check("projects/p/datasets/d/tables/t", ProtoRepeatedBQRepeated.getDescriptor(), false); + try { + compact.check( + "projects/p/datasets/d/tables/t", ProtoOptionalBQRepeated.getDescriptor(), false); + fail("Should fail: BQ schema is repeated, but proto is optional."); + } catch (IllegalArgumentException expected) { + assertEquals( + "Given proto field " + + ProtoOptionalBQRepeated.getDescriptor().getName() + + ".repeated_mode" + + " is not repeated but Big Query field t.repeated_mode is.", + expected.getMessage()); + } + + try { + compact.check( + "projects/p/datasets/d/tables/t", ProtoRequiredBQRepeated.getDescriptor(), false); + fail("Should fail: BQ schema is repeated, but proto is required."); + } catch (IllegalArgumentException expected) { + assertEquals( + "Given proto field " + + ProtoRequiredBQRepeated.getDescriptor().getName() + + ".repeated_mode" + + " is not repeated but Big Query field t.repeated_mode is.", + expected.getMessage()); + } + verify(mockBigquery, times(3)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(3)).getDefinition(); + } + + @Test + public void testBQRequired() { + customizeSchema( + Schema.of( + Field.newBuilder("required_mode", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.REQUIRED) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + compact.check("projects/p/datasets/d/tables/t", ProtoRequiredBQRequired.getDescriptor(), false); + + try { + compact.check("projects/p/datasets/d/tables/t", ProtoNoneBQRequired.getDescriptor(), false); + fail("Should fail: BQ schema is required, but proto does not have this field."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The required Big Query field t.required_mode is missing in the proto schema " + + ProtoNoneBQRequired.getDescriptor().getName() + + ".", + expected.getMessage()); + } + + try { + compact.check( + "projects/p/datasets/d/tables/t", ProtoOptionalBQRequired.getDescriptor(), false); + fail("Should fail: BQ schema is required, but proto is optional."); + } catch (IllegalArgumentException expected) { + assertEquals( + "Given proto field " + + ProtoOptionalBQRequired.getDescriptor().getName() + + ".required_mode is not required but Big Query field t.required_mode is.", + expected.getMessage()); + } + + try { + compact.check( + "projects/p/datasets/d/tables/t", ProtoRepeatedBQRequired.getDescriptor(), false); + fail("Should fail: BQ schema is required, but proto is repeated."); + } catch (IllegalArgumentException expected) { + assertEquals( + "Given proto field " + + ProtoRepeatedBQRequired.getDescriptor().getName() + + ".required_mode is not required but Big Query field t.required_mode is.", + expected.getMessage()); + } + verify(mockBigquery, times(4)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(4)).getDefinition(); + } + + @Test + public void testBQOptional() { + customizeSchema( + Schema.of( + Field.newBuilder("optional_mode", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + compact.check("projects/p/datasets/d/tables/t", ProtoOptionalBQOptional.getDescriptor(), false); + compact.check("projects/p/datasets/d/tables/t", ProtoRequiredBQOptional.getDescriptor(), false); + + try { + compact.check( + "projects/p/datasets/d/tables/t", ProtoRepeatedBQOptional.getDescriptor(), false); + fail("Should fail: BQ schema is nullable, but proto field is repeated."); + } catch (IllegalArgumentException expected) { + assertEquals( + "Given proto field " + + ProtoRepeatedBQOptional.getDescriptor().getName() + + ".optional_mode is repeated but Big Query field t.optional_mode is optional.", + expected.getMessage()); + } + + verify(mockBigquery, times(3)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(3)).getDefinition(); + } + + @Test + public void testBQBool() { + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.BOOLEAN) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>( + Arrays.asList( + Int32Type.getDescriptor(), + Int64Type.getDescriptor(), + UInt32Type.getDescriptor(), + UInt64Type.getDescriptor(), + Fixed32Type.getDescriptor(), + Fixed64Type.getDescriptor(), + SFixed32Type.getDescriptor(), + SFixed64Type.getDescriptor(), + BoolType.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ Boolean."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + @Test + public void testBQBytes() { + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.BYTES) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>(Arrays.asList(BytesType.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ Bytes."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + @Test + public void testBQDate() { + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.DATE) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>( + Arrays.asList( + Int32Type.getDescriptor(), + Int64Type.getDescriptor(), + SFixed32Type.getDescriptor(), + SFixed64Type.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ Date."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + @Test + public void testBQDatetime() { + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.DATETIME) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>(Arrays.asList(Int64Type.getDescriptor(), StringType.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ Datetime."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + @Test + public void testBQFloat() { + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.FLOAT) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>(Arrays.asList(FloatType.getDescriptor(), DoubleType.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ Float."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + @Test + public void testBQGeography() { + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.GEOGRAPHY) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>(Arrays.asList(StringType.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ Geography."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + @Test + public void testBQInteger() { + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>( + Arrays.asList( + Int32Type.getDescriptor(), + Int64Type.getDescriptor(), + UInt32Type.getDescriptor(), + Fixed32Type.getDescriptor(), + SFixed32Type.getDescriptor(), + SFixed64Type.getDescriptor(), + EnumType.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ Integer."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + @Test + public void testBQNumeric() { + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.NUMERIC) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>( + Arrays.asList( + Int32Type.getDescriptor(), + Int64Type.getDescriptor(), + UInt32Type.getDescriptor(), + UInt64Type.getDescriptor(), + Fixed32Type.getDescriptor(), + Fixed64Type.getDescriptor(), + SFixed32Type.getDescriptor(), + SFixed64Type.getDescriptor(), + BytesType.getDescriptor(), + StringType.getDescriptor(), + FloatType.getDescriptor(), + DoubleType.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ Numeric."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + @Test + public void testBQRecord() { + Field nestedMessage = + Field.newBuilder("test_field_type", LegacySQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(); + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.RECORD, nestedMessage) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>(Arrays.asList(MessageType.getDescriptor(), GroupType.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ String."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + @Test + public void testBQRecordMismatch() { + Field nestedMessage1 = + Field.newBuilder("test_field_type", LegacySQLTypeName.INTEGER) + .setMode(Field.Mode.NULLABLE) + .build(); + Field nestedMessage0 = + Field.newBuilder("mismatchlvl1", LegacySQLTypeName.RECORD, nestedMessage1) + .setMode(Field.Mode.NULLABLE) + .build(); + customizeSchema( + Schema.of( + Field.newBuilder("mismatchlvl0", LegacySQLTypeName.RECORD, nestedMessage0) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + try { + compact.check("projects/p/datasets/d/tables/t", MessageTypeMismatch.getDescriptor(), false); + fail("Should fail: Proto schema type should not match BQ String."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + MessageTypeMismatch.getDescriptor().getName() + + ".mismatchlvl0.mismatchlvl1.test_field_type does not have a matching type with the big query field t.mismatchlvl0.mismatchlvl1.test_field_type.", + expected.getMessage()); + } + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + @Test + public void testBQRecordMatch() { + Field nestedMessage1 = + Field.newBuilder("test_field_type", LegacySQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(); + Field nestedMessage0 = + Field.newBuilder("mismatchlvl1", LegacySQLTypeName.RECORD, nestedMessage1) + .setMode(Field.Mode.NULLABLE) + .build(); + customizeSchema( + Schema.of( + Field.newBuilder("mismatchlvl0", LegacySQLTypeName.RECORD, nestedMessage0) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + compact.check("projects/p/datasets/d/tables/t", MessageTypeMismatch.getDescriptor(), false); + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + @Test + public void testBQString() { + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>(Arrays.asList(StringType.getDescriptor(), EnumType.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ String."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + @Test + public void testBQTime() { + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.TIME) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>(Arrays.asList(Int64Type.getDescriptor(), StringType.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ Time."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + @Test + public void testBQTimestamp() { + customizeSchema( + Schema.of( + Field.newBuilder("test_field_type", LegacySQLTypeName.TIMESTAMP) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + HashSet compatible = + new HashSet<>( + Arrays.asList( + Int32Type.getDescriptor(), + Int64Type.getDescriptor(), + UInt32Type.getDescriptor(), + Fixed32Type.getDescriptor(), + SFixed32Type.getDescriptor(), + SFixed64Type.getDescriptor(), + EnumType.getDescriptor())); + + for (Descriptors.Descriptor descriptor : type_descriptors) { + if (compatible.contains(descriptor)) { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + } else { + try { + compact.check("projects/p/datasets/d/tables/t", descriptor, false); + fail("Should fail: Proto schema type should not match BQ Timestamp."); + } catch (IllegalArgumentException expected) { + assertEquals( + "The proto field " + + descriptor.getName() + + ".test_field_type does not have a matching type with the big query field t.test_field_type.", + expected.getMessage()); + } + } + } + verify(mockBigquery, times(16)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(16)).getDefinition(); + } + + /* + * Tests if having no matching fields in the top level causes an error. + */ + @Test + public void testBQTopLevelMismatch() { + customizeSchema( + Schema.of( + Field.newBuilder("test_toplevel_mismatch", LegacySQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + try { + compact.check("projects/p/datasets/d/tables/t", StringType.getDescriptor(), false); + } catch (IllegalArgumentException expected) { + assertEquals( + "There is no matching fields found for the proto schema " + + StringType.getDescriptor().getName() + + " and the BQ table schema t.", + expected.getMessage()); + } + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + /* + * Tests if there is at least 1 matching field in the top level. + */ + @Test + public void testBQTopLevelMatch() { + Field nestedMessage0 = + Field.newBuilder("mismatch", LegacySQLTypeName.STRING).setMode(Field.Mode.NULLABLE).build(); + customizeSchema( + Schema.of( + Field.newBuilder("mismatch", LegacySQLTypeName.RECORD, nestedMessage0) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("match", LegacySQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + compact.check("projects/p/datasets/d/tables/t", TopLevelMatch.getDescriptor(), false); + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + @Test + public void testAllowUnknownUnsupportedFields() { + customizeSchema( + Schema.of( + Field.newBuilder("string_value", LegacySQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + compact.check( + "projects/p/datasets/d/tables/t", AllowUnknownUnsupportedFields.getDescriptor(), true); + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } + + @Test + public void testLowerCase() { + customizeSchema( + Schema.of( + Field.newBuilder("tEsT_fIeLd_TyPe", LegacySQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build())); + SchemaCompatibility compact = SchemaCompatibility.getInstance(mockBigquery); + compact.check("projects/p/datasets/d/tables/t", StringType.getDescriptor(), true); + verify(mockBigquery, times(1)).getTable(any(TableId.class)); + verify(mockBigqueryTable, times(1)).getDefinition(); + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriterTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriterTest.java new file mode 100644 index 0000000000..80d8c493dd --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriterTest.java @@ -0,0 +1,903 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1beta2; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.api.core.ApiFuture; +import com.google.api.gax.batching.BatchingSettings; +import com.google.api.gax.batching.FlowControlSettings; +import com.google.api.gax.batching.FlowController; +import com.google.api.gax.core.ExecutorProvider; +import com.google.api.gax.core.FixedExecutorProvider; +import com.google.api.gax.core.InstantiatingExecutorProvider; +import com.google.api.gax.core.NoCredentialsProvider; +import com.google.api.gax.grpc.testing.LocalChannelProvider; +import com.google.api.gax.grpc.testing.MockGrpcService; +import com.google.api.gax.grpc.testing.MockServiceHelper; +import com.google.api.gax.retrying.RetrySettings; +import com.google.api.gax.rpc.DataLossException; +import com.google.cloud.bigquery.storage.test.Test.FooType; +import com.google.common.base.Strings; +import com.google.protobuf.DescriptorProtos; +import com.google.protobuf.Int64Value; +import com.google.protobuf.Timestamp; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import java.util.Arrays; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.threeten.bp.Duration; +import org.threeten.bp.Instant; + +@RunWith(JUnit4.class) +public class StreamWriterTest { + private static final Logger LOG = Logger.getLogger(StreamWriterTest.class.getName()); + private static final String TEST_STREAM = "projects/p/datasets/d/tables/t/streams/s"; + private static final ExecutorProvider SINGLE_THREAD_EXECUTOR = + InstantiatingExecutorProvider.newBuilder().setExecutorThreadCount(1).build(); + private static LocalChannelProvider channelProvider; + private FakeScheduledExecutorService fakeExecutor; + private FakeBigQueryWrite testBigQueryWrite; + private static MockServiceHelper serviceHelper; + + @Before + public void setUp() throws Exception { + testBigQueryWrite = new FakeBigQueryWrite(); + serviceHelper = + new MockServiceHelper( + UUID.randomUUID().toString(), Arrays.asList(testBigQueryWrite)); + serviceHelper.start(); + channelProvider = serviceHelper.createChannelProvider(); + fakeExecutor = new FakeScheduledExecutorService(); + testBigQueryWrite.setExecutor(fakeExecutor); + Instant time = Instant.now(); + Timestamp timestamp = + Timestamp.newBuilder().setSeconds(time.getEpochSecond()).setNanos(time.getNano()).build(); + // Add enough GetWriteStream response. + for (int i = 0; i < 4; i++) { + testBigQueryWrite.addResponse( + WriteStream.newBuilder().setName(TEST_STREAM).setCreateTime(timestamp).build()); + } + } + + @After + public void tearDown() throws Exception { + LOG.info("tearDown called"); + serviceHelper.stop(); + } + + private StreamWriter.Builder getTestStreamWriterBuilder(String testStream) { + return StreamWriter.newBuilder(testStream) + .setChannelProvider(channelProvider) + .setExecutorProvider(SINGLE_THREAD_EXECUTOR) + .setCredentialsProvider(NoCredentialsProvider.create()); + } + + private StreamWriter.Builder getTestStreamWriterBuilder() { + return getTestStreamWriterBuilder(TEST_STREAM); + } + + private AppendRowsRequest createAppendRequest(String[] messages, long offset) { + AppendRowsRequest.Builder requestBuilder = AppendRowsRequest.newBuilder(); + AppendRowsRequest.ProtoData.Builder dataBuilder = AppendRowsRequest.ProtoData.newBuilder(); + dataBuilder.setWriterSchema( + ProtoSchema.newBuilder() + .setProtoDescriptor( + DescriptorProtos.DescriptorProto.newBuilder() + .setName("Message") + .addField( + DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName("foo") + .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING) + .setNumber(1) + .build()) + .build())); + ProtoRows.Builder rows = ProtoRows.newBuilder(); + for (String message : messages) { + FooType foo = FooType.newBuilder().setFoo(message).build(); + rows.addSerializedRows(foo.toByteString()); + } + if (offset > 0) { + requestBuilder.setOffset(Int64Value.of(offset)); + } + return requestBuilder + .setProtoRows(dataBuilder.setRows(rows.build()).build()) + .setWriteStream(TEST_STREAM) + .build(); + } + + private ApiFuture sendTestMessage(StreamWriter writer, String[] messages) { + return writer.append(createAppendRequest(messages, -1)); + } + + @Test + public void testTableName() throws Exception { + try (StreamWriter writer = getTestStreamWriterBuilder().build()) { + assertEquals("projects/p/datasets/d/tables/t", writer.getTableNameString()); + } + } + + @Test + public void testAppendByDuration() throws Exception { + StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setDelayThreshold(Duration.ofSeconds(5)) + .setElementCountThreshold(10L) + .build()) + .setExecutorProvider(FixedExecutorProvider.create(fakeExecutor)) + .build(); + + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0).build()); + ApiFuture appendFuture1 = sendTestMessage(writer, new String[] {"A"}); + ApiFuture appendFuture2 = sendTestMessage(writer, new String[] {"B"}); + + assertFalse(appendFuture1.isDone()); + assertFalse(appendFuture2.isDone()); + fakeExecutor.advanceTime(Duration.ofSeconds(10)); + + assertEquals(0L, appendFuture1.get().getOffset()); + assertEquals(1L, appendFuture2.get().getOffset()); + + assertEquals(1, testBigQueryWrite.getAppendRequests().size()); + + assertEquals( + 2, + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + true, testBigQueryWrite.getAppendRequests().get(0).getProtoRows().hasWriterSchema()); + writer.close(); + } + + @Test + public void testAppendByNumBatchedMessages() throws Exception { + StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(2L) + .setDelayThreshold(Duration.ofSeconds(100)) + .build()) + .build(); + + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(2).build()); + + ApiFuture appendFuture1 = sendTestMessage(writer, new String[] {"A"}); + ApiFuture appendFuture2 = sendTestMessage(writer, new String[] {"B"}); + ApiFuture appendFuture3 = sendTestMessage(writer, new String[] {"C"}); + + assertEquals(0L, appendFuture1.get().getOffset()); + assertEquals(1L, appendFuture2.get().getOffset()); + + assertFalse(appendFuture3.isDone()); + + ApiFuture appendFuture4 = sendTestMessage(writer, new String[] {"D"}); + + assertEquals(2L, appendFuture3.get().getOffset()); + assertEquals(3L, appendFuture4.get().getOffset()); + + assertEquals(2, testBigQueryWrite.getAppendRequests().size()); + assertEquals( + 2, + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + true, testBigQueryWrite.getAppendRequests().get(0).getProtoRows().hasWriterSchema()); + assertEquals( + 2, + testBigQueryWrite + .getAppendRequests() + .get(1) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + false, testBigQueryWrite.getAppendRequests().get(1).getProtoRows().hasWriterSchema()); + writer.close(); + } + + @Test + public void testAppendByNumBytes() throws Exception { + StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + // Each message is 32 bytes, setting batch size to 70 bytes allows 2 messages. + .setRequestByteThreshold(70L) + .setDelayThreshold(Duration.ofSeconds(100000)) + .build()) + .build(); + + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(2).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(3).build()); + + ApiFuture appendFuture1 = sendTestMessage(writer, new String[] {"A"}); + ApiFuture appendFuture2 = sendTestMessage(writer, new String[] {"B"}); + ApiFuture appendFuture3 = sendTestMessage(writer, new String[] {"C"}); + + assertEquals(0L, appendFuture1.get().getOffset()); + assertEquals(1L, appendFuture2.get().getOffset()); + assertFalse(appendFuture3.isDone()); + + // This message is big enough to trigger send on the previous message and itself. + ApiFuture appendFuture4 = + sendTestMessage(writer, new String[] {Strings.repeat("A", 100)}); + assertEquals(2L, appendFuture3.get().getOffset()); + assertEquals(3L, appendFuture4.get().getOffset()); + + assertEquals(3, testBigQueryWrite.getAppendRequests().size()); + + writer.close(); + } + + @Test + public void testWriteByShutdown() throws Exception { + StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setDelayThreshold(Duration.ofSeconds(100)) + .setElementCountThreshold(10L) + .build()) + .build(); + + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0L).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(1L).build()); + + ApiFuture appendFuture1 = sendTestMessage(writer, new String[] {"A"}); + ApiFuture appendFuture2 = sendTestMessage(writer, new String[] {"B"}); + + // Note we are not advancing time or reaching the count threshold but messages should + // still get written by call to shutdown + + writer.close(); + + // Verify the appends completed + assertTrue(appendFuture1.isDone()); + assertTrue(appendFuture2.isDone()); + assertEquals(0L, appendFuture1.get().getOffset()); + assertEquals(1L, appendFuture2.get().getOffset()); + } + + @Test + public void testWriteMixedSizeAndDuration() throws Exception { + try (StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(2L) + .setDelayThreshold(Duration.ofSeconds(5)) + .build()) + .build()) { + + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0L).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(2L).build()); + + ApiFuture appendFuture1 = sendTestMessage(writer, new String[] {"A"}); + + fakeExecutor.advanceTime(Duration.ofSeconds(2)); + assertFalse(appendFuture1.isDone()); + + ApiFuture appendFuture2 = + sendTestMessage(writer, new String[] {"B", "C"}); + + // Write triggered by batch size + assertEquals(0L, appendFuture1.get().getOffset()); + assertEquals(1L, appendFuture2.get().getOffset()); + + ApiFuture appendFuture3 = sendTestMessage(writer, new String[] {"D"}); + + assertFalse(appendFuture3.isDone()); + + // Write triggered by time + fakeExecutor.advanceTime(Duration.ofSeconds(5)); + + assertEquals(2L, appendFuture3.get().getOffset()); + + assertEquals( + 3, + testBigQueryWrite + .getAppendRequests() + .get(0) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + true, testBigQueryWrite.getAppendRequests().get(0).getProtoRows().hasWriterSchema()); + assertEquals( + 1, + testBigQueryWrite + .getAppendRequests() + .get(1) + .getProtoRows() + .getRows() + .getSerializedRowsCount()); + assertEquals( + false, testBigQueryWrite.getAppendRequests().get(1).getProtoRows().hasWriterSchema()); + } + } + + @Test + public void testFlowControlBehaviorBlock() throws Exception { + StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(1L) + .setFlowControlSettings( + StreamWriter.Builder.DEFAULT_FLOW_CONTROL_SETTINGS + .toBuilder() + .setMaxOutstandingRequestBytes(40L) + .setLimitExceededBehavior(FlowController.LimitExceededBehavior.Block) + .build()) + .build()) + .build(); + + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(2L).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(3L).build()); + testBigQueryWrite.setResponseDelay(Duration.ofSeconds(10)); + + ApiFuture appendFuture1 = sendTestMessage(writer, new String[] {"A"}); + final StreamWriter writer1 = writer; + Runnable runnable = + new Runnable() { + @Override + public void run() { + ApiFuture appendFuture2 = + sendTestMessage(writer1, new String[] {"B"}); + } + }; + Thread t = new Thread(runnable); + t.start(); + assertEquals(true, t.isAlive()); + assertEquals(false, appendFuture1.isDone()); + // Wait is necessary for response to be scheduled before timer is advanced. + Thread.sleep(5000L); + fakeExecutor.advanceTime(Duration.ofSeconds(10)); + // The first requests gets back while the second one is blocked. + assertEquals(2L, appendFuture1.get().getOffset()); + Thread.sleep(5000L); + // Wait is necessary for response to be scheduled before timer is advanced. + fakeExecutor.advanceTime(Duration.ofSeconds(10)); + t.join(); + writer.close(); + } + + @Test + public void testFlowControlBehaviorException() throws Exception { + try (StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(1L) + .setFlowControlSettings( + StreamWriter.Builder.DEFAULT_FLOW_CONTROL_SETTINGS + .toBuilder() + .setMaxOutstandingElementCount(1L) + .setLimitExceededBehavior( + FlowController.LimitExceededBehavior.ThrowException) + .build()) + .build()) + .build()) { + assertEquals( + 1L, + writer + .getBatchingSettings() + .getFlowControlSettings() + .getMaxOutstandingElementCount() + .longValue()); + + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(1L).build()); + testBigQueryWrite.setResponseDelay(Duration.ofSeconds(10)); + ApiFuture appendFuture1 = sendTestMessage(writer, new String[] {"A"}); + ApiFuture appendFuture2 = sendTestMessage(writer, new String[] {"B"}); + // Wait is necessary for response to be scheduled before timer is advanced. + Thread.sleep(5000L); + fakeExecutor.advanceTime(Duration.ofSeconds(10)); + try { + appendFuture2.get(); + Assert.fail("This should fail"); + } catch (Exception e) { + assertEquals( + "java.util.concurrent.ExecutionException: The maximum number of batch elements: 1 have been reached.", + e.toString()); + } + assertEquals(1L, appendFuture1.get().getOffset()); + } + } + + @Test + public void testStreamReconnectionTransient() throws Exception { + StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setDelayThreshold(Duration.ofSeconds(100000)) + .setElementCountThreshold(1L) + .build()) + .build(); + + StatusRuntimeException transientError = new StatusRuntimeException(Status.UNAVAILABLE); + testBigQueryWrite.addException(transientError); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0).build()); + ApiFuture future1 = sendTestMessage(writer, new String[] {"m1"}); + assertEquals(false, future1.isDone()); + // Retry is scheduled to be 7 seconds later. + assertEquals(0L, future1.get().getOffset()); + writer.close(); + } + + @Test + public void testStreamReconnectionPermanant() throws Exception { + StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setDelayThreshold(Duration.ofSeconds(100000)) + .setElementCountThreshold(1L) + .build()) + .build(); + StatusRuntimeException permanentError = new StatusRuntimeException(Status.INVALID_ARGUMENT); + testBigQueryWrite.addException(permanentError); + ApiFuture future2 = sendTestMessage(writer, new String[] {"m2"}); + try { + future2.get(); + Assert.fail("This should fail."); + } catch (ExecutionException e) { + assertEquals(permanentError.toString(), e.getCause().getCause().toString()); + } + writer.close(); + } + + @Test + public void testStreamReconnectionExceedRetry() throws Exception { + StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setDelayThreshold(Duration.ofSeconds(100000)) + .setElementCountThreshold(1L) + .build()) + .setRetrySettings( + RetrySettings.newBuilder() + .setMaxRetryDelay(Duration.ofMillis(100)) + .setMaxAttempts(1) + .build()) + .build(); + assertEquals(1, writer.getRetrySettings().getMaxAttempts()); + StatusRuntimeException transientError = new StatusRuntimeException(Status.UNAVAILABLE); + testBigQueryWrite.addException(transientError); + testBigQueryWrite.addException(transientError); + ApiFuture future3 = sendTestMessage(writer, new String[] {"toomanyretry"}); + try { + future3.get(); + Assert.fail("This should fail."); + } catch (ExecutionException e) { + assertEquals(transientError.toString(), e.getCause().getCause().toString()); + } + writer.close(); + } + + @Test + public void testOffset() throws Exception { + try (StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(2L) + .build()) + .build()) { + + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(10L).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(13L).build()); + AppendRowsRequest request1 = createAppendRequest(new String[] {"A"}, 10L); + ApiFuture appendFuture1 = writer.append(request1); + AppendRowsRequest request2 = createAppendRequest(new String[] {"B", "C"}, 11L); + ApiFuture appendFuture2 = writer.append(request2); + AppendRowsRequest request3 = createAppendRequest(new String[] {"E", "F"}, 13L); + ApiFuture appendFuture3 = writer.append(request3); + AppendRowsRequest request4 = createAppendRequest(new String[] {"G"}, 15L); + ApiFuture appendFuture4 = writer.append(request4); + assertEquals(10L, appendFuture1.get().getOffset()); + assertEquals(11L, appendFuture2.get().getOffset()); + assertEquals(13L, appendFuture3.get().getOffset()); + assertEquals(15L, appendFuture4.get().getOffset()); + } + } + + @Test + public void testOffsetMismatch() throws Exception { + try (StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(1L) + .build()) + .build()) { + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(11L).build()); + AppendRowsRequest request1 = createAppendRequest(new String[] {"A"}, 10L); + ApiFuture appendFuture1 = writer.append(request1); + + appendFuture1.get(); + fail("Should throw exception"); + } catch (Exception e) { + assertEquals( + "java.lang.IllegalStateException: The append result offset 11 does not match the expected offset 10.", + e.getCause().toString()); + } + } + + @Test + public void testErrorPropagation() throws Exception { + try (StreamWriter writer = + getTestStreamWriterBuilder() + .setExecutorProvider(SINGLE_THREAD_EXECUTOR) + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(1L) + .setDelayThreshold(Duration.ofSeconds(5)) + .build()) + .build()) { + testBigQueryWrite.addException(Status.DATA_LOSS.asException()); + sendTestMessage(writer, new String[] {"A"}).get(); + fail("should throw exception"); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(DataLossException.class); + } + } + + @Test + public void testWriterGetters() throws Exception { + StreamWriter.Builder builder = StreamWriter.newBuilder(TEST_STREAM); + builder.setChannelProvider(channelProvider); + builder.setExecutorProvider(SINGLE_THREAD_EXECUTOR); + builder.setBatchingSettings( + BatchingSettings.newBuilder() + .setRequestByteThreshold(10L) + .setDelayThreshold(Duration.ofMillis(11)) + .setElementCountThreshold(12L) + .setFlowControlSettings( + FlowControlSettings.newBuilder() + .setMaxOutstandingElementCount(100L) + .setMaxOutstandingRequestBytes(1000L) + .setLimitExceededBehavior(FlowController.LimitExceededBehavior.Block) + .build()) + .build()); + builder.setCredentialsProvider(NoCredentialsProvider.create()); + StreamWriter writer = builder.build(); + + assertEquals(TEST_STREAM, writer.getStreamNameString()); + assertEquals(10, (long) writer.getBatchingSettings().getRequestByteThreshold()); + assertEquals(Duration.ofMillis(11), writer.getBatchingSettings().getDelayThreshold()); + assertEquals(12, (long) writer.getBatchingSettings().getElementCountThreshold()); + assertEquals( + FlowController.LimitExceededBehavior.Block, + writer.getBatchingSettings().getFlowControlSettings().getLimitExceededBehavior()); + assertEquals( + 100L, + writer + .getBatchingSettings() + .getFlowControlSettings() + .getMaxOutstandingElementCount() + .longValue()); + assertEquals( + 1000L, + writer + .getBatchingSettings() + .getFlowControlSettings() + .getMaxOutstandingRequestBytes() + .longValue()); + writer.close(); + } + + @Test + public void testBuilderParametersAndDefaults() { + StreamWriter.Builder builder = StreamWriter.newBuilder(TEST_STREAM); + assertEquals(StreamWriter.Builder.DEFAULT_EXECUTOR_PROVIDER, builder.executorProvider); + assertEquals(100 * 1024L, builder.batchingSettings.getRequestByteThreshold().longValue()); + assertEquals(Duration.ofMillis(10), builder.batchingSettings.getDelayThreshold()); + assertEquals(100L, builder.batchingSettings.getElementCountThreshold().longValue()); + assertEquals(StreamWriter.Builder.DEFAULT_RETRY_SETTINGS, builder.retrySettings); + assertEquals(Duration.ofMillis(100), builder.retrySettings.getInitialRetryDelay()); + assertEquals(3, builder.retrySettings.getMaxAttempts()); + } + + @Test + public void testBuilderInvalidArguments() { + StreamWriter.Builder builder = StreamWriter.newBuilder(TEST_STREAM); + + try { + builder.setChannelProvider(null); + fail("Should have thrown an NullPointerException"); + } catch (NullPointerException expected) { + // Expected + } + + try { + builder.setExecutorProvider(null); + fail("Should have thrown an NullPointerException"); + } catch (NullPointerException expected) { + // Expected + } + try { + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setRequestByteThreshold(null) + .build()); + fail("Should have thrown an NullPointerException"); + } catch (NullPointerException expected) { + // Expected + } + try { + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setRequestByteThreshold(0L) + .build()); + fail("Should have thrown an IllegalArgumentException"); + } catch (IllegalArgumentException expected) { + // Expected + } + try { + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setRequestByteThreshold(-1L) + .build()); + fail("Should have thrown an IllegalArgumentException"); + } catch (IllegalArgumentException expected) { + // Expected + } + + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setDelayThreshold(Duration.ofMillis(1)) + .build()); + try { + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setDelayThreshold(null) + .build()); + fail("Should have thrown an NullPointerException"); + } catch (NullPointerException expected) { + // Expected + } + try { + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setDelayThreshold(Duration.ofMillis(-1)) + .build()); + fail("Should have thrown an IllegalArgumentException"); + } catch (IllegalArgumentException expected) { + // Expected + } + + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(1L) + .build()); + try { + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(null) + .build()); + fail("Should have thrown an NullPointerException"); + } catch (NullPointerException expected) { + // Expected + } + try { + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(0L) + .build()); + fail("Should have thrown an IllegalArgumentException"); + } catch (IllegalArgumentException expected) { + // Expected + } + try { + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(-1L) + .build()); + fail("Should have thrown an IllegalArgumentException"); + } catch (IllegalArgumentException expected) { + // Expected + } + + try { + FlowControlSettings flowControlSettings = + FlowControlSettings.newBuilder().setMaxOutstandingElementCount(-1L).build(); + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setFlowControlSettings(flowControlSettings) + .build()); + fail("Should have thrown an IllegalArgumentException"); + } catch (IllegalArgumentException expected) { + // Expected + } + + try { + FlowControlSettings flowControlSettings = + FlowControlSettings.newBuilder().setMaxOutstandingRequestBytes(-1L).build(); + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setFlowControlSettings(flowControlSettings) + .build()); + fail("Should have thrown an IllegalArgumentException"); + } catch (IllegalArgumentException expected) { + // Expected + } + + try { + FlowControlSettings flowControlSettings = + FlowControlSettings.newBuilder() + .setLimitExceededBehavior(FlowController.LimitExceededBehavior.Ignore) + .build(); + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setFlowControlSettings(flowControlSettings) + .build()); + fail("Should have thrown an IllegalArgumentException"); + } catch (IllegalArgumentException expected) { + // Expected + } + + try { + FlowControlSettings flowControlSettings = + FlowControlSettings.newBuilder().setLimitExceededBehavior(null).build(); + builder.setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setFlowControlSettings(flowControlSettings) + .build()); + fail("Should have thrown an NullPointerException"); + } catch (NullPointerException expected) { + // Expected + } + } + + @Test + public void testExistingClient() throws Exception { + BigQueryWriteSettings settings = + BigQueryWriteSettings.newBuilder() + .setTransportChannelProvider(channelProvider) + .setCredentialsProvider(NoCredentialsProvider.create()) + .build(); + BigQueryWriteClient client = BigQueryWriteClient.create(settings); + StreamWriter writer = StreamWriter.newBuilder(TEST_STREAM, client).build(); + writer.close(); + assertFalse(client.isShutdown()); + client.shutdown(); + client.awaitTermination(1, TimeUnit.MINUTES); + } + + @Test + public void testFlushAll() throws Exception { + StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(2L) + .setDelayThreshold(Duration.ofSeconds(100000)) + .build()) + .build(); + + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(0).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(2).build()); + testBigQueryWrite.addResponse(AppendRowsResponse.newBuilder().setOffset(3).build()); + + ApiFuture appendFuture1 = sendTestMessage(writer, new String[] {"A"}); + ApiFuture appendFuture2 = sendTestMessage(writer, new String[] {"B"}); + ApiFuture appendFuture3 = sendTestMessage(writer, new String[] {"C"}); + + assertFalse(appendFuture3.isDone()); + writer.flushAll(100000); + + assertTrue(appendFuture3.isDone()); + + writer.close(); + } + + @Test + public void testFlushAllFailed() throws Exception { + StreamWriter writer = + getTestStreamWriterBuilder() + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(2L) + .setDelayThreshold(Duration.ofSeconds(100000)) + .build()) + .build(); + + testBigQueryWrite.addException(Status.DATA_LOSS.asException()); + + ApiFuture appendFuture1 = sendTestMessage(writer, new String[] {"A"}); + ApiFuture appendFuture2 = sendTestMessage(writer, new String[] {"B"}); + ApiFuture appendFuture3 = sendTestMessage(writer, new String[] {"C"}); + + assertFalse(appendFuture3.isDone()); + try { + writer.flushAll(100000); + fail("Should have thrown an Exception"); + } catch (Exception expected) { + if (expected.getCause() instanceof com.google.api.gax.rpc.DataLossException + | expected instanceof java.lang.InterruptedException) { + LOG.info("got: " + expected.toString()); + if (expected instanceof java.lang.InterruptedException) { + LOG.warning("Test return ealy due to InterruptedException"); + return; + } + } else { + fail("Unexpected exception:" + expected.toString()); + } + } + + assertTrue(appendFuture3.isDone()); + + writer.close(); + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/it/ITBigQueryWriteManualClientTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/it/ITBigQueryWriteManualClientTest.java new file mode 100644 index 0000000000..af4c58341a --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/it/ITBigQueryWriteManualClientTest.java @@ -0,0 +1,571 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.bigquery.storage.v1beta2.it; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.api.core.ApiFuture; +import com.google.cloud.ServiceOptions; +import com.google.cloud.bigquery.*; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.storage.test.Test.*; +import com.google.cloud.bigquery.storage.v1beta2.*; +import com.google.cloud.bigquery.testing.RemoteBigQueryHelper; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Int64Value; +import java.io.IOException; +import java.util.*; +import java.util.concurrent.*; +import java.util.logging.Logger; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.threeten.bp.Duration; + +/** Integration tests for BigQuery Write API. */ +public class ITBigQueryWriteManualClientTest { + private static final Logger LOG = + Logger.getLogger(ITBigQueryWriteManualClientTest.class.getName()); + private static final String DATASET = RemoteBigQueryHelper.generateDatasetName(); + private static final String TABLE = "testtable"; + private static final String TABLE2 = "complicatedtable"; + private static final String DESCRIPTION = "BigQuery Write Java manual client test dataset"; + + private static BigQueryWriteClient client; + private static TableInfo tableInfo; + private static TableInfo tableInfo2; + private static String tableId; + private static String tableId2; + private static BigQuery bigquery; + + @BeforeClass + public static void beforeClass() throws IOException { + client = BigQueryWriteClient.create(); + + RemoteBigQueryHelper bigqueryHelper = RemoteBigQueryHelper.create(); + bigquery = bigqueryHelper.getOptions().getService(); + DatasetInfo datasetInfo = + DatasetInfo.newBuilder(/* datasetId = */ DATASET).setDescription(DESCRIPTION).build(); + bigquery.create(datasetInfo); + LOG.info("Created test dataset: " + DATASET); + tableInfo = + TableInfo.newBuilder( + TableId.of(DATASET, TABLE), + StandardTableDefinition.of( + Schema.of( + com.google.cloud.bigquery.Field.newBuilder("foo", LegacySQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build()))) + .build(); + com.google.cloud.bigquery.Field.Builder innerTypeFieldBuilder = + com.google.cloud.bigquery.Field.newBuilder( + "inner_type", + LegacySQLTypeName.RECORD, + com.google.cloud.bigquery.Field.newBuilder("value", LegacySQLTypeName.STRING) + .setMode(Field.Mode.REPEATED) + .build()); + + tableInfo2 = + TableInfo.newBuilder( + TableId.of(DATASET, TABLE2), + StandardTableDefinition.of( + Schema.of( + Field.newBuilder( + "nested_repeated_type", + LegacySQLTypeName.RECORD, + innerTypeFieldBuilder.setMode(Field.Mode.REPEATED).build()) + .setMode(Field.Mode.REPEATED) + .build(), + innerTypeFieldBuilder.setMode(Field.Mode.NULLABLE).build()))) + .build(); + bigquery.create(tableInfo); + bigquery.create(tableInfo2); + tableId = + String.format( + "projects/%s/datasets/%s/tables/%s", + ServiceOptions.getDefaultProjectId(), DATASET, TABLE); + tableId2 = + String.format( + "projects/%s/datasets/%s/tables/%s", + ServiceOptions.getDefaultProjectId(), DATASET, TABLE2); + } + + @AfterClass + public static void afterClass() { + if (client != null) { + client.close(); + } + + if (bigquery != null) { + RemoteBigQueryHelper.forceDelete(bigquery, DATASET); + LOG.info("Deleted test dataset: " + DATASET); + } + } + + private AppendRowsRequest.Builder createAppendRequest(String streamName, String[] messages) { + AppendRowsRequest.Builder requestBuilder = AppendRowsRequest.newBuilder(); + + AppendRowsRequest.ProtoData.Builder dataBuilder = AppendRowsRequest.ProtoData.newBuilder(); + dataBuilder.setWriterSchema(ProtoSchemaConverter.convert(FooType.getDescriptor())); + + ProtoRows.Builder rows = ProtoRows.newBuilder(); + for (String message : messages) { + FooType foo = FooType.newBuilder().setFoo(message).build(); + rows.addSerializedRows(foo.toByteString()); + } + dataBuilder.setRows(rows.build()); + return requestBuilder.setProtoRows(dataBuilder.build()).setWriteStream(streamName); + } + + private AppendRowsRequest.Builder createAppendRequestComplicateType( + String streamName, String[] messages) { + AppendRowsRequest.Builder requestBuilder = AppendRowsRequest.newBuilder(); + + AppendRowsRequest.ProtoData.Builder dataBuilder = AppendRowsRequest.ProtoData.newBuilder(); + dataBuilder.setWriterSchema(ProtoSchemaConverter.convert(ComplicateType.getDescriptor())); + + ProtoRows.Builder rows = ProtoRows.newBuilder(); + for (String message : messages) { + ComplicateType foo = + ComplicateType.newBuilder() + .setInnerType(InnerType.newBuilder().addValue(message).addValue(message).build()) + .build(); + rows.addSerializedRows(foo.toByteString()); + } + dataBuilder.setRows(rows.build()); + return requestBuilder.setProtoRows(dataBuilder.build()).setWriteStream(streamName); + } + + @Test + public void testBatchWriteWithCommittedStream() + throws IOException, InterruptedException, ExecutionException { + WriteStream writeStream = + client.createWriteStream( + CreateWriteStreamRequest.newBuilder() + .setParent(tableId) + .setWriteStream( + WriteStream.newBuilder().setType(WriteStream.Type.COMMITTED).build()) + .build()); + try (StreamWriter streamWriter = + StreamWriter.newBuilder(writeStream.getName()) + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setRequestByteThreshold(1024 * 1024L) // 1 Mb + .setElementCountThreshold(2L) + .setDelayThreshold(Duration.ofSeconds(2)) + .build()) + .build()) { + LOG.info("Sending one message"); + ApiFuture response = + streamWriter.append( + createAppendRequest(writeStream.getName(), new String[] {"aaa"}).build()); + assertEquals(0, response.get().getOffset()); + + LOG.info("Sending two more messages"); + ApiFuture response1 = + streamWriter.append( + createAppendRequest(writeStream.getName(), new String[] {"bbb", "ccc"}).build()); + ApiFuture response2 = + streamWriter.append( + createAppendRequest(writeStream.getName(), new String[] {"ddd"}).build()); + // Waiting for API breaking change to be generated in new client. + // assertEquals(1, response1.get().getOffset()); + // assertEquals(3, response2.get().getOffset()); + response2.get(); + + TableResult result = + bigquery.listTableData( + tableInfo.getTableId(), BigQuery.TableDataListOption.startIndex(0L)); + Iterator iter = result.getValues().iterator(); + assertEquals("aaa", iter.next().get(0).getStringValue()); + assertEquals("bbb", iter.next().get(0).getStringValue()); + assertEquals("ccc", iter.next().get(0).getStringValue()); + assertEquals("ddd", iter.next().get(0).getStringValue()); + assertEquals(false, iter.hasNext()); + } + } + + @Test + public void testJsonStreamWriterBatchWriteWithCommittedStream() + throws IOException, InterruptedException, ExecutionException, + Descriptors.DescriptorValidationException { + String tableName = "JsonTable"; + TableInfo tableInfo = + TableInfo.newBuilder( + TableId.of(DATASET, tableName), + StandardTableDefinition.of( + Schema.of( + com.google.cloud.bigquery.Field.newBuilder("foo", LegacySQLTypeName.STRING) + .build()))) + .build(); + bigquery.create(tableInfo); + TableName parent = TableName.of(ServiceOptions.getDefaultProjectId(), DATASET, tableName); + WriteStream writeStream = + client.createWriteStream( + CreateWriteStreamRequest.newBuilder() + .setParent(parent.toString()) + .setWriteStream( + WriteStream.newBuilder().setType(WriteStream.Type.COMMITTED).build()) + .build()); + try (JsonStreamWriter jsonStreamWriter = + JsonStreamWriter.newBuilder(writeStream.getName(), writeStream.getTableSchema()) + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setRequestByteThreshold(1024 * 1024L) // 1 Mb + .setElementCountThreshold(2L) + .setDelayThreshold(Duration.ofSeconds(2)) + .build()) + .build()) { + LOG.info("Sending one message"); + JSONObject foo = new JSONObject(); + foo.put("foo", "aaa"); + JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + + ApiFuture response = + jsonStreamWriter.append(jsonArr, -1, /* allowUnknownFields */ false); + assertEquals(0, response.get().getOffset()); + + LOG.info("Sending two more messages"); + JSONObject foo1 = new JSONObject(); + foo1.put("foo", "bbb"); + JSONObject foo2 = new JSONObject(); + foo2.put("foo", "ccc"); + JSONArray jsonArr1 = new JSONArray(); + jsonArr1.put(foo1); + jsonArr1.put(foo2); + + JSONObject foo3 = new JSONObject(); + foo3.put("foo", "ddd"); + JSONArray jsonArr2 = new JSONArray(); + jsonArr2.put(foo3); + + ApiFuture response1 = + jsonStreamWriter.append(jsonArr1, -1, /* allowUnknownFields */ false); + ApiFuture response2 = + jsonStreamWriter.append(jsonArr2, -1, /* allowUnknownFields */ false); + // Waiting for API breaking change to be generated in new client. + // assertEquals(1, response1.get().getOffset()); + // assertEquals(3, response2.get().getOffset()); + response2.get(); + + TableResult result = + bigquery.listTableData( + tableInfo.getTableId(), BigQuery.TableDataListOption.startIndex(0L)); + Iterator iter = result.getValues().iterator(); + assertEquals("aaa", iter.next().get(0).getStringValue()); + assertEquals("bbb", iter.next().get(0).getStringValue()); + assertEquals("ccc", iter.next().get(0).getStringValue()); + assertEquals("ddd", iter.next().get(0).getStringValue()); + assertEquals(false, iter.hasNext()); + jsonStreamWriter.close(); + } + } + + @Test + public void testJsonStreamWriterSchemaUpdate() + throws IOException, InterruptedException, ExecutionException, + Descriptors.DescriptorValidationException { + String tableName = "SchemaUpdateTable"; + TableInfo tableInfo = + TableInfo.newBuilder( + TableId.of(DATASET, tableName), + StandardTableDefinition.of( + Schema.of( + com.google.cloud.bigquery.Field.newBuilder("foo", LegacySQLTypeName.STRING) + .build()))) + .build(); + + bigquery.create(tableInfo); + TableName parent = TableName.of(ServiceOptions.getDefaultProjectId(), DATASET, tableName); + WriteStream writeStream = + client.createWriteStream( + CreateWriteStreamRequest.newBuilder() + .setParent(parent.toString()) + .setWriteStream( + WriteStream.newBuilder().setType(WriteStream.Type.COMMITTED).build()) + .build()); + + try (JsonStreamWriter jsonStreamWriter = + JsonStreamWriter.newBuilder(writeStream.getName(), writeStream.getTableSchema()) + .setBatchingSettings( + StreamWriter.Builder.DEFAULT_BATCHING_SETTINGS + .toBuilder() + .setElementCountThreshold(1L) + .build()) + .build()) { + // 1). Send 1 row + JSONObject foo = new JSONObject(); + foo.put("foo", "aaa"); + JSONArray jsonArr = new JSONArray(); + jsonArr.put(foo); + + ApiFuture response = + jsonStreamWriter.append(jsonArr, -1, /* allowUnknownFields */ false); + assertEquals(0, response.get().getOffset()); + // 2). Schema update and wait until querying it returns a new schema. + try { + com.google.cloud.bigquery.Table table = bigquery.getTable(DATASET, tableName); + Schema schema = table.getDefinition().getSchema(); + FieldList fields = schema.getFields(); + Field newField = + Field.newBuilder("bar", LegacySQLTypeName.STRING).setMode(Field.Mode.NULLABLE).build(); + + List fieldList = new ArrayList(); + fieldList.add(fields.get(0)); + fieldList.add(newField); + Schema newSchema = Schema.of(fieldList); + // Update the table with the new schema + com.google.cloud.bigquery.Table updatedTable = + table.toBuilder().setDefinition(StandardTableDefinition.of(newSchema)).build(); + updatedTable.update(); + int millis = 0; + while (millis <= 10000) { + if (newSchema.equals(table.reload().getDefinition().getSchema())) { + break; + } + Thread.sleep(1000); + millis += 1000; + } + newSchema = schema; + LOG.info( + "bar column successfully added to table in " + + millis + + " millis: " + + bigquery.getTable(DATASET, tableName).getDefinition().getSchema()); + } catch (BigQueryException e) { + LOG.severe("bar column was not added. \n" + e.toString()); + } + // 3). Send rows to wait for updatedSchema to be returned. + JSONObject foo2 = new JSONObject(); + foo2.put("foo", "bbb"); + JSONArray jsonArr2 = new JSONArray(); + jsonArr2.put(foo2); + + int next = 0; + for (int i = 1; i < 100; i++) { + ApiFuture response2 = + jsonStreamWriter.append(jsonArr2, -1, /* allowUnknownFields */ false); + // Waiting for API breaking change to be generated in new client. + // assertEquals(i, response2.get().getOffset()); + if (response2.get().hasUpdatedSchema()) { + next = i; + break; + } else { + Thread.sleep(1000); + } + } + + int millis = 0; + while (millis <= 10000) { + if (jsonStreamWriter.getDescriptor().getFields().size() == 2) { + LOG.info("JsonStreamWriter successfully updated internal descriptor!"); + break; + } + Thread.sleep(100); + millis += 100; + } + assertTrue(jsonStreamWriter.getDescriptor().getFields().size() == 2); + // 4). Send rows with updated schema. + JSONObject updatedFoo = new JSONObject(); + updatedFoo.put("foo", "ccc"); + updatedFoo.put("bar", "ddd"); + JSONArray updatedJsonArr = new JSONArray(); + updatedJsonArr.put(updatedFoo); + for (int i = 0; i < 10; i++) { + ApiFuture response3 = + jsonStreamWriter.append(updatedJsonArr, -1, /* allowUnknownFields */ false); + // Waiting for API breaking change to be generated in new client. + // assertEquals(next + 1 + i, response3.get().getOffset()); + response3.get(); + } + + TableResult result3 = + bigquery.listTableData( + tableInfo.getTableId(), BigQuery.TableDataListOption.startIndex(0L)); + Iterator iter3 = result3.getValues().iterator(); + assertEquals("aaa", iter3.next().get(0).getStringValue()); + for (int j = 1; j <= next; j++) { + assertEquals("bbb", iter3.next().get(0).getStringValue()); + } + for (int j = next + 1; j < next + 1 + 10; j++) { + FieldValueList temp = iter3.next(); + assertEquals("ccc", temp.get(0).getStringValue()); + assertEquals("ddd", temp.get(1).getStringValue()); + } + assertEquals(false, iter3.hasNext()); + } + } + + @Test + public void testComplicateSchemaWithPendingStream() + throws IOException, InterruptedException, ExecutionException { + WriteStream writeStream = + client.createWriteStream( + CreateWriteStreamRequest.newBuilder() + .setParent(tableId2) + .setWriteStream(WriteStream.newBuilder().setType(WriteStream.Type.PENDING).build()) + .build()); + try (StreamWriter streamWriter = StreamWriter.newBuilder(writeStream.getName()).build()) { + LOG.info("Sending two messages"); + ApiFuture response = + streamWriter.append( + createAppendRequestComplicateType(writeStream.getName(), new String[] {"aaa"}) + .setOffset(Int64Value.of(0L)) + .build()); + assertEquals(0, response.get().getOffset()); + + ApiFuture response2 = + streamWriter.append( + createAppendRequestComplicateType(writeStream.getName(), new String[] {"bbb"}) + .setOffset(Int64Value.of(1L)) + .build()); + // Waiting for API breaking change to be generated in new client. + // assertEquals(1, response2.get().getOffset()); + + // Nothing showed up since rows are not committed. + TableResult result = + bigquery.listTableData( + tableInfo2.getTableId(), BigQuery.TableDataListOption.startIndex(0L)); + Iterator iter = result.getValues().iterator(); + assertEquals(false, iter.hasNext()); + + FinalizeWriteStreamResponse finalizeResponse = + client.finalizeWriteStream( + FinalizeWriteStreamRequest.newBuilder().setName(writeStream.getName()).build()); + + ApiFuture response3 = + streamWriter.append( + createAppendRequestComplicateType(writeStream.getName(), new String[] {"ccc"}) + .setOffset(Int64Value.of(1L)) + .build()); + try { + // Waiting for API breaking change to be generated in new client. + // assertEquals(2, response3.get().getOffset()); + // fail("Append to finalized stream should fail."); + } catch (Exception expected) { + // The exception thrown is not stable. Opened a bug to fix it. + } + } + // Finalize row count is not populated. + // assertEquals(1, finalizeResponse.getRowCount()); + BatchCommitWriteStreamsResponse batchCommitWriteStreamsResponse = + client.batchCommitWriteStreams( + BatchCommitWriteStreamsRequest.newBuilder() + .setParent(tableId2) + .addWriteStreams(writeStream.getName()) + .build()); + assertEquals(true, batchCommitWriteStreamsResponse.hasCommitTime()); + TableResult queryResult = + bigquery.query( + QueryJobConfiguration.newBuilder("SELECT * from " + DATASET + '.' + TABLE2).build()); + Iterator queryIter = queryResult.getValues().iterator(); + assertTrue(queryIter.hasNext()); + assertEquals( + "[FieldValue{attribute=REPEATED, value=[FieldValue{attribute=PRIMITIVE, value=aaa}, FieldValue{attribute=PRIMITIVE, value=aaa}]}]", + queryIter.next().get(1).getRepeatedValue().toString()); + assertEquals( + "[FieldValue{attribute=REPEATED, value=[FieldValue{attribute=PRIMITIVE, value=bbb}, FieldValue{attribute=PRIMITIVE, value=bbb}]}]", + queryIter.next().get(1).getRepeatedValue().toString()); + assertFalse(queryIter.hasNext()); + } + + @Test + public void testStreamError() throws IOException, InterruptedException, ExecutionException { + WriteStream writeStream = + client.createWriteStream( + CreateWriteStreamRequest.newBuilder() + .setParent(tableId) + .setWriteStream( + WriteStream.newBuilder().setType(WriteStream.Type.COMMITTED).build()) + .build()); + try (StreamWriter streamWriter = StreamWriter.newBuilder(writeStream.getName()).build()) { + AppendRowsRequest request = + createAppendRequest(writeStream.getName(), new String[] {"aaa"}).build(); + request + .toBuilder() + .setProtoRows(request.getProtoRows().toBuilder().clearWriterSchema().build()) + .build(); + ApiFuture response = + streamWriter.append( + createAppendRequest(writeStream.getName(), new String[] {"aaa"}).build()); + // Waiting for API breaking change to be generated in new client. + // assertEquals(0L, response.get().getOffset()); + response.get(); + // Send in a bogus stream name should cause in connection error. + ApiFuture response2 = + streamWriter.append( + createAppendRequest(writeStream.getName(), new String[] {"aaa"}) + .setOffset(Int64Value.of(100L)) + .build()); + try { + response2.get(); + Assert.fail("Should fail"); + } catch (ExecutionException e) { + assertThat(e.getCause().getMessage()) + .contains("OUT_OF_RANGE: The offset is beyond stream, expected offset 1, received 100"); + } + // We can keep sending requests on the same stream. + ApiFuture response3 = + streamWriter.append( + createAppendRequest(writeStream.getName(), new String[] {"aaa"}).build()); + // Waiting for API breaking change to be generated in new client. + // assertEquals(1L, response3.get().getOffset()); + } finally { + } + } + + @Test + public void testStreamReconnect() throws IOException, InterruptedException, ExecutionException { + WriteStream writeStream = + client.createWriteStream( + CreateWriteStreamRequest.newBuilder() + .setParent(tableId) + .setWriteStream( + WriteStream.newBuilder().setType(WriteStream.Type.COMMITTED).build()) + .build()); + try (StreamWriter streamWriter = StreamWriter.newBuilder(writeStream.getName()).build()) { + ApiFuture response = + streamWriter.append( + createAppendRequest(writeStream.getName(), new String[] {"aaa"}) + .setOffset(Int64Value.of(0L)) + .build()); + // Waiting for API breaking change to be generated in new client. + // assertEquals(0L, response.get().getOffset()); + } + + try (StreamWriter streamWriter = StreamWriter.newBuilder(writeStream.getName()).build()) { + // Currently there is a bug that reconnection must wait 5 seconds to get the real row count. + Thread.sleep(5000L); + ApiFuture response = + streamWriter.append( + createAppendRequest(writeStream.getName(), new String[] {"bbb"}) + .setOffset(Int64Value.of(1L)) + .build()); + // Waiting for API breaking change to be generated in new client. + // assertEquals(1L, response.get().getOffset()); + } + } +}