Skip to content

Commit

Permalink
Revert "Merge pull request #774 from gnanaprakash-ravi/ZEIssue230"
Browse files Browse the repository at this point in the history
This reverts commit 7d9567c, reversing
changes made to f9e9528.
  • Loading branch information
sonalgoyal committed Mar 7, 2024
1 parent 330a073 commit 6d87021
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 1,044 deletions.
22 changes: 6 additions & 16 deletions common/client/src/main/java/zingg/common/client/Arguments.java
Expand Up @@ -17,8 +17,6 @@
import com.fasterxml.jackson.databind.ObjectMapper;

import zingg.common.client.pipe.Pipe;
import zingg.common.py.annotations.PythonClass;
import zingg.common.py.annotations.PythonMethod;


/**
Expand Down Expand Up @@ -81,7 +79,6 @@
* }
* </pre>
*/
@PythonClass
@JsonInclude(Include.NON_NULL)
public class Arguments implements Serializable, IArguments {

Expand Down Expand Up @@ -124,7 +121,7 @@ public Arguments() {
public int getNumPartitions() {
return numPartitions;
}
@PythonMethod

@Override
public void setNumPartitions(int numPartitions) throws ZinggClientException{
if (numPartitions != -1 && numPartitions <= 0)
Expand Down Expand Up @@ -157,7 +154,6 @@ public float getLabelDataSampleSize() {
* generating seed samples
* @throws ZinggClientException
*/
@PythonMethod
@Override
public void setLabelDataSampleSize(float labelDataSampleSize) throws ZinggClientException {
if (labelDataSampleSize > 1 || labelDataSampleSize < 0)
Expand Down Expand Up @@ -239,12 +235,12 @@ public void setZinggInternal(Pipe[] zinggDir) {
*/


@PythonMethod

@Override
public String getModelId() {
return modelId;
}
@PythonMethod

@Override
public void setModelId(String modelId) {
this.modelId = modelId;
Expand All @@ -267,7 +263,6 @@ public Pipe[] getOutput() {
* where the match result is saved
* @throws ZinggClientException
*/
@PythonMethod
@Override
public void setOutput(Pipe[] outputDir) throws ZinggClientException {
//checkNullBlankEmpty(outputDir, " path for saving results");
Expand Down Expand Up @@ -345,7 +340,6 @@ public String getZinggDir() {
* @param zinggDir
* path to the Zingg directory
*/
@PythonMethod
@Override
public void setZinggDir(String zinggDir) {
this.zinggDir = zinggDir;
Expand All @@ -357,13 +351,12 @@ public void setZinggDir(String zinggDir) {
*
* @return the path for internal Zingg usage
*/
@PythonMethod

@Override
@JsonIgnore
public String getZinggBaseModelDir(){
return zinggDir + "/" + modelId;
}
@PythonMethod
@Override
@JsonIgnore
public String getZinggModelDir() {
Expand Down Expand Up @@ -393,7 +386,6 @@ public String getZinggDataDocFile() {
*
* @return the path for internal Zingg usage
*/
@PythonMethod
@Override
@JsonIgnore
public String getZinggBaseTrainingDataDir() {
Expand All @@ -407,7 +399,6 @@ public String getZinggBaseTrainingDataDir() {
*
* @return the path for internal Zingg usage
*/
@PythonMethod
@Override
@JsonIgnore
public String getZinggTrainingDataUnmarkedDir() {
Expand All @@ -419,7 +410,6 @@ public String getZinggTrainingDataUnmarkedDir() {
*
* @return the path for internal Zingg usage
*/
@PythonMethod
@Override
@JsonIgnore
public String getZinggTrainingDataMarkedDir() {
Expand Down Expand Up @@ -488,7 +478,7 @@ public void setCollectMetrics(boolean collectMetrics) {
public float getStopWordsCutoff() {
return stopWordsCutoff;
}
@PythonMethod

@Override
public void setStopWordsCutoff(float stopWordsCutoff) throws ZinggClientException {
if (stopWordsCutoff > 1 || stopWordsCutoff < 0)
Expand All @@ -510,7 +500,7 @@ public void setShowConcise(boolean showConcise) {
public String getColumn() {
return column;
}
@PythonMethod

@Override
public void setColumn(String column) {
this.column = column;
Expand Down
Expand Up @@ -22,10 +22,6 @@
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;

import zingg.common.client.cols.Named;
import zingg.common.py.annotations.PythonClass;
import zingg.common.py.annotations.PythonMethod;


/**
* This class defines each field that we use in matching We can use this to
Expand Down Expand Up @@ -56,7 +52,7 @@ public FieldDefinition() {
}

public String getFields() { return fields; }
@PythonMethod

public void setFields(String fields) { this.fields = fields;}

/**
Expand All @@ -75,7 +71,6 @@ public List<MatchType> getMatchType() {
* @param type
* the type to set
*/
@PythonMethod
@JsonDeserialize(using = MatchTypeDeserializer.class)
public void setMatchType(List<MatchType> type) {
this.matchType = type; //MatchTypeDeserializer.getMatchTypeFromString(type);
Expand Down Expand Up @@ -103,7 +98,7 @@ public void setDataType(String d) {
public String getStopWords() {
return stopWords;
}
@PythonMethod

public void setStopWords(String stopWords) {
this.stopWords = stopWords;
}
Expand All @@ -120,7 +115,6 @@ public String getFieldName() {
return fieldName;
}

@PythonMethod
public void setFieldName(String fieldName) {
this.fieldName = fieldName;
}
Expand Down
@@ -1,8 +1,5 @@
package zingg.common.py.processors;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -20,6 +17,7 @@
@SupportedAnnotationTypes("zingg.common.py.annotations.PythonClass")
public class PythonClassProcessor extends AbstractProcessor {

private boolean importsAndDeclarationsGenerated = false;
private Map<String, List<String>> classMethodsMap = new HashMap<>();

@Override
Expand All @@ -30,6 +28,12 @@ public synchronized void init(ProcessingEnvironment processingEnv) {
@Override
public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {

// Imports and global declarations
if (!importsAndDeclarationsGenerated) {
generateImportsAndDeclarations();
importsAndDeclarationsGenerated = true;
}


// process Services annotation
for (Element element : roundEnv.getElementsAnnotatedWith(PythonClass.class)) {
Expand All @@ -44,22 +48,28 @@ public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment
try (FileWriter fileWriter = new FileWriter(outputDirectory + File.separator + element.getSimpleName() + "Generated.py")) {
generateImportsAndDeclarations(element, fileWriter);

fileWriter.write("class " + element.getSimpleName() + ":\n");
System.out.println("class " + element.getSimpleName() + ":");

// __init__ method
System.out.println(" def __init__(self" +
generateConstructorParameters(classElement) + "):");
generateClassInitializationCode(classElement, element);

// __init__ method
fileWriter.write(" def __init__(self" + generateConstructorParameters(classElement, element) + "):\n");
generateClassInitializationCode(classElement, element, fileWriter);
// for (VariableElement field : ElementFilter.fieldsIn(classElement.getEnclosedElements())) {
// if (!field.getSimpleName().contentEquals("serialVersionUID")) {
// generateFieldInitializationCode(field, element);
// }
// }

for (ExecutableElement methodElement : ElementFilter.methodsIn(classElement.getEnclosedElements())) {
if (methodElement.getAnnotation(PythonMethod.class) != null) {
methodNames.add(methodElement.getSimpleName().toString());
}
for (ExecutableElement methodElement : ElementFilter.methodsIn(classElement.getEnclosedElements())) {
if (methodElement.getAnnotation(PythonMethod.class) != null) {
methodNames.add(methodElement.getSimpleName().toString());
}
classMethodsMap.put(element.getSimpleName().toString(), methodNames);
} catch (IOException e) {
e.printStackTrace();
}
classMethodsMap.put(element.getSimpleName().toString(), methodNames);
}
System.out.println();
// rest of generated class contents
}
ProcessorContext processorContext = ProcessorContext.getInstance();
processorContext.getClassMethodsMap().putAll(classMethodsMap);
Expand Down Expand Up @@ -94,13 +104,20 @@ private void generateImportsAndDeclarations(Element element, FileWriter fileWrit
fileWriter.write("JStructType = getJVM().org.apache.spark.sql.types.StructType\n");
fileWriter.write("\n");
}
private void generateImportsAndDeclarations() {
System.out.println("import logging");
System.out.println("from zingg.client import *");
System.out.println("LOG = logging.getLogger(\"zingg.pipes\")");
System.out.println();
System.out.println("JPipe = getJVM().zingg.spark.client.pipe.SparkPipe");
System.out.println("FilePipe = getJVM().zingg.common.client.pipe.FilePipe");
System.out.println("JStructType = getJVM().org.apache.spark.sql.types.StructType");
System.out.println();
}

private void generateClassInitializationCode(TypeElement classElement, Element element, FileWriter fileWriter) throws IOException {
private void generateClassInitializationCode(TypeElement classElement, Element element) {
if (element.getSimpleName().contentEquals("Pipe")) {
fileWriter.write(" self." + element.getSimpleName().toString().toLowerCase() + " = getJVM().zingg.spark.client.pipe.SparkPipe()\n");
fileWriter.write(" self." + element.getSimpleName().toString().toLowerCase() + ".setName(name)\n");
fileWriter.write(" self." + element.getSimpleName().toString().toLowerCase() + ".setFormat(format)\n");
System.out.println(" self." + element.getSimpleName().toString().toLowerCase() + " = getJVM().zingg.spark.client.pipe.SparkPipe()");
}
else if (element.getSimpleName().contentEquals("EPipe")) {
fileWriter.write(" self." + element.getSimpleName().toString().toLowerCase() + " = getJVM().zingg.spark.client.pipe.SparkPipe()\n");
Expand Down Expand Up @@ -134,32 +151,18 @@ else if (element.getSimpleName().contentEquals("FieldDefinition")) {
// }
// }

private String generateConstructorParameters(TypeElement classElement, Element element) {

private String generateConstructorParameters(TypeElement classElement) {
StringBuilder parameters = new StringBuilder();
List<VariableElement> fields = ElementFilter.fieldsIn(classElement.getEnclosedElements());

if (element.getSimpleName().contentEquals("Arguments")) {
// For the "Arguments" class, no constructor parameters are needed
return "";
}
else if (element.getSimpleName().contentEquals("Pipe")) {
parameters.append(", name, format");
}
else if (element.getSimpleName().contentEquals("FieldDefinition")) {
parameters.append(", name, dataType, *matchType");
}
else {
List<VariableElement> fields = ElementFilter.fieldsIn(classElement.getEnclosedElements());

fields = fields.stream()
.filter(field -> !field.getSimpleName().contentEquals("serialVersionUID"))
.filter(this::isFieldForConstructor)
.collect(Collectors.toList());

for (VariableElement field : fields) {
parameters.append(", ");
parameters.append(field.getSimpleName());
}
fields = fields.stream()
.filter(field -> !field.getSimpleName().contentEquals("serialVersionUID"))
.filter(this::isFieldForConstructor)
.collect(Collectors.toList());

for (VariableElement field : fields) {
parameters.append(", ");
parameters.append(field.getSimpleName());
}
return parameters.toString();
}
Expand Down

0 comments on commit 6d87021

Please sign in to comment.