diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/BQTableSchemaToProtoDescriptor.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/BQTableSchemaToProtoDescriptor.java new file mode 100644 index 0000000000..946d2bc7c8 --- /dev/null +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/BQTableSchemaToProtoDescriptor.java @@ -0,0 +1,149 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1alpha2; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.DescriptorProtos.DescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto; +import com.google.protobuf.DescriptorProtos.FileDescriptorProto; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +/** + * Converts a BQ table schema to protobuf descriptor. The mapping between field types and field + * modes are shown in the ImmutableMaps below. + */ +public class BQTableSchemaToProtoDescriptor { + private static ImmutableMap + BQTableSchemaModeMap = + ImmutableMap.of( + Table.TableFieldSchema.Mode.NULLABLE, FieldDescriptorProto.Label.LABEL_OPTIONAL, + Table.TableFieldSchema.Mode.REPEATED, FieldDescriptorProto.Label.LABEL_REPEATED, + Table.TableFieldSchema.Mode.REQUIRED, FieldDescriptorProto.Label.LABEL_REQUIRED); + + private static ImmutableMap + BQTableSchemaTypeMap = + new ImmutableMap.Builder() + .put(Table.TableFieldSchema.Type.BOOL, FieldDescriptorProto.Type.TYPE_BOOL) + .put(Table.TableFieldSchema.Type.BYTES, FieldDescriptorProto.Type.TYPE_BYTES) + .put(Table.TableFieldSchema.Type.DATE, FieldDescriptorProto.Type.TYPE_INT64) + .put(Table.TableFieldSchema.Type.DATETIME, FieldDescriptorProto.Type.TYPE_INT64) + .put(Table.TableFieldSchema.Type.DOUBLE, FieldDescriptorProto.Type.TYPE_DOUBLE) + .put(Table.TableFieldSchema.Type.GEOGRAPHY, FieldDescriptorProto.Type.TYPE_BYTES) + .put(Table.TableFieldSchema.Type.INT64, FieldDescriptorProto.Type.TYPE_INT64) + .put(Table.TableFieldSchema.Type.NUMERIC, FieldDescriptorProto.Type.TYPE_BYTES) + .put(Table.TableFieldSchema.Type.STRING, FieldDescriptorProto.Type.TYPE_STRING) + .put(Table.TableFieldSchema.Type.STRUCT, FieldDescriptorProto.Type.TYPE_MESSAGE) + .put(Table.TableFieldSchema.Type.TIME, FieldDescriptorProto.Type.TYPE_INT64) + .put(Table.TableFieldSchema.Type.TIMESTAMP, FieldDescriptorProto.Type.TYPE_INT64) + .build(); + + /** + * Converts Table.TableSchema to a Descriptors.Descriptor object. + * + * @param BQTableSchema + * @throws Descriptors.DescriptorValidationException + */ + public static Descriptor ConvertBQTableSchemaToProtoDescriptor(Table.TableSchema BQTableSchema) + throws Descriptors.DescriptorValidationException { + return ConvertBQTableSchemaToProtoDescriptorImpl( + BQTableSchema, "root", new HashMap, Descriptor>()); + } + + /** + * Converts a Table.TableSchema to a Descriptors.Descriptor object. + * + * @param BQTableSchema + * @param scope Keeps track of current scope to prevent repeated naming while constructing + * descriptor. + * @param dependencyMap Stores already constructed descriptors to prevent reconstruction + * @throws Descriptors.DescriptorValidationException + */ + private static Descriptor ConvertBQTableSchemaToProtoDescriptorImpl( + Table.TableSchema BQTableSchema, + String scope, + HashMap, Descriptor> dependencyMap) + throws Descriptors.DescriptorValidationException { + List dependenciesList = new ArrayList(); + List fields = new ArrayList(); + int index = 1; + for (Table.TableFieldSchema BQTableField : BQTableSchema.getFieldsList()) { + String currentScope = scope + "__" + BQTableField.getName(); + if (BQTableField.getType() == Table.TableFieldSchema.Type.STRUCT) { + ImmutableList fieldList = + ImmutableList.copyOf(BQTableField.getFieldsList()); + if (dependencyMap.containsKey(fieldList)) { + Descriptor descriptor = dependencyMap.get(fieldList); + dependenciesList.add(descriptor.getFile()); + fields.add(ConvertBQTableFieldToProtoField(BQTableField, index++, descriptor.getName())); + } else { + Descriptor descriptor = + ConvertBQTableSchemaToProtoDescriptorImpl( + Table.TableSchema.newBuilder().addAllFields(fieldList).build(), + currentScope, + dependencyMap); + dependenciesList.add(descriptor.getFile()); + dependencyMap.put(fieldList, descriptor); + fields.add(ConvertBQTableFieldToProtoField(BQTableField, index++, currentScope)); + } + } else { + fields.add(ConvertBQTableFieldToProtoField(BQTableField, index++, currentScope)); + } + } + FileDescriptor[] dependenciesArray = new FileDescriptor[dependenciesList.size()]; + dependenciesArray = dependenciesList.toArray(dependenciesArray); + DescriptorProto descriptorProto = + DescriptorProto.newBuilder().setName(scope).addAllField(fields).build(); + FileDescriptorProto fileDescriptorProto = + FileDescriptorProto.newBuilder().addMessageType(descriptorProto).build(); + FileDescriptor fileDescriptor = + FileDescriptor.buildFrom(fileDescriptorProto, dependenciesArray); + Descriptor descriptor = fileDescriptor.findMessageTypeByName(scope); + return descriptor; + } + + /** + * Converts a BQTableField to ProtoField + * + * @param BQTableField BQ Field used to construct a FieldDescriptorProto + * @param index Index for protobuf fields. + * @param scope used to name descriptors + */ + private static FieldDescriptorProto ConvertBQTableFieldToProtoField( + Table.TableFieldSchema BQTableField, int index, String scope) { + Table.TableFieldSchema.Mode mode = BQTableField.getMode(); + String fieldName = BQTableField.getName(); + if (BQTableField.getType() == Table.TableFieldSchema.Type.STRUCT) { + return FieldDescriptorProto.newBuilder() + .setName(fieldName) + .setTypeName(scope) + .setLabel((FieldDescriptorProto.Label) BQTableSchemaModeMap.get(mode)) + .setNumber(index) + .build(); + } + return FieldDescriptorProto.newBuilder() + .setName(fieldName) + .setType((FieldDescriptorProto.Type) BQTableSchemaTypeMap.get(BQTableField.getType())) + .setLabel((FieldDescriptorProto.Label) BQTableSchemaModeMap.get(mode)) + .setNumber(index) + .build(); + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/BQTableSchemaToProtoDescriptorTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/BQTableSchemaToProtoDescriptorTest.java new file mode 100644 index 0000000000..e2cb04d1f4 --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/BQTableSchemaToProtoDescriptorTest.java @@ -0,0 +1,280 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1alpha2; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import com.google.cloud.bigquery.storage.test.JsonTest.*; +import com.google.cloud.bigquery.storage.test.SchemaTest.*; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class BQTableSchemaToProtoDescriptorTest { + // This is a map between the Table.TableFieldSchema.Type and the descriptor it is supposed to + // produce. The produced descriptor will be used to check against the entry values here. + private static ImmutableMap + BQTableTypeToCorrectProtoDescriptorTest = + new ImmutableMap.Builder() + .put(Table.TableFieldSchema.Type.BOOL, BoolType.getDescriptor()) + .put(Table.TableFieldSchema.Type.BYTES, BytesType.getDescriptor()) + .put(Table.TableFieldSchema.Type.DATE, Int64Type.getDescriptor()) + .put(Table.TableFieldSchema.Type.DATETIME, Int64Type.getDescriptor()) + .put(Table.TableFieldSchema.Type.DOUBLE, DoubleType.getDescriptor()) + .put(Table.TableFieldSchema.Type.GEOGRAPHY, BytesType.getDescriptor()) + .put(Table.TableFieldSchema.Type.INT64, Int64Type.getDescriptor()) + .put(Table.TableFieldSchema.Type.NUMERIC, BytesType.getDescriptor()) + .put(Table.TableFieldSchema.Type.STRING, StringType.getDescriptor()) + .put(Table.TableFieldSchema.Type.TIME, Int64Type.getDescriptor()) + .put(Table.TableFieldSchema.Type.TIMESTAMP, Int64Type.getDescriptor()) + .build(); + + // Creates mapping from descriptor to how many times it was reused. + private void mapDescriptorToCount(Descriptor descriptor, HashMap map) { + for (FieldDescriptor field : descriptor.getFields()) { + if (field.getType() == FieldDescriptor.Type.MESSAGE) { + Descriptor subDescriptor = field.getMessageType(); + String messageName = subDescriptor.getName(); + if (map.containsKey(messageName)) { + map.put(messageName, map.get(messageName) + 1); + } else { + map.put(messageName, 1); + } + mapDescriptorToCount(subDescriptor, map); + } + } + } + + private void isDescriptorEqual(Descriptor convertedProto, Descriptor originalProto) { + // Check same number of fields + assertEquals(convertedProto.getFields().size(), originalProto.getFields().size()); + for (FieldDescriptor convertedField : convertedProto.getFields()) { + // Check field name + FieldDescriptor originalField = originalProto.findFieldByName(convertedField.getName()); + assertNotNull(originalField); + // Check type + FieldDescriptor.Type convertedType = convertedField.getType(); + FieldDescriptor.Type originalType = originalField.getType(); + assertEquals(convertedType, originalType); + // Check mode + assertTrue( + (originalField.isRepeated() == convertedField.isRepeated()) + && (originalField.isRequired() == convertedField.isRequired()) + && (originalField.isOptional() == convertedField.isOptional())); + // Recursively check nested messages + if (convertedType == FieldDescriptor.Type.MESSAGE) { + isDescriptorEqual(convertedField.getMessageType(), originalField.getMessageType()); + } + } + } + + @Test + public void testSimpleTypes() throws Exception { + for (Map.Entry entry : + BQTableTypeToCorrectProtoDescriptorTest.entrySet()) { + final Table.TableFieldSchema tableFieldSchema = + Table.TableFieldSchema.newBuilder() + .setType(entry.getKey()) + .setMode(Table.TableFieldSchema.Mode.NULLABLE) + .setName("test_field_type") + .build(); + final Table.TableSchema tableSchema = + Table.TableSchema.newBuilder().addFields(0, tableFieldSchema).build(); + final Descriptor descriptor = + BQTableSchemaToProtoDescriptor.ConvertBQTableSchemaToProtoDescriptor(tableSchema); + isDescriptorEqual(descriptor, entry.getValue()); + } + } + + @Test + public void testStructSimple() throws Exception { + final Table.TableFieldSchema StringType = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.STRING) + .setMode(Table.TableFieldSchema.Mode.NULLABLE) + .setName("test_field_type") + .build(); + final Table.TableFieldSchema tableFieldSchema = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.STRUCT) + .setMode(Table.TableFieldSchema.Mode.NULLABLE) + .setName("test_field_type") + .addFields(0, StringType) + .build(); + final Table.TableSchema tableSchema = + Table.TableSchema.newBuilder().addFields(0, tableFieldSchema).build(); + final Descriptor descriptor = + BQTableSchemaToProtoDescriptor.ConvertBQTableSchemaToProtoDescriptor(tableSchema); + isDescriptorEqual(descriptor, MessageType.getDescriptor()); + } + + @Test + public void testStructComplex() throws Exception { + final Table.TableFieldSchema test_int = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.INT64) + .setMode(Table.TableFieldSchema.Mode.NULLABLE) + .setName("test_int") + .build(); + final Table.TableFieldSchema test_string = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.STRING) + .setMode(Table.TableFieldSchema.Mode.REPEATED) + .setName("test_string") + .build(); + final Table.TableFieldSchema test_bytes = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.BYTES) + .setMode(Table.TableFieldSchema.Mode.REQUIRED) + .setName("test_bytes") + .build(); + final Table.TableFieldSchema test_bool = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.BOOL) + .setMode(Table.TableFieldSchema.Mode.NULLABLE) + .setName("test_bool") + .build(); + final Table.TableFieldSchema test_double = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.DOUBLE) + .setMode(Table.TableFieldSchema.Mode.REPEATED) + .setName("test_double") + .build(); + final Table.TableFieldSchema ComplexLvl2 = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.STRUCT) + .setMode(Table.TableFieldSchema.Mode.REQUIRED) + .addFields(0, test_int) + .setName("complexLvl2") + .build(); + final Table.TableFieldSchema ComplexLvl1 = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.STRUCT) + .setMode(Table.TableFieldSchema.Mode.REQUIRED) + .addFields(0, test_int) + .addFields(1, ComplexLvl2) + .setName("complexLvl1") + .build(); + final Table.TableSchema tableSchema = + Table.TableSchema.newBuilder() + .addFields(0, test_int) + .addFields(1, test_string) + .addFields(2, test_bytes) + .addFields(3, test_bool) + .addFields(4, test_double) + .addFields(5, ComplexLvl1) + .addFields(6, ComplexLvl2) + .build(); + final Descriptor descriptor = + BQTableSchemaToProtoDescriptor.ConvertBQTableSchemaToProtoDescriptor(tableSchema); + isDescriptorEqual(descriptor, ComplexRoot.getDescriptor()); + } + + @Test + public void testOptions() throws Exception { + final Table.TableFieldSchema required = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.INT64) + .setMode(Table.TableFieldSchema.Mode.REQUIRED) + .setName("test_required") + .build(); + final Table.TableFieldSchema repeated = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.INT64) + .setMode(Table.TableFieldSchema.Mode.REPEATED) + .setName("test_repeated") + .build(); + final Table.TableFieldSchema optional = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.INT64) + .setMode(Table.TableFieldSchema.Mode.NULLABLE) + .setName("test_optional") + .build(); + final Table.TableSchema tableSchema = + Table.TableSchema.newBuilder() + .addFields(0, required) + .addFields(1, repeated) + .addFields(2, optional) + .build(); + final Descriptor descriptor = + BQTableSchemaToProtoDescriptor.ConvertBQTableSchemaToProtoDescriptor(tableSchema); + isDescriptorEqual(descriptor, OptionTest.getDescriptor()); + } + + @Test + public void testDescriptorReuseDuringCreation() throws Exception { + final Table.TableFieldSchema test_int = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.INT64) + .setMode(Table.TableFieldSchema.Mode.NULLABLE) + .setName("test_int") + .build(); + final Table.TableFieldSchema reuse_lvl2 = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.STRUCT) + .setMode(Table.TableFieldSchema.Mode.NULLABLE) + .setName("reuse_lvl2") + .addFields(0, test_int) + .build(); + final Table.TableFieldSchema reuse_lvl1 = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.STRUCT) + .setMode(Table.TableFieldSchema.Mode.NULLABLE) + .setName("reuse_lvl1") + .addFields(0, test_int) + .addFields(0, reuse_lvl2) + .build(); + final Table.TableFieldSchema reuse_lvl1_1 = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.STRUCT) + .setMode(Table.TableFieldSchema.Mode.NULLABLE) + .setName("reuse_lvl1_1") + .addFields(0, test_int) + .addFields(0, reuse_lvl2) + .build(); + final Table.TableFieldSchema reuse_lvl1_2 = + Table.TableFieldSchema.newBuilder() + .setType(Table.TableFieldSchema.Type.STRUCT) + .setMode(Table.TableFieldSchema.Mode.NULLABLE) + .setName("reuse_lvl1_2") + .addFields(0, test_int) + .addFields(0, reuse_lvl2) + .build(); + final Table.TableSchema tableSchema = + Table.TableSchema.newBuilder() + .addFields(0, reuse_lvl1) + .addFields(1, reuse_lvl1_1) + .addFields(2, reuse_lvl1_2) + .build(); + final Descriptor descriptor = + BQTableSchemaToProtoDescriptor.ConvertBQTableSchemaToProtoDescriptor(tableSchema); + HashMap descriptorToCount = new HashMap(); + mapDescriptorToCount(descriptor, descriptorToCount); + assertEquals(descriptorToCount.size(), 2); + assertTrue(descriptorToCount.containsKey("root__reuse_lvl1")); + assertEquals(descriptorToCount.get("root__reuse_lvl1").intValue(), 3); + assertTrue(descriptorToCount.containsKey("root__reuse_lvl1__reuse_lvl2")); + assertEquals(descriptorToCount.get("root__reuse_lvl1__reuse_lvl2").intValue(), 3); + isDescriptorEqual(descriptor, ReuseRoot.getDescriptor()); + } +} diff --git a/google-cloud-bigquerystorage/src/test/proto/jsonTest.proto b/google-cloud-bigquerystorage/src/test/proto/jsonTest.proto new file mode 100644 index 0000000000..c531e09096 --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/proto/jsonTest.proto @@ -0,0 +1,43 @@ +syntax = "proto2"; + +package com.google.cloud.bigquery.storage.test; + +message ComplexRoot { + optional int64 test_int = 1; + repeated string test_string = 2; + required bytes test_bytes = 3; + optional bool test_bool = 4; + repeated double test_double = 5; + required ComplexLvl1 complexLvl1 = 6; + required ComplexLvl2 complexLvl2 = 7; +} + +message ComplexLvl1 { + optional int64 test_int = 1; + required ComplexLvl2 complexLvl2 = 2; +} + +message ComplexLvl2 { + optional int64 test_int = 1; +} + +message OptionTest { + optional int64 test_optional = 1; + required int64 test_required = 2; + repeated int64 test_repeated = 3; +} + +message ReuseRoot { + optional ReuseLvl1 reuse_lvl1 = 1; + optional ReuseLvl1 reuse_lvl1_1 = 2; + optional ReuseLvl1 reuse_lvl1_2 = 3; +} + +message ReuseLvl1 { + optional int64 test_int = 1; + optional ReuseLvl2 reuse_lvl2 = 2; +} + +message ReuseLvl2 { + optional int64 test_int = 1; +}