Skip to content

Commit

Permalink
Merge pull request #804 from zinggAI/selectCols
Browse files Browse the repository at this point in the history
overloaded getPairs getActualDupes method and event listener bugs removed
  • Loading branch information
sonalgoyal committed Mar 18, 2024
2 parents ee6027e + d2730de commit f96145e
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 71 deletions.
5 changes: 2 additions & 3 deletions common/client/src/main/java/zingg/common/client/Client.java
Expand Up @@ -214,8 +214,6 @@ else if (options.get(ClientOptions.CONF).value.endsWith("env")) {
client = getClient(arguments, options);
client.init();
// after setting arguments etc. as some of the listeners need it
initializeListeners();
EventsListener.getInstance().fireEvent(new ZinggStartEvent());
client.execute();
client.postMetrics();
LOG.warn("Zingg processing has completed");
Expand Down Expand Up @@ -263,7 +261,8 @@ else if (options.get(ClientOptions.CONF).value.endsWith("env")) {
public void init() throws ZinggClientException {
zingg.init(getArguments(), getSession());
if (session != null) zingg.setSession(session);

initializeListeners();
EventsListener.getInstance().fireEvent(new ZinggStartEvent());
}

/**
Expand Down
@@ -0,0 +1,23 @@
package zingg.common.client.cols;

import java.util.ArrayList;
import java.util.List;

import zingg.common.client.util.ColName;

public class PredictionColsSelector extends SelectedCols {

public PredictionColsSelector() {

List<String> cols = new ArrayList<String>();
cols.add(ColName.ID_COL);
cols.add(ColName.COL_PREFIX + ColName.ID_COL);
cols.add(ColName.PREDICTION_COL);
cols.add(ColName.SCORE_COL);

setCols(cols);

}


}
@@ -1,25 +1,25 @@
package zingg.common.client.event.listeners;

import java.util.List;

import zingg.common.client.ZinggClientException;
import zingg.common.client.event.events.IEvent;
import zingg.common.client.util.ListMap;

public class EventsListener {
private static EventsListener eventsListener = null;
private final ListMap<String, IEventListener> eventListeners;
private static EventsListener _eventsListener = new EventsListener();
private final ListMap<String, IEventListener> eventListenersList;

private EventsListener() {
eventListeners = new ListMap<>();
eventListenersList = new ListMap<String, IEventListener>();
}

public static EventsListener getInstance() {
if (eventsListener == null)
eventsListener = new EventsListener();
return eventsListener;
return _eventsListener;
}

public void addListener(Class<? extends IEvent> eventClass, IEventListener listener) {
eventListeners.add(eventClass.getCanonicalName(), listener);
eventListenersList.add(eventClass.getCanonicalName(), listener);
}

public void fireEvent(IEvent event) throws ZinggClientException {
Expand All @@ -28,8 +28,13 @@ public void fireEvent(IEvent event) throws ZinggClientException {

private void listen(IEvent event) throws ZinggClientException {
Class<? extends IEvent> eventClass = event.getClass();
for (IEventListener listener : eventListeners.get(eventClass.getCanonicalName())) {
listener.listen(event);
}
List<IEventListener> listenerList = eventListenersList.get(eventClass.getCanonicalName());
if (listenerList != null) {
for (IEventListener listener : listenerList) {
if (listener != null) {
listener.listen(event);
}
}
}
}
}
33 changes: 15 additions & 18 deletions common/core/src/main/java/zingg/common/core/executor/Linker.java
Expand Up @@ -7,8 +7,7 @@
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOptions;
import zingg.common.client.util.ColName;
import zingg.common.client.util.ColValues;
import zingg.common.core.pairs.IPairBuilder;
import zingg.common.core.filter.PredictionFilter;
import zingg.common.core.pairs.SelfPairBuilderSourceSensitive;


Expand All @@ -27,13 +26,26 @@ public Linker() {
public ZFrame<D,R,C> selectColsFromBlocked(ZFrame<D,R,C> blocked) {
return blocked;
}

@Override
public ZFrame<D,R,C> getPairs(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll) throws Exception{
return getPairs(blocked, bAll, new SelfPairBuilderSourceSensitive<S, D, R, C> (getDSUtil(),args));
}

@Override
protected ZFrame<D,R,C> getActualDupes(ZFrame<D,R,C> blocked, ZFrame<D,R,C> testData) throws Exception, ZinggClientException{
PredictionFilter<D, R, C> predictionFilter = new PredictionFilter<D, R, C>();
SelfPairBuilderSourceSensitive<S, D, R, C> iPairBuilder = new SelfPairBuilderSourceSensitive<S, D, R, C> (getDSUtil(),args);
return getActualDupes(blocked, testData,predictionFilter, iPairBuilder, null);
}

@Override
public void writeOutput(ZFrame<D,R,C> sampleOrginal, ZFrame<D,R,C> dupes) throws ZinggClientException {
try {
// input dupes are pairs
/// pick ones according to the threshold by user
ZFrame<D,R,C> dupesActual = getDupesActualForGraph(dupes);
PredictionFilter<D, R, C> predictionFilter = new PredictionFilter<D, R, C>();
ZFrame<D,R,C> dupesActual = predictionFilter.filter(dupes);

// all clusters consolidated in one place
if (args.getOutput() != null) {
Expand All @@ -52,19 +64,4 @@ public void writeOutput(ZFrame<D,R,C> sampleOrginal, ZFrame<D,R,C> dupes) throws
}
}

@Override
public ZFrame<D,R,C> getDupesActualForGraph(ZFrame<D,R,C> dupes) {
ZFrame<D,R,C> dupesActual = dupes
.filter(dupes.equalTo(ColName.PREDICTION_COL, ColValues.IS_MATCH_PREDICTION));
return dupesActual;
}

@Override
public IPairBuilder<S, D, R, C> getIPairBuilder() {
if(iPairBuilder==null) {
iPairBuilder = new SelfPairBuilderSourceSensitive<S, D, R, C> (getDSUtil(),args);
}
return iPairBuilder;
}

}
62 changes: 22 additions & 40 deletions common/core/src/main/java/zingg/common/core/executor/Matcher.java
Expand Up @@ -8,12 +8,14 @@

import zingg.common.client.ZFrame;
import zingg.common.client.ZinggClientException;
import zingg.common.client.cols.PredictionColsSelector;
import zingg.common.client.cols.ZidAndFieldDefSelector;
import zingg.common.client.options.ZinggOptions;
import zingg.common.client.util.ColName;
import zingg.common.client.util.ColValues;
import zingg.common.core.block.Canopy;
import zingg.common.core.block.Tree;
import zingg.common.core.filter.IFilter;
import zingg.common.core.filter.PredictionFilter;
import zingg.common.core.model.Model;
import zingg.common.core.pairs.IPairBuilder;
import zingg.common.core.pairs.SelfPairBuilder;
Expand All @@ -27,8 +29,6 @@ public abstract class Matcher<S,D,R,C,T> extends ZinggBase<S,D,R,C,T>{
protected static String name = "zingg.Matcher";
public static final Log LOG = LogFactory.getLog(Matcher.class);

protected IPairBuilder<S, D, R, C> iPairBuilder;

public Matcher() {
setZinggOption(ZinggOptions.MATCH);
}
Expand All @@ -54,7 +54,11 @@ public ZFrame<D,R,C> getBlocked( ZFrame<D,R,C> testData) throws Exception, Zin
}

public ZFrame<D,R,C> getPairs(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll) throws Exception{
return getIPairBuilder().getPairs(blocked, bAll);
return getPairs(blocked, bAll, new SelfPairBuilder<S, D, R, C> (getDSUtil(),args));
}

public ZFrame<D,R,C> getPairs(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll, IPairBuilder<S, D, R, C> iPairBuilder) throws Exception{
return iPairBuilder.getPairs(blocked, bAll);
}

protected abstract Model getModel() throws ZinggClientException;
Expand All @@ -76,11 +80,22 @@ protected ZFrame<D,R,C> predictOnBlocks(ZFrame<D,R,C>blocks) throws Exception, Z
}

protected ZFrame<D,R,C> getActualDupes(ZFrame<D,R,C> blocked, ZFrame<D,R,C> testData) throws Exception, ZinggClientException{
ZFrame<D,R,C> blocks = getPairs(selectColsFromBlocked(blocked), testData);
ZFrame<D,R,C>dupesActual = predictOnBlocks(blocks);
return getDupesActualForGraph(dupesActual);
PredictionFilter<D, R, C> predictionFilter = new PredictionFilter<D, R, C>();
SelfPairBuilder<S, D, R, C> iPairBuilder = new SelfPairBuilder<S, D, R, C> (getDSUtil(),args);
return getActualDupes(blocked, testData,predictionFilter, iPairBuilder,new PredictionColsSelector());
}

protected ZFrame<D,R,C> getActualDupes(ZFrame<D,R,C> blocked, ZFrame<D,R,C> testData,
IFilter<D, R, C> predictionFilter, IPairBuilder<S, D, R, C> iPairBuilder, PredictionColsSelector colsSelector) throws Exception, ZinggClientException{
ZFrame<D,R,C> blocks = getPairs(selectColsFromBlocked(blocked), testData, iPairBuilder);
ZFrame<D,R,C>dupesActual = predictOnBlocks(blocks);
ZFrame<D, R, C> filteredData = predictionFilter.filter(dupesActual);
if(colsSelector!=null) {
filteredData = filteredData.select(colsSelector.getCols());
}
return filteredData;
}

@Override
public void execute() throws ZinggClientException {
try {
Expand Down Expand Up @@ -251,40 +266,7 @@ protected ZFrame<D, R, C> getGraphWithScores(ZFrame<D, R, C> graph, ZFrame<D, R,
return allScores.groupByMinMaxScore(allScores.col(ColName.ID_COL));
}

protected ZFrame<D,R,C> getDupesActualForGraph(ZFrame<D,R,C>dupes) {
dupes = selectColsFromDupes(dupes);
LOG.debug("dupes al");
if (LOG.isDebugEnabled()) dupes.show();
return dupes.filter(dupes.equalTo(ColName.PREDICTION_COL,ColValues.IS_MATCH_PREDICTION));
}

protected ZFrame<D,R,C> selectColsFromDupes(ZFrame<D,R,C>dupesActual) {
List<C> cols = new ArrayList<C>();
cols.add(dupesActual.col(ColName.ID_COL));
cols.add(dupesActual.col(ColName.COL_PREFIX + ColName.ID_COL));
cols.add(dupesActual.col(ColName.PREDICTION_COL));
cols.add(dupesActual.col(ColName.SCORE_COL));
ZFrame<D,R,C> dupesActual1 = dupesActual.select(cols); //.cache();
return dupesActual1;
}

protected abstract StopWordsRemover<S,D,R,C,T> getStopWords();

/**
* Each sub class of matcher can inject it's own iPairBuilder implementation
* @return
*/
public IPairBuilder<S, D, R, C> getIPairBuilder() {
if(iPairBuilder==null) {
iPairBuilder = new SelfPairBuilder<S, D, R, C> (getDSUtil(),args);
}
return iPairBuilder;
}

public void setIPairBuilder(IPairBuilder<S, D, R, C> iPairBuilder) {
this.iPairBuilder = iPairBuilder;
}



}
@@ -0,0 +1,9 @@
package zingg.common.core.filter;

import zingg.common.client.ZFrame;

public interface IFilter<D, R, C> {

public ZFrame<D, R, C> filter(ZFrame<D, R, C> df);

}
@@ -0,0 +1,29 @@
package zingg.common.core.filter;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.ZFrame;
import zingg.common.client.util.ColName;
import zingg.common.client.util.ColValues;

public class PredictionFilter<D, R, C> implements IFilter<D, R, C> {

public static final Log LOG = LogFactory.getLog(PredictionFilter.class);

public PredictionFilter() {
super();
}

@Override
public ZFrame<D, R, C> filter(ZFrame<D, R, C> dupes) {
dupes = filterMatches(dupes);
return dupes;
}

protected ZFrame<D, R, C> filterMatches(ZFrame<D, R, C> dupes) {
dupes = dupes.filter(dupes.equalTo(ColName.PREDICTION_COL,ColValues.IS_MATCH_PREDICTION));
return dupes;
}

}

0 comments on commit f96145e

Please sign in to comment.