Skip to content

Commit

Permalink
Merge pull request #801 from zinggAI/selectCols
Browse files Browse the repository at this point in the history
using ZidAndFieldDefSelector to select cols in ftd, label, matcher
  • Loading branch information
sonalgoyal committed Mar 5, 2024
2 parents 942791a + fcb9e2b commit 5aab06a
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 42 deletions.
Expand Up @@ -8,7 +8,7 @@ public interface ILabelDataViewHelper<S, D, R, C> {

List<R> getClusterIds(ZFrame<D, R, C> lines);

List<C> getDisplayColumns(ZFrame<D, R, C> lines, IArguments args);
// List<C> getDisplayColumns(ZFrame<D, R, C> lines, IArguments args);

ZFrame<D, R, C> getCurrentPair(ZFrame<D, R, C> lines, int index, List<R> clusterIds, ZFrame<D, R, C> clusterLines);

Expand Down
Expand Up @@ -4,27 +4,38 @@
import java.util.List;

import zingg.common.client.FieldDefinition;
import zingg.common.client.MatchType;

public class FieldDefSelectedCols extends SelectedCols {

public FieldDefSelectedCols(List<FieldDefinition> fieldDefs, boolean showConcise) {
protected FieldDefSelectedCols() {

}

public FieldDefSelectedCols(List<? extends FieldDefinition> fieldDefs, boolean showConcise) {
List<String> colList = getColList(fieldDefs, showConcise);
setCols(colList);
}

protected List<String> getColList(List<? extends FieldDefinition> fieldDefs) {
return getColList(fieldDefs,false);
}

List<FieldDefinition> namedList = new ArrayList<>();
protected List<String> getColList(List<? extends FieldDefinition> fieldDefs, boolean showConcise) {
List<FieldDefinition> namedList = new ArrayList<FieldDefinition>();

for (FieldDefinition fieldDef : fieldDefs) {
if (showConcise && fieldDef.isDontUse()) {
if (showConcise && fieldDef.matchType.contains(MatchType.DONT_USE)) {
continue;
}
namedList.add(fieldDef);
}

namedList.add(new FieldDefinition());
List<String> stringList = convertNamedListToStringList(namedList);
setCols(stringList);
}
return stringList;
}

private List<String> convertNamedListToStringList(List<FieldDefinition> namedList) {
List<String> stringList = new ArrayList<>();
protected List<String> convertNamedListToStringList(List<? extends FieldDefinition> namedList) {
List<String> stringList = new ArrayList<String>();
for (FieldDefinition named : namedList) {
stringList.add(named.getName());
}
Expand Down
@@ -1,16 +1,24 @@
package zingg.common.client.cols;

import java.util.Arrays;
import java.util.List;

import zingg.common.client.FieldDefinition;
import zingg.common.client.util.ColName;

public class ZidAndFieldDefSelector extends SelectedCols {
public class ZidAndFieldDefSelector extends FieldDefSelectedCols {

public ZidAndFieldDefSelector(String[] fieldDefs) {
public ZidAndFieldDefSelector(List<? extends FieldDefinition> fieldDefs) {
this(fieldDefs, true, false);
}

public ZidAndFieldDefSelector(List<? extends FieldDefinition> fieldDefs, boolean includeZid, boolean showConcise) {
List<String> colList = getColList(fieldDefs, showConcise);

if (includeZid) colList.add(0, ColName.ID_COL);

colList.add(ColName.SOURCE_COL);

setCols(colList);
}

List<String> fieldDefList = Arrays.asList(fieldDefs);
fieldDefList.add(0, ColName.ID_COL);
setCols(fieldDefList);
}
}
Expand Up @@ -6,11 +6,9 @@
import org.apache.commons.logging.LogFactory;

import zingg.common.client.ClientOptions;
import zingg.common.client.IArguments;
import zingg.common.client.ILabelDataViewHelper;
import zingg.common.client.ZFrame;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
import zingg.common.client.util.ColName;
import zingg.common.client.util.ColValues;
import zingg.common.core.context.Context;
Expand Down Expand Up @@ -39,11 +37,11 @@ public List<R> getClusterIds(ZFrame<D,R,C> lines) {
}


@Override
public List<C> getDisplayColumns(ZFrame<D,R,C> lines, IArguments args) {
return getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise());
}

// @Override
// public List<C> getDisplayColumns(ZFrame<D,R,C> lines, IArguments args) {
// return getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise());
// }
//

@Override
public ZFrame<D,R,C> getCurrentPair(ZFrame<D,R,C> lines, int index, List<R> clusterIds, ZFrame<D,R,C> clusterLines) {
Expand Down
@@ -1,13 +1,13 @@
package zingg.common.core.executor;

import java.util.List;
import java.util.Scanner;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.ZFrame;
import zingg.common.client.ZinggClientException;
import zingg.common.client.cols.ZidAndFieldDefSelector;
import zingg.common.client.options.ZinggOptions;
import zingg.common.client.pipe.Pipe;
import zingg.common.client.util.ColName;
Expand Down Expand Up @@ -125,14 +125,14 @@ protected ZFrame<D, R, C> getUpdatedRecords(ZFrame<D, R, C> updatedRecords, int
}

protected int getUserInput(ZFrame<D,R,C> lines,ZFrame<D,R,C> currentPair,String cluster_id) {

List<C> displayCols = getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise());

// List<C> displayCols = getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise());
ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition(), false, args.getShowConcise());
int matchFlag = currentPair.getAsInt(currentPair.head(),ColName.MATCH_FLAG_COL);
String preMsg = String.format("\n\tThe record pairs belonging to the input cluster id %s are:", cluster_id);
String matchType = LabelMatchType.get(matchFlag).msg;
String postMsg = String.format("\tThe above pair is labeled as %s\n", matchType);
int selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), preMsg, postMsg);
// int selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), preMsg, postMsg);
int selectedOption = displayRecordsAndGetUserInput(currentPair.select(zidAndFieldDefSelector.getCols()), preMsg, postMsg);
getTrainingDataModel().updateLabellerStat(selectedOption, INCREMENT);
getTrainingDataModel().updateLabellerStat(matchFlag, -1*INCREMENT);
getLabelDataViewHelper().printMarkedRecordsStat(
Expand Down
Expand Up @@ -10,6 +10,7 @@
import zingg.common.client.ITrainingDataModel;
import zingg.common.client.ZFrame;
import zingg.common.client.ZinggClientException;
import zingg.common.client.cols.ZidAndFieldDefSelector;
import zingg.common.client.options.ZinggOptions;
import zingg.common.client.util.ColName;

Expand Down Expand Up @@ -79,7 +80,8 @@ public ZFrame<D,R,C> processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientE
);

lines = lines.cache();
List<C> displayCols = getLabelDataViewHelper().getDisplayColumns(lines, args);
// List<C> displayCols = getLabelDataViewHelper().getDisplayColumns(lines, args);
ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition(), false, args.getShowConcise());
//have to introduce as snowframe can not handle row.getAs with column
//name and row and lines are out of order for the code to work properly
//snow getAsString expects row to have same struc as dataframe which is
Expand All @@ -104,7 +106,8 @@ public ZFrame<D,R,C> processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientE
msg2 = getLabelDataViewHelper().getMsg2(prediction, score);
//String msgHeader = msg1 + msg2;

selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), msg1, msg2);
// selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), msg1, msg2);
selectedOption = displayRecordsAndGetUserInput(currentPair.select(zidAndFieldDefSelector.getCols()), msg1, msg2);
getTrainingDataModel().updateLabellerStat(selectedOption, INCREMENT);
getLabelDataViewHelper().printMarkedRecordsStat(
getTrainingDataModel().getPositivePairsCount(),
Expand Down
Expand Up @@ -8,6 +8,7 @@

import zingg.common.client.ZFrame;
import zingg.common.client.ZinggClientException;
import zingg.common.client.cols.ZidAndFieldDefSelector;
import zingg.common.client.options.ZinggOptions;
import zingg.common.client.util.ColName;
import zingg.common.client.util.ColValues;
Expand Down Expand Up @@ -35,7 +36,9 @@ public ZFrame<D,R,C> getTestData() throws ZinggClientException{
}

public ZFrame<D, R, C> getFieldDefColumnsDS(ZFrame<D, R, C> testDataOriginal) {
return getDSUtil().getFieldDefColumnsDS(testDataOriginal, args, true);
ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition());
return testDataOriginal.select(zidAndFieldDefSelector.getCols());
// return getDSUtil().getFieldDefColumnsDS(testDataOriginal, args, true);
}


Expand All @@ -46,13 +49,7 @@ public ZFrame<D,R,C> getBlocked( ZFrame<D,R,C> testData) throws Exception, Zin
ZFrame<D,R,C> blocked1 = blocked.repartition(args.getNumPartitions(), blocked.col(ColName.HASH_COL)); //.cache();
return blocked1;
}



public ZFrame<D,R,C> getBlocks(ZFrame<D,R,C>blocked) throws Exception{
return getDSUtil().joinWithItself(blocked, ColName.HASH_COL, true).cache();
}

public ZFrame<D,R,C> getBlocks(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll) throws Exception{
ZFrame<D,R,C>joinH = getDSUtil().joinWithItself(blocked, ColName.HASH_COL, true).cache();
/*ZFrame<D,R,C>joinH = blocked.as("first").joinOnCol(blocked.as("second"), ColName.HASH_COL)
Expand Down
@@ -1,10 +1,13 @@
package zingg.common.core.executor;

import java.util.Arrays;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.ZFrame;
import zingg.common.client.ZinggClientException;
import zingg.common.client.cols.ZidAndFieldDefSelector;
import zingg.common.client.options.ZinggOptions;
import zingg.common.client.pipe.Pipe;
import zingg.common.client.util.ColName;
Expand Down Expand Up @@ -79,7 +82,7 @@ public void execute() throws ZinggClientException {
if (negPairs!= null) negPairs = negPairs.cache();
//create random samples for blocking
ZFrame<D,R,C> sampleOrginal = data.sample(false, args.getLabelDataSampleSize()).repartition(args.getNumPartitions()).cache();
sampleOrginal = getDSUtil().getFieldDefColumnsDS(sampleOrginal, args, true);
sampleOrginal = getFieldDefColumnsDS(sampleOrginal);
LOG.info("Preprocessing DS for stopWords");

ZFrame<D,R,C> sample = getStopWords().preprocessForStopWords(sampleOrginal);
Expand Down Expand Up @@ -188,7 +191,7 @@ public ZFrame<D,R,C> getPositiveSamples(ZFrame<D,R,C> data) throws Exception {
}
ZFrame<D,R,C> posSample = data.sample(false, args.getLabelDataSampleSize());
//select only those columns which are mentioned in the field definitions
posSample = getDSUtil().getFieldDefColumnsDS(posSample, args, true);
posSample = getFieldDefColumnsDS(posSample);
if (LOG.isDebugEnabled()) {
LOG.debug("Sampled " + posSample.count());
}
Expand All @@ -202,8 +205,13 @@ public ZFrame<D,R,C> getPositiveSamples(ZFrame<D,R,C> data) throws Exception {
return posPairs;
}

protected ZFrame<D, R, C> getFieldDefColumnsDS(ZFrame<D, R, C> data) {
ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition());
String[] cols = zidAndFieldDefSelector.getCols();
return data.select(cols);
//return getDSUtil().getFieldDefColumnsDS(data, args, true);
}

protected abstract StopWordsRemover<S,D,R,C,T> getStopWords();



}

0 comments on commit 5aab06a

Please sign in to comment.