Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use IPairBuilder for building pairs #803

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions common/core/src/main/java/zingg/common/core/executor/Linker.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,27 @@
import zingg.common.client.options.ZinggOptions;
import zingg.common.client.util.ColName;
import zingg.common.client.util.ColValues;
import zingg.common.core.pairs.IPairBuilder;
import zingg.common.core.pairs.SelfPairBuilderSourceSensitive;



public abstract class Linker<S,D,R,C,T> extends Matcher<S,D,R,C,T> {

private static final long serialVersionUID = 1L;
protected static String name = "zingg.Linker";
public static final Log LOG = LogFactory.getLog(Linker.class);

public Linker() {
setZinggOption(ZinggOptions.LINK);
}

public ZFrame<D,R,C> getBlocks(ZFrame<D,R,C> blocked, ZFrame<D,R,C> bAll) throws Exception{
// THIS LOG IS NEEDED FOR PLAN CALCULATION USING COUNT, DO NOT REMOVE
LOG.info("in getBlocks, blocked count is " + blocked.count());
return getDSUtil().joinWithItselfSourceSensitive(blocked, ColName.HASH_COL, args).cache();
}


@Override
public ZFrame<D,R,C> selectColsFromBlocked(ZFrame<D,R,C> blocked) {
return blocked;
}

@Override
public void writeOutput(ZFrame<D,R,C> sampleOrginal, ZFrame<D,R,C> dupes) throws ZinggClientException {
try {
// input dupes are pairs
Expand All @@ -53,12 +52,19 @@ public void writeOutput(ZFrame<D,R,C> sampleOrginal, ZFrame<D,R,C> dupes) throws
}
}

@Override
public ZFrame<D,R,C> getDupesActualForGraph(ZFrame<D,R,C> dupes) {
ZFrame<D,R,C> dupesActual = dupes
.filter(dupes.equalTo(ColName.PREDICTION_COL, ColValues.IS_MATCH_PREDICTION));
return dupesActual;
}


@Override
public IPairBuilder<S, D, R, C> getIPairBuilder() {
if(iPairBuilder==null) {
iPairBuilder = new SelfPairBuilderSourceSensitive<S, D, R, C> (getDSUtil(),args);
}
return iPairBuilder;
}

}
42 changes: 21 additions & 21 deletions common/core/src/main/java/zingg/common/core/executor/Matcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import zingg.common.core.block.Canopy;
import zingg.common.core.block.Tree;
import zingg.common.core.model.Model;
import zingg.common.core.pairs.IPairBuilder;
import zingg.common.core.pairs.SelfPairBuilder;
import zingg.common.core.preprocess.StopWordsRemover;
import zingg.common.core.util.Analytics;
import zingg.common.core.util.Metric;
Expand All @@ -25,6 +27,7 @@ public abstract class Matcher<S,D,R,C,T> extends ZinggBase<S,D,R,C,T>{
protected static String name = "zingg.Matcher";
public static final Log LOG = LogFactory.getLog(Matcher.class);

protected IPairBuilder<S, D, R, C> iPairBuilder;

public Matcher() {
setZinggOption(ZinggOptions.MATCH);
Expand All @@ -50,26 +53,8 @@ public ZFrame<D,R,C> getBlocked( ZFrame<D,R,C> testData) throws Exception, Zin
return blocked1;
}

public ZFrame<D,R,C> getBlocks(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll) throws Exception{
ZFrame<D,R,C>joinH = getDSUtil().joinWithItself(blocked, ColName.HASH_COL, true).cache();
/*ZFrame<D,R,C>joinH = blocked.as("first").joinOnCol(blocked.as("second"), ColName.HASH_COL)
.selectExpr("first.z_zid as z_zid", "second.z_zid as z_z_zid");
*/
//joinH.show();
joinH = joinH.filter(joinH.gt(ColName.ID_COL));
LOG.warn("Num comparisons " + joinH.count());
joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.ID_COL));
bAll = bAll.repartition(args.getNumPartitions(), bAll.col(ColName.ID_COL));
joinH = joinH.joinOnCol(bAll, ColName.ID_COL);
LOG.warn("Joining with actual values");
//joinH.show();
bAll = getDSUtil().getPrefixedColumnsDS(bAll);
//bAll.show();
joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.COL_PREFIX + ColName.ID_COL));
joinH = joinH.joinOnCol(bAll, ColName.COL_PREFIX + ColName.ID_COL);
LOG.warn("Joining again with actual values");
//joinH.show();
return joinH;
public ZFrame<D,R,C> getPairs(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll) throws Exception{
return getIPairBuilder().getPairs(blocked, bAll);
}

protected abstract Model getModel() throws ZinggClientException;
Expand All @@ -91,7 +76,7 @@ protected ZFrame<D,R,C> predictOnBlocks(ZFrame<D,R,C>blocks) throws Exception, Z
}

protected ZFrame<D,R,C> getActualDupes(ZFrame<D,R,C> blocked, ZFrame<D,R,C> testData) throws Exception, ZinggClientException{
ZFrame<D,R,C> blocks = getBlocks(selectColsFromBlocked(blocked), testData);
ZFrame<D,R,C> blocks = getPairs(selectColsFromBlocked(blocked), testData);
ZFrame<D,R,C>dupesActual = predictOnBlocks(blocks);
return getDupesActualForGraph(dupesActual);
}
Expand Down Expand Up @@ -285,6 +270,21 @@ protected ZFrame<D,R,C> selectColsFromDupes(ZFrame<D,R,C>dupesActual) {

protected abstract StopWordsRemover<S,D,R,C,T> getStopWords();

/**
* Each sub class of matcher can inject it's own iPairBuilder implementation
* @return
*/
public IPairBuilder<S, D, R, C> getIPairBuilder() {
if(iPairBuilder==null) {
iPairBuilder = new SelfPairBuilder<S, D, R, C> (getDSUtil(),args);
}
return iPairBuilder;
}

public void setIPairBuilder(IPairBuilder<S, D, R, C> iPairBuilder) {
this.iPairBuilder = iPairBuilder;
}



}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package zingg.common.core.pairs;

import zingg.common.client.ZFrame;

public interface IPairBuilder<S, D, R, C> {

public ZFrame<D, R, C> getPairs(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll) throws Exception;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package zingg.common.core.pairs;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.IArguments;
import zingg.common.client.ZFrame;
import zingg.common.client.util.ColName;
import zingg.common.client.util.DSUtil;

public class SelfPairBuilder<S, D, R, C> implements IPairBuilder<S, D, R, C> {

protected DSUtil<S, D, R, C> dsUtil;
public static final Log LOG = LogFactory.getLog(SelfPairBuilder.class);
protected IArguments args;

public SelfPairBuilder(DSUtil<S, D, R, C> dsUtil, IArguments args) {
this.dsUtil = dsUtil;
this.args = args;
}

@Override
public ZFrame<D, R, C> getPairs(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll) throws Exception {
ZFrame<D,R,C>joinH = getDSUtil().joinWithItself(blocked, ColName.HASH_COL, true).cache();
/*ZFrame<D,R,C>joinH = blocked.as("first").joinOnCol(blocked.as("second"), ColName.HASH_COL)
.selectExpr("first.z_zid as z_zid", "second.z_zid as z_z_zid");
*/
//joinH.show();
joinH = joinH.filter(joinH.gt(ColName.ID_COL));
LOG.warn("Num comparisons " + joinH.count());

Check warning

Code scanning / PMD

Logger calls should be surrounded by log level guards. Warning

Logger calls should be surrounded by log level guards.
joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.ID_COL));
bAll = bAll.repartition(args.getNumPartitions(), bAll.col(ColName.ID_COL));
joinH = joinH.joinOnCol(bAll, ColName.ID_COL);
LOG.warn("Joining with actual values");
//joinH.show();
bAll = getDSUtil().getPrefixedColumnsDS(bAll);
//bAll.show();
joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.COL_PREFIX + ColName.ID_COL));
joinH = joinH.joinOnCol(bAll, ColName.COL_PREFIX + ColName.ID_COL);
LOG.warn("Joining again with actual values");
//joinH.show();
return joinH;
}

public DSUtil<S, D, R, C> getDSUtil() {
return dsUtil;
}

public void setDSUtil(DSUtil<S, D, R, C> dsUtil) {
this.dsUtil = dsUtil;
}



}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package zingg.common.core.pairs;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.IArguments;
import zingg.common.client.ZFrame;
import zingg.common.client.util.ColName;
import zingg.common.client.util.DSUtil;

public class SelfPairBuilderSourceSensitive<S, D, R, C> extends SelfPairBuilder<S, D, R, C> {

public static final Log LOG = LogFactory.getLog(SelfPairBuilderSourceSensitive.class);

public SelfPairBuilderSourceSensitive(DSUtil<S, D, R, C> dsUtil, IArguments args) {
super(dsUtil, args);
}

@Override
public ZFrame<D,R,C> getPairs(ZFrame<D,R,C> blocked, ZFrame<D,R,C> bAll) throws Exception{
// THIS LOG IS NEEDED FOR PLAN CALCULATION USING COUNT, DO NOT REMOVE
LOG.info("in getBlocks, blocked count is " + blocked.count());

Check warning

Code scanning / PMD

Logger calls should be surrounded by log level guards. Warning

Logger calls should be surrounded by log level guards.
return getDSUtil().joinWithItselfSourceSensitive(blocked, ColName.HASH_COL, args).cache();
}

}
Loading