Skip to content

Commit 480fb7b

Browse files
committedOct 29, 2024·
ftd changes
1 parent e872540 commit 480fb7b

21 files changed

+578
-36
lines changed
 

‎common/core/src/main/java/zingg/common/core/block/Block.java

+15-18
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public abstract class Block<D,R,C,T> implements Serializable {
2323
private static final long serialVersionUID = 1L;
2424

2525
public static final Log LOG = LogFactory.getLog(Block.class);
26-
private Set<String> hashFunctionsInCurrentNodePath;
26+
private final IHashFunctionUtility<D, R, C, T> hashFunctionUtility;
2727

2828
protected ZFrame<D,R,C> dupes;
2929
// Class[] types;
@@ -32,28 +32,29 @@ public abstract class Block<D,R,C,T> implements Serializable {
3232
ZFrame<D,R,C> training;
3333
protected ListMap<HashFunction<D,R,C,T>, String> childless;
3434

35-
public Block() {
36-
35+
public Block(HashUtility hashUtility) {
36+
HashFunctionUtilityFactory<D, R, C, T> hashFunctionUtilityFactory = new HashFunctionUtilityFactory<D, R, C, T>();
37+
this.hashFunctionUtility = hashFunctionUtilityFactory.getHashFunctionUtility(hashUtility);
3738
}
3839

39-
public Block(ZFrame<D,R,C> training, ZFrame<D,R,C> dupes) {
40+
public Block(ZFrame<D,R,C> training, ZFrame<D,R,C> dupes, HashUtility hashUtility) {
41+
HashFunctionUtilityFactory<D, R, C, T> hashFunctionUtilityFactory = new HashFunctionUtilityFactory<D, R, C, T>();
42+
this.hashFunctionUtility = hashFunctionUtilityFactory.getHashFunctionUtility(hashUtility);
4043
this.training = training;
4144
this.dupes = dupes;
4245
childless = new ListMap<HashFunction<D,R,C,T>, String>();
43-
hashFunctionsInCurrentNodePath = new HashSet<>();
4446
// types = getSampleTypes();
4547
/*
4648
* for (Class type : types) { LOG.info("Type is " + type); }
4749
*/
4850
}
4951

5052
public Block(ZFrame<D,R,C> training, ZFrame<D,R,C> dupes,
51-
ListMap<T, HashFunction<D, R, C, T>> functionsMap, long maxSize) {
52-
this(training, dupes);
53+
ListMap<T, HashFunction<D, R, C, T>> functionsMap, long maxSize, HashUtility hashUtility) {
54+
this(training, dupes, hashUtility);
5355
this.functionsMap = functionsMap;
5456
// functionsMap.prettyPrint();
5557
this.maxSize = maxSize;
56-
hashFunctionsInCurrentNodePath = new HashSet<>();
5758
}
5859

5960
/**
@@ -149,7 +150,7 @@ public void estimateElimCount(Canopy<R> c, long elimCount) {
149150
for (HashFunction function : functions) {
150151
// /if (!used.contains(field.getIndex(), function) &&
151152
if (least ==0) break;//how much better can it get?
152-
if (!isFunctionUsed(field, function) //&&
153+
if (!hashFunctionUtility.isHashFunctionUsed(field, function, tree, node) //&&
153154
//!childless.contains(function, field.fieldName)
154155
)
155156
{
@@ -236,7 +237,7 @@ public Tree<Canopy<R>> getBlockingTree(Tree<Canopy<R>> tree, Canopy<R>parent,
236237
Canopy<R>best = getBestNode(tree, parent, node, fieldsOfInterest);
237238
if (best != null) {
238239
//add function, context info for this best node in set
239-
hashFunctionsInCurrentNodePath.add(getKey(best.getContext(), best.getFunction()));
240+
hashFunctionUtility.addHashFunctionIfRequired(best);
240241
if (LOG.isDebugEnabled()) {
241242
LOG.debug(" HashFunction is " + best + " and node is " + node);
242243
}
@@ -265,7 +266,7 @@ public Tree<Canopy<R>> getBlockingTree(Tree<Canopy<R>> tree, Canopy<R>parent,
265266
getBlockingTree(tree, node, n, fieldsOfInterest);
266267
}
267268
//remove function, context info for this best node as we are returning from best node
268-
hashFunctionsInCurrentNodePath.remove(getKey(best.getContext(), best.getFunction()));
269+
hashFunctionUtility.removeHashFunctionIfRequired(best);
269270
}
270271
else {
271272
node.clearBeforeSaving();
@@ -287,9 +288,9 @@ public Tree<Canopy<R>> getBlockingTree(Tree<Canopy<R>> tree, Canopy<R>parent,
287288
return tree;
288289
}
289290

290-
public boolean isFunctionUsed(FieldDefinition fieldDefinition, HashFunction<D, R, C, T> function) {
291-
return hashFunctionsInCurrentNodePath.contains(getKey(fieldDefinition, function));
292-
}
291+
// public boolean isFunctionUsed(FieldDefinition fieldDefinition, HashFunction<D, R, C, T> function) {
292+
// return hashFunctionsInCurrentNodePath.contains(getKey(fieldDefinition, function));
293+
// }
293294

294295
public List<Canopy<R>> getHashSuccessors(Collection<Canopy<R>> successors, Object hash) {
295296
List<Canopy<R>> retCanopy = new ArrayList<Canopy<R>>();
@@ -375,10 +376,6 @@ public void printTree(Tree<Canopy<R>> tree,
375376
}
376377
}
377378

378-
private String getKey(FieldDefinition fieldDefinition, HashFunction<D, R, C, T> hashFunction) {
379-
return fieldDefinition.getName() + ":" + hashFunction.getName();
380-
}
381-
382379
public abstract FeatureFactory<T> getFeatureFactory();
383380

384381

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package zingg.common.core.block;
2+
3+
import zingg.common.client.FieldDefinition;
4+
import zingg.common.core.hash.HashFunction;
5+
6+
import java.util.HashSet;
7+
import java.util.Set;
8+
9+
public class CacheBasedHashFunctionUtility<D, R, C, T> implements IHashFunctionUtility<D, R, C, T> {
10+
11+
private final Set<String> hashFunctionsInCurrentNodePath;
12+
private static final String DELIMITER = ":";
13+
14+
public CacheBasedHashFunctionUtility() {
15+
this.hashFunctionsInCurrentNodePath = new HashSet<>();
16+
}
17+
18+
@Override
19+
public boolean isHashFunctionUsed(FieldDefinition fieldDefinition, HashFunction<D, R, C, T> hashFunction, Tree<Canopy<R>> tree, Canopy<R> node) {
20+
return hashFunctionsInCurrentNodePath.contains(getKey(fieldDefinition, hashFunction));
21+
}
22+
23+
@Override
24+
public void addHashFunctionIfRequired(Canopy<R> node) {
25+
addHashFunctionInCurrentNodePath(node);
26+
}
27+
28+
@Override
29+
public void removeHashFunctionIfRequired(Canopy<R> node) {
30+
removeHashFunctionInCurrentNodePath(node);
31+
}
32+
33+
private void addHashFunctionInCurrentNodePath(Canopy<R> node) {
34+
this.hashFunctionsInCurrentNodePath.add(getKey(node.getContext(), node.getFunction()));
35+
}
36+
37+
private void removeHashFunctionInCurrentNodePath(Canopy<R> node) {
38+
this.hashFunctionsInCurrentNodePath.remove(getKey(node.getContext(), node.getFunction()));
39+
}
40+
41+
private String getKey(FieldDefinition fieldDefinition, HashFunction<D, R, C, T> hashFunction) {
42+
return fieldDefinition.getName() + DELIMITER + hashFunction.getName();
43+
}
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package zingg.common.core.block;
2+
3+
import zingg.common.client.FieldDefinition;
4+
import zingg.common.core.hash.HashFunction;
5+
6+
public class DefaultHashFunctionUtility<D, R, C, T> implements IHashFunctionUtility<D, R, C, T>{
7+
@Override
8+
public boolean isHashFunctionUsed(FieldDefinition fieldDefinition, HashFunction<D, R, C, T> hashFunction, Tree<Canopy<R>> tree, Canopy<R> node) {
9+
boolean isUsed = false;
10+
if (node == null || tree == null)
11+
return false;
12+
if (checkFunctionInNode(node, fieldDefinition.fieldName, hashFunction))
13+
return true;
14+
Tree<Canopy<R>> nodeTree = tree.getTree(node);
15+
if (nodeTree == null)
16+
return false;
17+
18+
Tree<Canopy<R>> parent = nodeTree.getParent();
19+
if (parent != null) {
20+
Canopy<R>head = parent.getHead();
21+
while (head != null) {
22+
// check siblings of node
23+
/*for (Tree<Canopy<R>> siblings : parent.getSubTrees()) {
24+
Canopy<R>sibling = siblings.getHead();
25+
if (checkFunctionInNode(sibling, index, function))
26+
return true;
27+
}*/
28+
// check parent of node
29+
return isHashFunctionUsed(fieldDefinition, hashFunction, tree, head);
30+
}
31+
}
32+
return isUsed;
33+
}
34+
35+
@Override
36+
public void addHashFunctionIfRequired(Canopy<R> node) {
37+
//don't add hashFunction to cache
38+
//as we are in default mode
39+
}
40+
41+
@Override
42+
public void removeHashFunctionIfRequired(Canopy<R> node) {
43+
//don't remove hashFunction from cache
44+
//as we are in default mode
45+
}
46+
47+
private boolean checkFunctionInNode(Canopy<R>node, String name,
48+
HashFunction<D, R, C, T> function) {
49+
if (node.getFunction() != null && node.getFunction().equals(function)
50+
&& node.context.fieldName.equals(name)) {
51+
return true;
52+
}
53+
return false;
54+
}
55+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package zingg.common.core.block;
2+
3+
public class HashFunctionUtilityFactory<D, R, C, T> {
4+
5+
public IHashFunctionUtility<D, R, C, T> getHashFunctionUtility(HashUtility hashUtility) {
6+
7+
if (HashUtility.DEFAULT.equals(hashUtility)) {
8+
return new DefaultHashFunctionUtility<D, R, C, T>();
9+
}
10+
return new CacheBasedHashFunctionUtility<D, R, C, T>();
11+
}
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package zingg.common.core.block;
2+
3+
public enum HashUtility {
4+
DEFAULT,
5+
CACHED
6+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package zingg.common.core.block;
2+
3+
import zingg.common.client.FieldDefinition;
4+
import zingg.common.core.hash.HashFunction;
5+
6+
public interface IHashFunctionUtility<D, R, C, T> {
7+
boolean isHashFunctionUsed(FieldDefinition fieldDefinition, HashFunction<D, R, C, T> hashFunction, Tree<Canopy<R>> tree, Canopy<R>node);
8+
9+
void addHashFunctionIfRequired(Canopy<R> node);
10+
11+
void removeHashFunctionIfRequired(Canopy<R> node);
12+
}

‎common/core/src/main/java/zingg/common/core/executor/Trainer.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import zingg.common.client.util.ColName;
99
import zingg.common.client.util.ColValues;
1010
import zingg.common.core.block.Canopy;
11+
import zingg.common.core.block.HashUtility;
1112
import zingg.common.core.block.Tree;
1213
import zingg.common.core.model.Model;
1314
import zingg.common.core.util.Analytics;
@@ -42,7 +43,7 @@ public void execute() throws ZinggClientException {
4243
ZFrame<D,R,C> testData = getStopWords().preprocessForStopWords(testDataOriginal);
4344

4445
Tree<Canopy<R>> blockingTree = getBlockingTreeUtil().createBlockingTreeFromSample(testData, positives, 0.5,
45-
-1, args, getHashUtil().getHashFunctionList());
46+
-1, args, getHashUtil().getHashFunctionList(), HashUtility.CACHED);
4647
if (blockingTree == null || blockingTree.getSubTrees() == null) {
4748
LOG.warn("Seems like no indexing rules have been learnt");
4849
}

‎common/core/src/main/java/zingg/common/core/executor/TrainingDataFinder.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import zingg.common.client.util.ColName;
1414
import zingg.common.client.util.ColValues;
1515
import zingg.common.core.block.Canopy;
16+
import zingg.common.core.block.HashUtility;
1617
import zingg.common.core.block.Tree;
1718
import zingg.common.core.model.Model;
1819
import zingg.common.core.preprocess.StopWordsRemover;
@@ -87,7 +88,7 @@ public void execute() throws ZinggClientException {
8788

8889
ZFrame<D,R,C> sample = getStopWords().preprocessForStopWords(sampleOrginal);
8990

90-
Tree<Canopy<R>> tree = getBlockingTreeUtil().createBlockingTree(sample, posPairs, 1, -1, args, getHashUtil().getHashFunctionList());
91+
Tree<Canopy<R>> tree = getBlockingTreeUtil().createBlockingTree(sample, posPairs, 1, -1, args, getHashUtil().getHashFunctionList(), HashUtility.CACHED);
9192
//tree.print(2);
9293
ZFrame<D,R,C> blocked = getBlockingTreeUtil().getBlockHashes(sample, tree);
9394

‎common/core/src/main/java/zingg/common/core/util/BlockingTreeUtil.java

+9-8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import zingg.common.client.util.Util;
1717
import zingg.common.core.block.Block;
1818
import zingg.common.core.block.Canopy;
19+
import zingg.common.core.block.HashUtility;
1920
import zingg.common.core.block.Tree;
2021
import zingg.common.core.hash.HashFunction;
2122

@@ -36,13 +37,13 @@ public void setPipeUtil(PipeUtilBase<S, D, R, C> pipeUtil) {
3637

3738

3839
public abstract Block<D,R,C,T> getBlock(ZFrame<D,R,C> sample, ZFrame<D,R,C> positives,
39-
ListMap<T, HashFunction<D,R,C,T>>hashFunctions, long blockSize);
40+
ListMap<T, HashFunction<D,R,C,T>>hashFunctions, long blockSize, HashUtility hashUtility);
4041

4142

42-
public Tree<Canopy<R>> createBlockingTree(ZFrame<D,R,C> testData,
43-
ZFrame<D,R,C> positives, double sampleFraction, long blockSize,
44-
IArguments args,
45-
ListMap<T, HashFunction<D,R,C,T>> hashFunctions) throws Exception, ZinggClientException {
43+
public Tree<Canopy<R>> createBlockingTree(ZFrame<D,R,C> testData,
44+
ZFrame<D,R,C> positives, double sampleFraction, long blockSize,
45+
IArguments args,
46+
ListMap<T, HashFunction<D,R,C,T>> hashFunctions, HashUtility hashUtility) throws Exception, ZinggClientException {
4647
ZFrame<D,R,C> sample = testData.sample(false, sampleFraction);
4748
sample = sample.cache();
4849
long totalCount = sample.count();
@@ -54,7 +55,7 @@ public Tree<Canopy<R>> createBlockingTree(ZFrame<D,R,C> testData,
5455
LOG.info("Learning indexing rules for block size " + blockSize);
5556

5657
positives = positives.coalesce(1);
57-
Block<D,R,C,T> cblock = getBlock(sample, positives, hashFunctions, blockSize);
58+
Block<D,R,C,T> cblock = getBlock(sample, positives, hashFunctions, blockSize, hashUtility);
5859
Canopy<R> root = new Canopy<R>(sample.collectAsList(), positives.collectAsList());
5960

6061
List<FieldDefinition> fd = new ArrayList<FieldDefinition> ();
@@ -78,9 +79,9 @@ public Tree<Canopy<R>> createBlockingTree(ZFrame<D,R,C> testData,
7879

7980
public Tree<Canopy<R>> createBlockingTreeFromSample(ZFrame<D,R,C> testData,
8081
ZFrame<D,R,C> positives, double sampleFraction, long blockSize, IArguments args,
81-
ListMap hashFunctions) throws Exception, ZinggClientException {
82+
ListMap hashFunctions, HashUtility hashUtility) throws Exception, ZinggClientException {
8283
ZFrame<D,R,C> sample = testData.sample(false, sampleFraction);
83-
return createBlockingTree(sample, positives, sampleFraction, blockSize, args, hashFunctions);
84+
return createBlockingTree(sample, positives, sampleFraction, blockSize, args, hashFunctions, hashUtility);
8485
}
8586

8687
public void writeBlockingTree(Tree<Canopy<R>> blockingTree, IArguments args) throws Exception, ZinggClientException {

‎common/core/src/test/java/zingg/common/core/block/TestBlockBase.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public void testTree() throws Throwable {
4343
IArguments args = getArguments();
4444

4545
Tree<Canopy<R>> blockingTree = blockingTreeUtil.createBlockingTreeFromSample(zFrameEvent, zFrameEventCluster, 0.5, -1,
46-
args, hashUtil.getHashFunctionList());
46+
args, hashUtil.getHashFunctionList(), HashUtility.CACHED);
4747

4848
// primary deciding is unique year so identityInteger should have been picked
4949
Canopy<R> head = blockingTree.getHead();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package zingg.common.core.block;
2+
3+
import org.junit.jupiter.api.Assertions;
4+
import org.junit.jupiter.api.Test;
5+
import zingg.common.client.Arguments;
6+
import zingg.common.client.ArgumentsUtil;
7+
import zingg.common.client.IArguments;
8+
import zingg.common.client.ZFrame;
9+
import zingg.common.client.ZinggClientException;
10+
import zingg.common.client.util.DFObjectUtil;
11+
import zingg.common.core.block.dataUtility.CsvReader;
12+
import zingg.common.core.block.dataUtility.DataUtility;
13+
import zingg.common.core.block.model.Customer;
14+
import zingg.common.core.block.model.CustomerDupe;
15+
import zingg.common.core.util.BlockingTreeUtil;
16+
import zingg.common.core.util.HashUtil;
17+
18+
import java.util.Iterator;
19+
import java.util.List;
20+
import java.util.Objects;
21+
22+
import static java.lang.Math.max;
23+
24+
public abstract class TestBlockingTreeUtil<S, D, R, C, T> {
25+
26+
protected String TEST_DATA_BASE_LOCATION;
27+
private int maxDepth = 1, totalNodes = 0;
28+
private static String TEST_FILE = "test.csv";
29+
private static String CONFIG_FILE = "config.json";
30+
private final DataUtility dataUtility;
31+
32+
public TestBlockingTreeUtil() {
33+
setTestDataBaseLocation();
34+
this.dataUtility = new DataUtility(new CsvReader());
35+
}
36+
37+
@Test
38+
public void testSameBlockingTreeWithoutVariance() throws Exception, ZinggClientException {
39+
List<Customer> testCustomers = dataUtility.getCustomers(TEST_DATA_BASE_LOCATION + "/" + TEST_FILE);
40+
//setting variance as false
41+
List<CustomerDupe> testCustomerDupes = dataUtility.getCustomerDupes(TEST_DATA_BASE_LOCATION + "/" + TEST_FILE, false);
42+
DFObjectUtil<S, D, R, C> dfObjectUtil = getDFObjectUtil();
43+
44+
ZFrame<D, R, C> zFrameTest = dfObjectUtil.getDFFromObjectList(testCustomers, Customer.class);
45+
ZFrame<D, R, C> zFramePositives = dfObjectUtil.getDFFromObjectList(testCustomerDupes, CustomerDupe.class);
46+
47+
testSameBlockingTree(zFrameTest, zFramePositives);
48+
}
49+
50+
@Test
51+
public void testSameBlockingTreeWithVariance() throws Exception, ZinggClientException {
52+
List<Customer> testCustomers = dataUtility.getCustomers(TEST_DATA_BASE_LOCATION + "/" + TEST_FILE);
53+
//setting variance as true
54+
List<CustomerDupe> testCustomerDupes = dataUtility.getCustomerDupes(TEST_DATA_BASE_LOCATION + "/" + TEST_FILE, true);
55+
DFObjectUtil<S, D, R, C> dfObjectUtil = getDFObjectUtil();
56+
57+
ZFrame<D, R, C> zFrameTest = dfObjectUtil.getDFFromObjectList(testCustomers, Customer.class);
58+
ZFrame<D, R, C> zFramePositives = dfObjectUtil.getDFFromObjectList(testCustomerDupes, CustomerDupe.class);
59+
60+
testSameBlockingTree(zFrameTest, zFramePositives);
61+
}
62+
63+
public void testSameBlockingTree(ZFrame<D, R, C> zFrameTest, ZFrame<D, R, C> zFramePositives) throws Exception, ZinggClientException {
64+
setTestDataBaseLocation();
65+
BlockingTreeUtil<S, D, R, C, T> blockingTreeUtil = getBlockingTreeUtil();
66+
HashUtil<S, D, R, C, T> hashUtil = getHashUtil();
67+
68+
69+
IArguments args = new ArgumentsUtil(Arguments.class).createArgumentsFromJSON(
70+
TEST_DATA_BASE_LOCATION + "/" + CONFIG_FILE,
71+
"");
72+
args.setBlockSize(8);
73+
74+
long ts = System.currentTimeMillis();
75+
Tree<Canopy<R>> blockingTreeOptimized = blockingTreeUtil.createBlockingTree(zFrameTest, zFramePositives, 1, -1,
76+
args, hashUtil.getHashFunctionList(), HashUtility.CACHED);
77+
System.out.println("************ time taken to create optimized blocking tree ************ " + (System.currentTimeMillis() - ts));
78+
79+
ts = System.currentTimeMillis();
80+
Tree<Canopy<R>> blockingTreeDefault = blockingTreeUtil.createBlockingTree(zFrameTest, zFramePositives, 1, -1,
81+
args, hashUtil.getHashFunctionList(), HashUtility.DEFAULT);
82+
System.out.println("************ time taken to create blocking tree ************ " + (System.currentTimeMillis() - ts));
83+
84+
int depth = 1;
85+
//assert both the trees are equal
86+
Assertions.assertTrue(dfsSameTreeValidation(blockingTreeDefault, blockingTreeOptimized, depth));
87+
88+
System.out.println("-------- max depth of trees -------- " + maxDepth);
89+
System.out.println("-------- total nodes in a trees -------- " + totalNodes);
90+
}
91+
92+
93+
private boolean dfsSameTreeValidation(Tree<Canopy<R>> node1, Tree<Canopy<R>> node2, int depth) {
94+
totalNodes++;
95+
maxDepth = max(maxDepth, depth);
96+
97+
//if both the node1 and node2 are null, return true
98+
if(node1 == null && node2 == null){
99+
return true;
100+
}
101+
//if only one of node1 or node2 is null, return false
102+
if(node1 == null || node2 == null){
103+
return false;
104+
}
105+
106+
if (!performValidationOnNode1AndNode2(node1, node2)) {
107+
return false;
108+
}
109+
110+
Iterator<Tree<Canopy<R>>> canopyIterator1 = node1.getSubTrees().iterator();
111+
Iterator<Tree<Canopy<R>>> canopyIterator2 = node2.getSubTrees().iterator();
112+
113+
boolean isEqual = true;
114+
115+
//recurse through sub-trees
116+
while (canopyIterator1.hasNext() && canopyIterator2.hasNext()) {
117+
isEqual &= dfsSameTreeValidation(canopyIterator1.next(), canopyIterator2.next(), depth + 1);
118+
}
119+
120+
return isEqual;
121+
}
122+
123+
124+
private boolean performValidationOnNode1AndNode2(Tree<Canopy<R>> node1, Tree<Canopy<R>> node2) {
125+
boolean functionEqual = isNodeFunctionEqual(node1.getHead(), node2.getHead());
126+
boolean contextEqual = isNodeContextEqual(node1.getHead(), node2.getHead());
127+
boolean hashEqual = isNodeHashEqual(node1.getHead(), node2.getHead());
128+
boolean subtreeSizeEqual = isNodeSubTreesSizeEqual(node1, node2);
129+
130+
return functionEqual && contextEqual && hashEqual && subtreeSizeEqual;
131+
}
132+
private boolean isNodeFunctionEqual(Canopy<R> node1Head, Canopy<R> node2Head) {
133+
if (node1Head.getFunction() == null && node2Head.getFunction() == null) {
134+
return true;
135+
} else if (node1Head.getFunction() == null || node2Head.getFunction() == null) {
136+
return false;
137+
} else {
138+
return Objects.equals(node1Head.getFunction().getName(), node2Head.getFunction().getName());
139+
}
140+
}
141+
142+
private boolean isNodeHashEqual(Canopy<R> node1Head, Canopy<R> node2Head) {
143+
return Objects.equals(node1Head.getHash(), node2Head.getHash());
144+
}
145+
146+
private boolean isNodeContextEqual(Canopy<R> node1Head, Canopy<R> node2Head) {
147+
148+
if (node1Head.getContext() == null && node2Head.getContext() == null) {
149+
return true;
150+
} else if (node1Head.getContext() == null || node2Head.getContext() == null) {
151+
return false;
152+
} else {
153+
return Objects.equals(node1Head.getContext().getName(), node2Head.getContext().getName());
154+
}
155+
}
156+
157+
private boolean isNodeSubTreesSizeEqual(Tree<Canopy<R>> node1, Tree<Canopy<R>> node2) {
158+
return node1.getSubTrees().size() == node2.getSubTrees().size();
159+
}
160+
161+
162+
protected abstract DFObjectUtil<S, D, R, C> getDFObjectUtil();
163+
protected abstract BlockingTreeUtil<S, D, R, C, T> getBlockingTreeUtil();
164+
protected abstract HashUtil<S, D, R, C, T> getHashUtil();
165+
protected abstract void setTestDataBaseLocation();
166+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package zingg.common.core.block.dataUtility;
2+
3+
import com.opencsv.CSVReader;
4+
import com.opencsv.CSVReaderBuilder;
5+
import com.opencsv.exceptions.CsvException;
6+
7+
import java.io.FileNotFoundException;
8+
import java.io.FileReader;
9+
import java.io.IOException;
10+
import java.util.List;
11+
12+
public class CsvReader implements DataReader{
13+
14+
@Override
15+
public List<String[]> readDataFromSource(String source) throws IOException, CsvException {
16+
CSVReader csvReader = getCSVReader(source);
17+
List<String[]> allData = csvReader.readAll();
18+
return allData;
19+
}
20+
21+
22+
private CSVReader getCSVReader(String source) throws FileNotFoundException {
23+
FileReader filereader = new FileReader(source);
24+
com.opencsv.CSVReader csvReader = new CSVReaderBuilder(filereader)
25+
.withSkipLines(1)
26+
.build();
27+
return csvReader;
28+
}
29+
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package zingg.common.core.block.dataUtility;
2+
3+
import com.opencsv.exceptions.CsvException;
4+
5+
import java.io.IOException;
6+
import java.util.List;
7+
8+
public interface DataReader {
9+
List<String[]> readDataFromSource(String source) throws IOException, CsvException;
10+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package zingg.common.core.block.dataUtility;
2+
3+
import com.opencsv.exceptions.CsvException;
4+
import zingg.common.core.block.model.Customer;
5+
import zingg.common.core.block.model.CustomerDupe;
6+
7+
import java.io.IOException;
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
11+
public class DataUtility {
12+
13+
private final DataReader dataReader;
14+
15+
public DataUtility(DataReader dataReader) {
16+
this.dataReader = dataReader;
17+
}
18+
19+
public List<CustomerDupe> getCustomerDupes(String source, boolean varianceAdded) throws IOException, CsvException {
20+
21+
List<CustomerDupe> testCustomerDupes = new ArrayList<>();
22+
23+
List<String[]> allData = dataReader.readDataFromSource(source);
24+
for (String[] row : allData) {
25+
String[] dupe = new String[2 * row.length];
26+
System.arraycopy(row, 0, dupe, 0, row.length);
27+
String[] sideRow;
28+
if (varianceAdded) {
29+
sideRow = getVarianceAddedRow(row);
30+
} else {
31+
sideRow = getNonVarianceAddedRow(row);
32+
}
33+
System.arraycopy(sideRow, 0, dupe, sideRow.length, sideRow.length);
34+
testCustomerDupes.add(new CustomerDupe(dupe));
35+
}
36+
return testCustomerDupes;
37+
}
38+
39+
40+
public List<Customer> getCustomers(String source) throws IOException, CsvException {
41+
42+
List<Customer> testCustomers = new ArrayList<>();
43+
44+
List<String[]> allData = dataReader.readDataFromSource(source);
45+
for (String[] row : allData) {
46+
testCustomers.add(new Customer(row));
47+
}
48+
return testCustomers;
49+
}
50+
51+
private String[] getVarianceAddedRow(String[] row) {
52+
String[] varianceAddedRow = new String[row.length];
53+
varianceAddedRow[0] = row[0];
54+
for(int idx = 1; idx < row.length; idx++) {
55+
varianceAddedRow[idx] = "v_" + row[idx] + "_v";
56+
}
57+
58+
return varianceAddedRow;
59+
}
60+
61+
private String[] getNonVarianceAddedRow(String[] row) {
62+
return row;
63+
}
64+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package zingg.common.core.block.model;
2+
3+
import java.util.List;
4+
5+
public class Customer {
6+
String id;
7+
String fname;
8+
String lname;
9+
String stNo;
10+
String add1;
11+
String add2;
12+
String city;
13+
String areacode;
14+
String state;
15+
String dob;
16+
String ssn;
17+
18+
public Customer(String... arguments){
19+
List<String> argumentsList = List.of(arguments);
20+
21+
this.id = argumentsList.get(0);
22+
this.fname = argumentsList.get(1);
23+
this.lname = argumentsList.get(2);
24+
this.stNo = argumentsList.get(3);
25+
this.add1 = argumentsList.get(4);
26+
this.add2 = argumentsList.get(5);
27+
this.city = argumentsList.get(6);
28+
this.areacode = argumentsList.get(7);
29+
this.state = argumentsList.get(8);
30+
this.dob = argumentsList.get(9);
31+
this.ssn = argumentsList.get(10);
32+
}
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package zingg.common.core.block.model;
2+
3+
public class CustomerDupe {
4+
String id;
5+
String fname;
6+
String lname;
7+
String stNo;
8+
String add1;
9+
String add2;
10+
String city;
11+
String areacode;
12+
String state;
13+
String dob;
14+
String ssn;
15+
String z_zid;
16+
String z_fname;
17+
String z_lname;
18+
String z_stNo;
19+
String z_add1;
20+
String z_add2;
21+
String z_city;
22+
String z_areacode;
23+
String z_state;
24+
String z_dob;
25+
String z_ssn;
26+
27+
public CustomerDupe(String... arguments) {
28+
this.id = arguments[0];
29+
this.fname = arguments[1];
30+
this.lname = arguments[2];
31+
this.stNo = arguments[3];
32+
this.add1 = arguments[4];
33+
this.add2 = arguments[5];
34+
this.city = arguments[6];
35+
this.state = arguments[7];
36+
this.areacode = arguments[8];
37+
this.dob = arguments[9];
38+
this.ssn = arguments[10];
39+
this.z_zid = arguments[11];
40+
this.z_fname = arguments[12];
41+
this.z_lname = arguments[13];
42+
this.z_stNo = arguments[14];
43+
this.z_add1 = arguments[15];
44+
this.z_add2 = arguments[16];
45+
this.z_city = arguments[17];
46+
this.z_areacode = arguments[18];
47+
this.z_state = arguments[19];
48+
this.z_dob = arguments[20];
49+
this.z_ssn = arguments[21];
50+
}
51+
}

‎examples/febrl120k/config120k.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"name":"test",
7575
"format":"csv",
7676
"props": {
77-
"location": "examples/febrl120k/test.csv.gz",
77+
"location": "examples/febrl120k/test1l20k.csv.orig",
7878
"delimiter": ",",
7979
"header":false
8080
},

‎pom.xml

+5
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@
9393
</repositories>
9494

9595
<dependencies>
96+
<dependency>
97+
<groupId>com.opencsv</groupId>
98+
<artifactId>opencsv</artifactId>
99+
<version>5.9</version>
100+
</dependency>
96101
<dependency>
97102
<groupId>org.junit.jupiter</groupId>
98103
<artifactId>junit-jupiter-engine</artifactId>

‎spark/core/src/main/java/zingg/spark/core/block/SparkBlock.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import zingg.common.client.ZFrame;
99
import zingg.common.client.util.ListMap;
1010
import zingg.common.core.block.Block;
11+
import zingg.common.core.block.HashUtility;
1112
import zingg.common.core.feature.FeatureFactory;
1213
import zingg.common.core.hash.HashFunction;
1314
import zingg.spark.core.feature.SparkFeatureFactory;
@@ -17,12 +18,14 @@ public class SparkBlock extends Block<Dataset<Row>, Row, Column, DataType> {
1718
private static final long serialVersionUID = 1L;
1819

1920

20-
public SparkBlock(){}
21+
public SparkBlock(HashUtility hashUtility){
22+
super(hashUtility);
23+
}
2124

2225

2326
public SparkBlock(ZFrame<Dataset<Row>, Row, Column> training, ZFrame<Dataset<Row>, Row, Column> dupes,
24-
ListMap<DataType, HashFunction<Dataset<Row>, Row, Column, DataType>> functionsMap, long maxSize) {
25-
super(training, dupes, functionsMap, maxSize);
27+
ListMap<DataType, HashFunction<Dataset<Row>, Row, Column, DataType>> functionsMap, long maxSize, HashUtility hashUtility) {
28+
super(training, dupes, functionsMap, maxSize, hashUtility);
2629
}
2730

2831
@Override

‎spark/core/src/main/java/zingg/spark/core/util/SparkBlockingTreeUtil.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import zingg.common.client.util.PipeUtilBase;
2424
import zingg.common.core.block.Block;
2525
import zingg.common.core.block.Canopy;
26+
import zingg.common.core.block.HashUtility;
2627
import zingg.common.core.block.Tree;
2728
import zingg.common.core.hash.HashFunction;
2829
import zingg.common.core.util.BlockingTreeUtil;
@@ -85,9 +86,10 @@ public Tree<Canopy<Row>> readBlockingTree(Arguments args) throws Exception, Zing
8586

8687
@Override
8788
public Block<Dataset<Row>, Row, Column, DataType> getBlock(ZFrame<Dataset<Row>, Row, Column> sample,
88-
ZFrame<Dataset<Row>, Row, Column> positives,
89-
ListMap<DataType, HashFunction<Dataset<Row>, Row, Column, DataType>> hashFunctions, long blockSize) {
89+
ZFrame<Dataset<Row>, Row, Column> positives,
90+
ListMap<DataType, HashFunction<Dataset<Row>, Row, Column, DataType>> hashFunctions,
91+
long blockSize, HashUtility hashUtility) {
9092
// TODO Auto-generated method stub
91-
return new SparkBlock(sample, positives, hashFunctions, blockSize);
93+
return new SparkBlock(sample, positives, hashFunctions, blockSize, hashUtility);
9294
}
9395
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package zingg.common.core.block;
2+
3+
import org.apache.spark.sql.Column;
4+
import org.apache.spark.sql.Dataset;
5+
import org.apache.spark.sql.Row;
6+
import org.apache.spark.sql.SparkSession;
7+
import org.apache.spark.sql.types.DataType;
8+
import org.junit.jupiter.api.extension.ExtendWith;
9+
import zingg.TestSparkBase;
10+
import zingg.common.client.util.DFObjectUtil;
11+
import zingg.common.client.util.IWithSession;
12+
import zingg.common.client.util.WithSession;
13+
import zingg.common.core.util.BlockingTreeUtil;
14+
import zingg.common.core.util.HashUtil;
15+
import zingg.spark.client.util.SparkDFObjectUtil;
16+
import zingg.spark.client.util.SparkPipeUtil;
17+
import zingg.spark.core.util.SparkBlockingTreeUtil;
18+
import zingg.spark.core.util.SparkHashUtil;
19+
20+
@ExtendWith(TestSparkBase.class)
21+
public class TestSparkBlockingTreeUtil extends TestBlockingTreeUtil<SparkSession, Dataset<Row>, Row, Column, DataType>{
22+
23+
private final IWithSession<SparkSession> withSession;
24+
25+
public TestSparkBlockingTreeUtil(SparkSession sparkSession) {
26+
withSession = new WithSession<>();
27+
withSession.setSession(sparkSession);
28+
}
29+
30+
@Override
31+
protected DFObjectUtil<SparkSession, Dataset<Row>, Row, Column> getDFObjectUtil() {
32+
return new SparkDFObjectUtil(withSession);
33+
}
34+
35+
@Override
36+
protected BlockingTreeUtil<SparkSession, Dataset<Row>, Row, Column, DataType> getBlockingTreeUtil() {
37+
return new SparkBlockingTreeUtil(withSession.getSession(), new SparkPipeUtil(withSession.getSession()));
38+
}
39+
40+
@Override
41+
protected HashUtil<SparkSession, Dataset<Row>, Row, Column, DataType> getHashUtil() {
42+
return new SparkHashUtil(withSession.getSession());
43+
}
44+
45+
@Override
46+
protected void setTestDataBaseLocation() {
47+
TEST_DATA_BASE_LOCATION = "/home/administrator/zingg/zinggOSS/examples/febrl";
48+
}
49+
}

0 commit comments

Comments
 (0)
Please sign in to comment.