Skip to content

Commit

Permalink
Merge pull request #583 from zinggAI/dbconnect
Browse files Browse the repository at this point in the history
Running using Databricks Connect #582
  • Loading branch information
sonalgoyal committed Jun 2, 2023
2 parents ede80ab + a3ddd46 commit 28dd2a9
Show file tree
Hide file tree
Showing 20 changed files with 764 additions and 224 deletions.
Expand Up @@ -92,6 +92,7 @@
@JsonInclude(Include.NON_NULL)
public class Arguments implements Serializable {

private static final long serialVersionUID = 1L;
// creates DriverArgs and invokes the main object
Pipe[] output;
Pipe[] data;
Expand Down
10 changes: 10 additions & 0 deletions common/client/src/main/java/zingg/common/client/Client.java
Expand Up @@ -15,6 +15,7 @@
*
*/
public abstract class Client<S,D,R,C,T> implements Serializable {
private static final long serialVersionUID = 1L;
protected Arguments arguments;
protected IZingg<S,D,R,C> zingg;
protected ClientOptions options;
Expand Down Expand Up @@ -283,4 +284,13 @@ public ZFrame<D,R,C> getUnmarkedRecords() {
return zingg.getUnmarkedRecords();
}

public ITrainingDataModel<S, D, R, C> getTrainingDataModel() throws UnsupportedOperationException {
return zingg.getTrainingDataModel();
}

public ILabelDataViewHelper<S, D, R, C> getLabelDataViewHelper() throws UnsupportedOperationException {
return zingg.getLabelDataViewHelper();
}


}
@@ -0,0 +1,28 @@
package zingg.common.client;

import java.util.List;

public interface ILabelDataViewHelper<S, D, R, C> {

ZFrame<D, R, C> getClusterIdsFrame(ZFrame<D, R, C> lines);

List<R> getClusterIds(ZFrame<D, R, C> lines);

List<C> getDisplayColumns(ZFrame<D, R, C> lines, Arguments args);

ZFrame<D, R, C> getCurrentPair(ZFrame<D, R, C> lines, int index, List<R> clusterIds, ZFrame<D, R, C> clusterLines);

double getScore(ZFrame<D, R, C> currentPair);

double getPrediction(ZFrame<D, R, C> currentPair);

String getMsg1(int index, int totalPairs);

String getMsg2(double prediction, double score);

void displayRecords(ZFrame<D, R, C> records, String preMessage, String postMessage);

void printMarkedRecordsStat(long positivePairsCount, long negativePairsCount, long notSurePairsCount,
long totalCount);

}
@@ -0,0 +1,26 @@
package zingg.common.client;

import zingg.common.client.pipe.Pipe;

public interface ITrainingDataModel<S, D, R, C> {

public void setMarkedRecordsStat(ZFrame<D, R, C> markedRecords);

public ZFrame<D, R, C> updateRecords(int matchValue, ZFrame<D, R, C> newRecords, ZFrame<D, R, C> updatedRecords);

public void updateLabellerStat(int selected_option, int increment);

public void writeLabelledOutput(ZFrame<D, R, C> records, Arguments args) throws ZinggClientException;

public void writeLabelledOutput(ZFrame<D,R,C> records, Arguments args, Pipe p) throws ZinggClientException;

public long getPositivePairsCount();

public long getNegativePairsCount();

public long getNotSurePairsCount() ;

public long getTotalCount();


}
6 changes: 5 additions & 1 deletion common/client/src/main/java/zingg/common/client/IZingg.java
Expand Up @@ -37,5 +37,9 @@ public void init(Arguments args, String license)
public ClientOptions getClientOptions();

public void setSession(S session);


public ITrainingDataModel<S, D, R, C> getTrainingDataModel() throws UnsupportedOperationException;

public ILabelDataViewHelper<S, D, R, C> getLabelDataViewHelper() throws UnsupportedOperationException;

}
@@ -0,0 +1,131 @@
package zingg.common.core.executor;

import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.Arguments;
import zingg.common.client.ClientOptions;
import zingg.common.client.ILabelDataViewHelper;
import zingg.common.client.ZFrame;
import zingg.common.client.ZinggClientException;
import zingg.common.client.ZinggOptions;
import zingg.common.client.util.ColName;
import zingg.common.client.util.ColValues;
import zingg.common.core.Context;
import zingg.common.core.util.LabelMatchType;

public class LabelDataViewHelper<S,D,R,C,T> extends ZinggBase<S, D, R, C, T> implements ILabelDataViewHelper<S, D, R, C> {

private static final long serialVersionUID = 1L;
public static final Log LOG = LogFactory.getLog(LabelDataViewHelper.class);

public LabelDataViewHelper(Context<S,D,R,C,T> context, ZinggOptions zinggOptions, ClientOptions clientOptions) {
setContext(context);
setZinggOptions(zinggOptions);
setClientOptions(clientOptions);
setName(this.getClass().getName());
}

@Override
public ZFrame<D,R,C> getClusterIdsFrame(ZFrame<D,R,C> lines) {
return lines.select(ColName.CLUSTER_COLUMN).distinct();
}


@Override
public List<R> getClusterIds(ZFrame<D,R,C> lines) {
return lines.collectAsList();
}


@Override
public List<C> getDisplayColumns(ZFrame<D,R,C> lines, Arguments args) {
return getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise());
}


@Override
public ZFrame<D,R,C> getCurrentPair(ZFrame<D,R,C> lines, int index, List<R> clusterIds, ZFrame<D,R,C> clusterLines) {
return lines.filter(lines.equalTo(ColName.CLUSTER_COLUMN,
clusterLines.getAsString(clusterIds.get(index), ColName.CLUSTER_COLUMN))).cache();
}


@Override
public double getScore(ZFrame<D,R,C> currentPair) {
return currentPair.getAsDouble(currentPair.head(),ColName.SCORE_COL);
}


@Override
public double getPrediction(ZFrame<D,R,C> currentPair) {
return currentPair.getAsDouble(currentPair.head(), ColName.PREDICTION_COL);
}


@Override
public String getMsg1(int index, int totalPairs) {
return String.format("\tCurrent labelling round : %d/%d pairs labelled\n", index, totalPairs);
}


@Override
public String getMsg2(double prediction, double score) {
String msg2 = "";
String matchType = LabelMatchType.get(prediction).msg;
if (prediction == ColValues.IS_NOT_KNOWN_PREDICTION) {
msg2 = String.format(
"\tZingg does not do any prediction for the above pairs as Zingg is still collecting training data to build the preliminary models.");
} else {
msg2 = String.format("\tZingg predicts the above records %s with a similarity score of %.2f",
matchType, Math.floor(score * 100) * 0.01);
}
return msg2;
}


@Override
public void displayRecords(ZFrame<D, R, C> records, String preMessage, String postMessage) {
//System.out.println();
System.out.println(preMessage);
records.show(false);
System.out.println(postMessage);
System.out.println("\tWhat do you think? Your choices are: ");
System.out.println();

System.out.println("\tNo, they do not match : 0");
System.out.println("\tYes, they match : 1");
System.out.println("\tNot sure : 2");
System.out.println();
System.out.println("\tTo exit : 9");
System.out.println();
System.out.print("\tPlease enter your choice [0,1,2 or 9]: ");
}

@Override
public void printMarkedRecordsStat(long positivePairsCount,long negativePairsCount,long notSurePairsCount,long totalCount) {
String msg = String.format(
"\tLabelled pairs so far : %d/%d MATCH, %d/%d DO NOT MATCH, %d/%d NOT SURE", positivePairsCount, totalCount,
negativePairsCount, totalCount, notSurePairsCount, totalCount);

System.out.println();
System.out.println();
System.out.println();
System.out.println(msg);
}



@Override
public void execute() throws ZinggClientException {
throw new UnsupportedOperationException();
}

@Override
public ILabelDataViewHelper<S, D, R, C> getLabelDataViewHelper() throws UnsupportedOperationException {
return this;
}

}
Expand Up @@ -14,7 +14,8 @@
import zingg.common.core.util.LabelMatchType;

public abstract class LabelUpdater<S,D,R,C,T> extends Labeller<S,D,R,C,T> {
protected static String name = "zingg.LabelUpdater";
private static final long serialVersionUID = 1L;
protected static String name = "zingg.common.core.executor.LabelUpdater";
public static final Log LOG = LogFactory.getLog(LabelUpdater.class);

public LabelUpdater() {
Expand All @@ -33,12 +34,18 @@ public void execute() throws ZinggClientException {
}
}

public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
public ZFrame<D,R,C> processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
LOG.info("Processing Records for CLI updateLabelling");

if (lines != null && lines.count() > 0) {
getMarkedRecordsStat(lines);
printMarkedRecordsStat();
getTrainingDataModel().setMarkedRecordsStat(lines);
getLabelDataViewHelper().printMarkedRecordsStat(
getTrainingDataModel().getPositivePairsCount(),
getTrainingDataModel().getNegativePairsCount(),
getTrainingDataModel().getNotSurePairsCount(),
getTrainingDataModel().getTotalCount()
);


List<C> displayCols = getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise());
try {
Expand All @@ -52,7 +59,7 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
do {
System.out.print("\n\tPlease enter the cluster id (or 9 to exit): ");
String cluster_id = sc.next();
if (cluster_id.equals("9")) {
if (cluster_id.equals(QUIT_LABELING.toString())) {
LOG.info("User has exit in the middle. Updating the records.");
break;
}
Expand All @@ -67,10 +74,16 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
String matchType = LabelMatchType.get(matchFlag).msg;
postMsg = String.format("\tThe above pair is labeled as %s\n", matchType);
selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), preMsg, postMsg);
updateLabellerStat(selectedOption, +1);
updateLabellerStat(matchFlag, -1);
printMarkedRecordsStat();
if (selectedOption == 9) {
getTrainingDataModel().updateLabellerStat(selectedOption, INCREMENT);
getTrainingDataModel().updateLabellerStat(matchFlag, -1*INCREMENT);
getLabelDataViewHelper().printMarkedRecordsStat(
getTrainingDataModel().getPositivePairsCount(),
getTrainingDataModel().getNegativePairsCount(),
getTrainingDataModel().getNotSurePairsCount(),
getTrainingDataModel().getTotalCount()
);

if (selectedOption == QUIT_LABELING) {
LOG.info("User has quit in the middle. Updating the records.");
break;
}
Expand All @@ -80,15 +93,16 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
updatedRecords = updatedRecords
.filter(updatedRecords.notEqual(ColName.CLUSTER_COLUMN,cluster_id));
}
updatedRecords = updateRecords(selectedOption, currentPair, updatedRecords);
} while (selectedOption != 9);
updatedRecords = getTrainingDataModel().updateRecords(selectedOption, currentPair, updatedRecords);
} while (selectedOption != QUIT_LABELING);

if (updatedRecords != null) {
updatedRecords = updatedRecords.union(recordsToUpdate);
}
writeLabelledOutput(updatedRecords);
getTrainingDataModel().writeLabelledOutput(updatedRecords,args,getOutputPipe());
sc.close();
LOG.info("Processing finished.");
return updatedRecords;
} catch (Exception e) {
if (LOG.isDebugEnabled()) {
e.printStackTrace();
Expand All @@ -98,6 +112,7 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
}
} else {
LOG.info("There is no marked record for updating. Please run findTrainingData/label jobs to generate training data.");
return null;
}
}

Expand Down

0 comments on commit 28dd2a9

Please sign in to comment.