Skip to content

Commit

Permalink
Add protobuf codegen decoder with unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rseetham committed May 8, 2024
1 parent 363a03e commit 8cadb65
Show file tree
Hide file tree
Showing 19 changed files with 2,161 additions and 107 deletions.
5 changes: 5 additions & 0 deletions pinot-plugins/pinot-input-format/pinot-protobuf/pom.xml
Expand Up @@ -50,6 +50,11 @@
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
</dependency>
<dependency>
<groupId>org.codehaus.janino</groupId>
<artifactId>janino</artifactId>
<version>3.1.6</version>
</dependency>
<dependency>
<groupId>com.github.os72</groupId>
<artifactId>protobuf-dynamic</artifactId>
Expand Down
@@ -0,0 +1,32 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.google.protobuf;

public class ProtobufInternalUtils {
private ProtobufInternalUtils() {
}
/**
* The protocol buffer compiler generates a set of accessor methods for each field defined within the message in the
* .proto file. The method name is determined by converting the .proto names to camel case using the
* SchemaUtil.toCamelCase(). We need this to generate the method name to call.
*/
public static String underScoreToCamelCase(String name, boolean capNext) {
return SchemaUtil.toCamelCase(name, capNext);
}
}
@@ -0,0 +1,121 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.plugin.inputformat.protobuf;

import com.google.common.base.Preconditions;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import java.io.File;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import org.apache.pinot.plugin.inputformat.protobuf.codegen.MessageCodeGen;
import org.apache.pinot.spi.data.readers.GenericRow;
import org.apache.pinot.spi.stream.StreamMessageDecoder;
import org.codehaus.janino.SimpleCompiler;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ProtoBufCodeGenMessageDecoder implements StreamMessageDecoder<byte[]> {
private static final Logger LOGGER =
LoggerFactory.getLogger(
org.apache.pinot.plugin.inputformat.protobuf.ProtoBufCodeGenMessageDecoder.class);

public static final String PROTOBUF_JAR_FILE_PATH = "jarFile";
public static final String PROTO_CLASS_NAME = "protoClassName";
public static final String EXTRACTOR_PACKAGE_NAME = "org.apache.pinot.plugin.inputformat.protobuf.decoder";
public static final String EXTRACTOR_CLASS_NAME = "ProtobufRecorderMessageExtractor";
public static final String EXTRACTOR_METHOD_NAME = "execute";
private Class _recordExtractor = ProtoBufMessageDecoder.class;
private Method _decodeMethod;

@Override
public void init(Map<String, String> props, Set<String> fieldsToRead, String topicName)
throws Exception {
Preconditions.checkState(
props.containsKey(PROTOBUF_JAR_FILE_PATH),
"Protocol Buffer schema jar file must be provided");
Preconditions.checkState(
props.containsKey(PROTO_CLASS_NAME),
"Protocol Buffer Message class name must be provided");
String protoClassName = props.getOrDefault(PROTO_CLASS_NAME, "");
String jarPath = props.getOrDefault(PROTOBUF_JAR_FILE_PATH, "");
ClassLoader protoMessageClsLoader = loadClass(jarPath);
Descriptors.Descriptor descriptor = getDescriptorForProtoClass(protoMessageClsLoader, protoClassName);
String codeGenCode = new MessageCodeGen().codegen(descriptor, fieldsToRead);
_recordExtractor = compileClass(
protoMessageClsLoader, EXTRACTOR_PACKAGE_NAME + "." + EXTRACTOR_CLASS_NAME, codeGenCode);
_decodeMethod = _recordExtractor.getMethod(EXTRACTOR_METHOD_NAME, byte[].class, GenericRow.class);
}

@Nullable
@Override
public GenericRow decode(byte[] payload, GenericRow destination) {
try {
destination = (GenericRow) _decodeMethod.invoke(null, payload, destination);
} catch (Exception e) {
throw new RuntimeException(e);
}
return destination;
}

@Nullable
@Override
public GenericRow decode(byte[] payload, int offset, int length, GenericRow destination) {
return decode(Arrays.copyOfRange(payload, offset, offset + length), destination);
}

public static ClassLoader loadClass(String jarFilePath) {
try {
File file = ProtoBufUtils.getFileCopiedToLocal(jarFilePath);
URL url = file.toURI().toURL();
URL[] urls = new URL[] {url};
return new URLClassLoader(urls);
} catch (Exception e) {
throw new RuntimeException("Error loading protobuf class", e);
}
}

private Class compileClass(ClassLoader classloader, String className, String code)
throws ClassNotFoundException {
SimpleCompiler simpleCompiler = new SimpleCompiler();
simpleCompiler.setParentClassLoader(classloader);
try {
simpleCompiler.cook(code);
} catch (Throwable t) {
System.out.println("Protobuf codegen compile error: \n" + code);
throw new RuntimeException(
"Program cannot be compiled. This is a bug. Please file an issue.", t);
}
return simpleCompiler.getClassLoader()
.loadClass(className);
}

public static Descriptors.Descriptor getDescriptorForProtoClass(ClassLoader protoMessageClsLoader,
String protoClassName)
throws NoSuchMethodException, ClassNotFoundException, InvocationTargetException, IllegalAccessException {
Class<? extends Message> updateMessage = (Class<Message>) protoMessageClsLoader.loadClass(protoClassName);
return (Descriptors.Descriptor) updateMessage.getMethod("getDescriptor").invoke(null);
}
}
Expand Up @@ -18,6 +18,8 @@
*/
package org.apache.pinot.plugin.inputformat.protobuf;

import com.google.protobuf.Descriptors;
import com.google.protobuf.ProtobufInternalUtils;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
Expand All @@ -29,17 +31,17 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class ProtoBufUtils {
private static final Logger LOGGER = LoggerFactory.getLogger(ProtoBufUtils.class);
public static final String TMP_DIR_PREFIX = "pinot-protobuf";
public static final String PB_OUTER_CLASS_SUFFIX = "OuterClass";

private ProtoBufUtils() {
}

public static InputStream getDescriptorFileInputStream(String descriptorFilePath)
public static File getFileCopiedToLocal(String filePath)
throws Exception {
URI descriptorFileURI = URI.create(descriptorFilePath);
URI descriptorFileURI = URI.create(filePath);
String scheme = descriptorFileURI.getScheme();
if (scheme == null) {
scheme = PinotFSFactory.LOCAL_PINOT_FS_SCHEME;
Expand All @@ -48,21 +50,158 @@ public static InputStream getDescriptorFileInputStream(String descriptorFilePath
PinotFS pinotFS = PinotFSFactory.create(scheme);
Path localTmpDir = Files.createTempDirectory(TMP_DIR_PREFIX + System.currentTimeMillis());
File protoDescriptorLocalFile = createLocalFile(descriptorFileURI, localTmpDir.toFile());
LOGGER.info("Copying protocol buffer descriptor file from source: {} to dst: {}", descriptorFilePath,
LOGGER.info("Copying protocol buffer jar/descriptor file from source: {} to dst: {}", filePath,
protoDescriptorLocalFile.getAbsolutePath());
pinotFS.copyToLocalFile(descriptorFileURI, protoDescriptorLocalFile);
return new FileInputStream(protoDescriptorLocalFile);
return protoDescriptorLocalFile;
} else {
throw new RuntimeException(String.format("Scheme: %s not supported in PinotFSFactory"
+ " for protocol buffer descriptor file: %s.", scheme, descriptorFilePath));
+ " for protocol buffer jar/descriptor file: %s.", scheme, filePath));
}
}

public static InputStream getDescriptorFileInputStream(String descriptorFilePath)
throws Exception {
return new FileInputStream(getFileCopiedToLocal(descriptorFilePath));
}

public static File createLocalFile(URI srcURI, File dstDir) {
String sourceURIPath = srcURI.getPath();
File dstFile = new File(dstDir, new File(sourceURIPath).getName());
LOGGER.debug("Created empty local temporary file {} to copy protocol "
+ "buffer descriptor {}", dstFile.getAbsolutePath(), srcURI);
return dstFile;
}

// Copied from Flink codebase. https://github.com/apache/flink/blob/master/flink-formats/flink-protobuf/
// src/main/java/org/apache/flink/formats/protobuf/util/PbCodegenUtils.java
// This is needed since the generated class name is not always the same as the proto file name.
// The descriptor that we get from the jar drops the first prefix of the proto class name.
// For example, insead of com.data.example.ExampleProto we get data.example.ExampleProto.
// Copied from Flink codebase.
// https://github.com/apache/flink/blob/master/flink-formats/flink-protobuf/
// src/main/java/org/apache/flink/formats/protobuf/util/PbCodegenUtils.java
public static String getFullJavaName(Descriptors.Descriptor descriptor) {
if (null != descriptor.getContainingType()) {
// nested type
String parentJavaFullName = getFullJavaName(descriptor.getContainingType());
return parentJavaFullName + "." + descriptor.getName();
} else {
// top level message
String outerProtoName = getOuterProtoPrefix(descriptor.getFile());
return outerProtoName + descriptor.getName();
}
}

public static String getFullJavaName(Descriptors.EnumDescriptor enumDescriptor) {
if (null != enumDescriptor.getContainingType()) {
return getFullJavaName(enumDescriptor.getContainingType())
+ "."
+ enumDescriptor.getName();
} else {
String outerProtoName = getOuterProtoPrefix(enumDescriptor.getFile());
return outerProtoName + enumDescriptor.getName();
}
}

public static String getOuterProtoPrefix(Descriptors.FileDescriptor fileDescriptor) {
String javaPackageName =
fileDescriptor.getOptions().hasJavaPackage()
? fileDescriptor.getOptions().getJavaPackage()
: fileDescriptor.getPackage();
if (fileDescriptor.getOptions().getJavaMultipleFiles()) {
return javaPackageName + ".";
} else {
String outerClassName = getOuterClassName(fileDescriptor);
return javaPackageName + "." + outerClassName + ".";
}
}

public static String getOuterClassName(Descriptors.FileDescriptor fileDescriptor) {
if (fileDescriptor.getOptions().hasJavaOuterClassname()) {
return fileDescriptor.getOptions().getJavaOuterClassname();
} else {
String[] fileNames = fileDescriptor.getName().split("/");
String fileName = fileNames[fileNames.length - 1];
String outerName = ProtobufInternalUtils.underScoreToCamelCase(fileName.split("\\.")[0], true);
// https://developers.google.com/protocol-buffers/docs/reference/java-generated#invocation
// The name of the wrapper class is determined by converting the base name of the .proto
// file to camel case if the java_outer_classname option is not specified.
// For example, foo_bar.proto produces the class name FooBar. If there is a service,
// enum, or message (including nested types) in the file with the same name,
// "OuterClass" will be appended to the wrapper class's name.
boolean hasSameNameMessage =
fileDescriptor.getMessageTypes().stream()
.anyMatch(f -> f.getName().equals(outerName));
boolean hasSameNameEnum =
fileDescriptor.getEnumTypes().stream()
.anyMatch(f -> f.getName().equals(outerName));
boolean hasSameNameService =
fileDescriptor.getServices().stream()
.anyMatch(f -> f.getName().equals(outerName));
if (hasSameNameMessage || hasSameNameEnum || hasSameNameService) {
return outerName + PB_OUTER_CLASS_SUFFIX;
} else {
return outerName;
}
}
}

/**
* Get java type str from {@link Descriptors.FieldDescriptor} which directly fetched from protobuf object.
*
* @return The returned code phrase will be used as java type str in codegen sections.
*/
public static String getTypeStrFromProto(Descriptors.FieldDescriptor fd, boolean isList) {
String typeStr;
switch (fd.getJavaType()) {
case MESSAGE:
if (fd.isMapField()) {
// map
Descriptors.FieldDescriptor keyFd =
fd.getMessageType().findFieldByName("key");
Descriptors.FieldDescriptor valueFd =
fd.getMessageType().findFieldByName("value");
// key and value cannot be repeated
String keyTypeStr = getTypeStrFromProto(keyFd, false);
String valueTypeStr = getTypeStrFromProto(valueFd, false);
typeStr = "Map<" + keyTypeStr + "," + valueTypeStr + ">";
} else {
// simple message
typeStr = getFullJavaName(fd.getMessageType());
}
break;
case INT:
typeStr = "Integer";
break;
case LONG:
typeStr = "Long";
break;
case STRING:
typeStr = "String";
break;
case ENUM:
typeStr = getFullJavaName(fd.getEnumType());
break;
case FLOAT:
typeStr = "Float";
break;
case DOUBLE:
typeStr = "Double";
break;
case BYTE_STRING:
typeStr = "ByteString";
break;
case BOOLEAN:
typeStr = "Boolean";
break;
default:
throw new RuntimeException("do not support field type: " + fd.getJavaType());
}
if (isList) {
return "List<" + typeStr + ">";
} else {
return typeStr;
}
}
}

0 comments on commit 8cadb65

Please sign in to comment.