Skip to content

Commit 9c6df7d

Browse files
committed
blocking tree changes
1 parent 80a7ef0 commit 9c6df7d

File tree

16 files changed

+477
-11
lines changed

16 files changed

+477
-11
lines changed

common/core/src/main/java/zingg/common/core/block/Block.java

+21-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import java.util.Collection;
66
import java.util.List;
77
import java.util.Map;
8+
import java.util.Set;
89

910
import org.apache.commons.logging.Log;
1011
import org.apache.commons.logging.LogFactory;
@@ -21,6 +22,7 @@ public abstract class Block<D,R,C,T> implements Serializable {
2122
private static final long serialVersionUID = 1L;
2223

2324
public static final Log LOG = LogFactory.getLog(Block.class);
25+
private Set<String> hashFunctionsInCurrentNodePath;
2426

2527
protected ZFrame<D,R,C> dupes;
2628
// Class[] types;
@@ -287,8 +289,8 @@ public boolean checkFunctionInNode(Canopy<R>node, String name,
287289
return false;
288290
}
289291

290-
public boolean isFunctionUsed(Tree<Canopy<R>> tree, Canopy<R>node, String fieldName,
291-
HashFunction function) {
292+
public boolean isFunctionUsedDefault(Tree<Canopy<R>> tree, Canopy<R>node, String fieldName,
293+
HashFunction function) {
292294
// //LOG.debug("Tree " + tree);
293295
// //LOG.debug("Node " + node);
294296
// //LOG.debug("Index " + index);
@@ -318,6 +320,23 @@ public boolean isFunctionUsed(Tree<Canopy<R>> tree, Canopy<R>node, String fieldN
318320
}
319321
return isUsed;
320322
}
323+
324+
public boolean isFunctionUsed(Tree<Canopy<R>> tree, Canopy<R> node, String fieldName, HashFunction function) {
325+
326+
//default to original implementation
327+
if (hashFunctionsInCurrentNodePath == null) {
328+
return isFunctionUsedDefault(tree, node, fieldName, function);
329+
}
330+
return hashFunctionsInCurrentNodePath.contains(function.getName() + ":" + fieldName);
331+
}
332+
333+
public void setHashFunctionsInCurrentNodePath(Set<String> hashFunctionsInCurrentNodePath) {
334+
this.hashFunctionsInCurrentNodePath = hashFunctionsInCurrentNodePath;
335+
}
336+
337+
public Set<String> getHashFunctionsInCurrentNodePath() {
338+
return this.hashFunctionsInCurrentNodePath;
339+
}
321340

322341

323342
public List<Canopy<R>> getHashSuccessors(Collection<Canopy<R>> successors, Object hash) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package zingg.common.core.block.blockingTree;
2+
3+
import zingg.common.client.FieldDefinition;
4+
import zingg.common.client.ZinggClientException;
5+
import zingg.common.core.block.Block;
6+
import zingg.common.core.block.Canopy;
7+
import zingg.common.core.block.Tree;
8+
9+
import java.util.List;
10+
11+
public abstract class DefaultBockingTreeBuilder<D, R, C, T> extends Block<D, R, C, T> implements IBlockingTreeBuilder<D, R, C, T> {
12+
13+
@Override
14+
public Tree<Canopy<R>> getBlockingTree(Tree<Canopy<R>> tree, Canopy<R> parent, Canopy<R> node, List<FieldDefinition> fieldsOfInterest, Block<D, R, C, T> cblock) throws ZinggClientException, Exception {
15+
LOG.info("--------- using default blocking tree builder ---------");
16+
return cblock.getBlockingTree(tree, parent, node, fieldsOfInterest);
17+
}
18+
19+
20+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package zingg.common.core.block.blockingTree;
2+
3+
import zingg.common.client.FieldDefinition;
4+
import zingg.common.client.ZinggClientException;
5+
import zingg.common.core.block.Block;
6+
import zingg.common.core.block.Canopy;
7+
import zingg.common.core.block.Tree;
8+
9+
import java.util.List;
10+
11+
public interface IBlockingTreeBuilder<D, R, C, T> {
12+
13+
Tree<Canopy<R>> getBlockingTree(Tree<Canopy<R>> tree, Canopy<R> parent, Canopy<R> node,
14+
List<FieldDefinition> fieldsOfInterest, Block<D, R, C, T> cblock) throws Exception, ZinggClientException;
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package zingg.common.core.block.blockingTree;
2+
3+
4+
import org.apache.commons.logging.Log;
5+
import org.apache.commons.logging.LogFactory;
6+
import zingg.common.client.FieldDefinition;
7+
import zingg.common.client.ZinggClientException;
8+
import zingg.common.core.block.Block;
9+
import zingg.common.core.block.Canopy;
10+
import zingg.common.core.block.Tree;
11+
12+
import java.util.HashSet;
13+
import java.util.List;
14+
import java.util.Set;
15+
16+
public abstract class OptimizedBlockingTreeBuilder<D, R, C, T> extends Block<D, R, C, T> implements IBlockingTreeBuilder<D, R, C, T> {
17+
18+
public static final Log LOG = LogFactory.getLog(OptimizedBlockingTreeBuilder.class);
19+
20+
21+
@Override
22+
public Tree<Canopy<R>> getBlockingTree(Tree<Canopy<R>> tree, Canopy<R> parent, Canopy<R> node,
23+
List<FieldDefinition> fieldsOfInterest, Block<D, R, C, T> cblock)
24+
throws Exception, ZinggClientException {
25+
LOG.info("--------- using optimized blocking tree builder ---------");
26+
cblock.setHashFunctionsInCurrentNodePath(new HashSet<>());
27+
Tree<Canopy<R>> blockingTree = dfsToGetBlockingTree(tree, parent, node, fieldsOfInterest, cblock.getHashFunctionsInCurrentNodePath(), cblock);
28+
return blockingTree;
29+
}
30+
31+
private Tree<Canopy<R>> dfsToGetBlockingTree(Tree<Canopy<R>> tree, Canopy<R> parent, Canopy<R> node, List<FieldDefinition> fieldsOfInterest,
32+
Set<String> hashFunctionsInCurrentNodePath, Block<D, R, C, T> cblock) throws ZinggClientException, Exception {
33+
long size = node.getTrainingSize();
34+
if (size > cblock.getMaxSize() && node.getDupeN() != null && !node.getDupeN().isEmpty()) {
35+
Canopy<R> best = cblock.getBestNode(tree, parent, node, fieldsOfInterest);
36+
if (best != null) {
37+
if (tree == null && parent == null) {
38+
tree = new Tree<>(node);
39+
}
40+
traverseThroughCanopies(best, tree, node, fieldsOfInterest, hashFunctionsInCurrentNodePath, cblock);
41+
} else {
42+
node.clearBeforeSaving();
43+
}
44+
} else {
45+
if ((node.getDupeN() == null) || (node.getDupeN().isEmpty())) {
46+
LOG.warn("Ran out of training at size " + size + " for node " + node);
47+
} else {
48+
if (tree == null) {
49+
throw new ZinggClientException("Unable to create Zingg models due to insufficient data. Please run Zingg after adding more data");
50+
}
51+
}
52+
node.clearBeforeSaving();
53+
}
54+
return tree;
55+
}
56+
57+
private void traverseThroughCanopies(Canopy<R> best, Tree<Canopy<R>> tree, Canopy<R> node, List<FieldDefinition> fieldsOfInterest,
58+
Set<String> hashFunctionsInCurrentNodePath, Block<D, R, C, T> cblock) throws ZinggClientException, Exception {
59+
hashFunctionsInCurrentNodePath.add(best.getFunction().getName() + ":" + best.getContext().fieldName);
60+
best.copyTo(node);
61+
List<Canopy<R>> canopies = node.getCanopies();
62+
for (Canopy<R> n : canopies) {
63+
node.clearBeforeSaving();
64+
tree.addLeaf(node, n);
65+
dfsToGetBlockingTree(tree, node, n, fieldsOfInterest, hashFunctionsInCurrentNodePath, cblock);
66+
}
67+
hashFunctionsInCurrentNodePath.remove(best.getFunction().getName() + ":" + best.getContext().fieldName);
68+
}
69+
}

common/core/src/main/java/zingg/common/core/util/BlockingTreeUtil.java

+7-4
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717
import zingg.common.core.block.Block;
1818
import zingg.common.core.block.Canopy;
1919
import zingg.common.core.block.Tree;
20+
import zingg.common.core.block.blockingTree.IBlockingTreeBuilder;
2021
import zingg.common.core.hash.HashFunction;
2122

2223
public abstract class BlockingTreeUtil<S, D,R,C,T> {
2324

2425
public final Log LOG = LogFactory.getLog(BlockingTreeUtil.class);
25-
26+
private final IBlockingTreeBuilder<D, R, C, T> blockingTreeBuilder;
27+
28+
public BlockingTreeUtil(IBlockingTreeBuilder<D, R, C, T> blockingTreeBuilder) {
29+
this.blockingTreeBuilder = blockingTreeBuilder;
30+
}
2631

2732
private PipeUtilBase<S, D, R, C> pipeUtil;
2833

@@ -67,9 +72,7 @@ public Tree<Canopy<R>> createBlockingTree(ZFrame<D,R,C> testData,
6772
fd.add(def);
6873
}
6974
}
70-
71-
Tree<Canopy<R>> blockingTree = cblock.getBlockingTree(null, null, root,
72-
fd);
75+
Tree<Canopy<R>> blockingTree = blockingTreeBuilder.getBlockingTree(null, null, root, fd, cblock);
7376
if (LOG.isDebugEnabled()) {
7477
LOG.debug("The blocking tree is ");
7578
blockingTree.print(2);

examples/febrl/config.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
],
9090
"labelDataSampleSize" : 0.5,
9191
"numPartitions":4,
92-
"modelId": 100,
92+
"modelId": "oct_22",
9393
"zinggDir": "models"
9494

9595
}

pom.xml

+6
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@
9393
</repositories>
9494

9595
<dependencies>
96+
<dependency>
97+
<groupId>com.opencsv</groupId>
98+
<artifactId>opencsv</artifactId>
99+
<version>5.9</version>
100+
</dependency>
96101
<dependency>
97102
<groupId>org.junit.jupiter</groupId>
98103
<artifactId>junit-jupiter-engine</artifactId>
@@ -178,6 +183,7 @@
178183
<artifactId>maven-surefire-plugin</artifactId>
179184
<version>3.2.2</version>
180185
<configuration>
186+
<argLine>-Xmx9216m</argLine>
181187
<statelessTestsetReporter implementation="org.apache.maven.plugin.surefire.extensions.junit5.JUnit5Xml30StatelessReporter">
182188
<disable>false</disable>
183189
<version>3.0</version>

spark/client/src/main/java/zingg/spark/client/SparkClient.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public SparkSession getSession() {
8181
.builder()
8282
.appName("Zingg")
8383
.getOrCreate();
84-
JavaSparkContext ctx = JavaSparkContext.fromSparkContext(session.sparkContext());
84+
JavaSparkContext ctx = JavaSparkContext.fromSparkContext(s.sparkContext());
8585
JavaSparkContext.jarOfClass(IZingg.class);
8686
LOG.debug("Context " + ctx.toString());
8787
//initHashFns();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package zingg.spark.core.block.blockingTree;
2+
3+
import org.apache.spark.sql.Column;
4+
import org.apache.spark.sql.Dataset;
5+
import org.apache.spark.sql.Row;
6+
import org.apache.spark.sql.types.DataType;
7+
import zingg.common.core.block.blockingTree.DefaultBockingTreeBuilder;
8+
import zingg.common.core.feature.FeatureFactory;
9+
import zingg.spark.core.feature.SparkFeatureFactory;
10+
11+
public class SparkDefaultBlockingTreeBuilder extends DefaultBockingTreeBuilder<Dataset<Row>, Row, Column, DataType> {
12+
@Override
13+
public FeatureFactory<DataType> getFeatureFactory() {
14+
return new SparkFeatureFactory();
15+
}
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package zingg.spark.core.block.blockingTree;
2+
3+
import org.apache.spark.sql.Column;
4+
import org.apache.spark.sql.Dataset;
5+
import org.apache.spark.sql.Row;
6+
import org.apache.spark.sql.types.DataType;
7+
import zingg.common.core.block.blockingTree.OptimizedBlockingTreeBuilder;
8+
import zingg.common.core.feature.FeatureFactory;
9+
import zingg.spark.core.feature.SparkFeatureFactory;
10+
11+
public class SparkOptimizedBlockingTreeBuilder extends OptimizedBlockingTreeBuilder<Dataset<Row>, Row, Column, DataType> {
12+
@Override
13+
public FeatureFactory<DataType> getFeatureFactory() {
14+
return new SparkFeatureFactory();
15+
}
16+
17+
}

spark/core/src/main/java/zingg/spark/core/context/ZinggSparkContext.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import zingg.common.core.util.ModelUtil;
2222
import zingg.spark.client.util.SparkDSUtil;
2323
import zingg.spark.client.util.SparkPipeUtil;
24+
import zingg.spark.core.block.blockingTree.SparkDefaultBlockingTreeBuilder;
25+
import zingg.spark.core.block.blockingTree.SparkOptimizedBlockingTreeBuilder;
2426
import zingg.spark.core.util.SparkBlockingTreeUtil;
2527
import zingg.spark.core.util.SparkGraphUtil;
2628
import zingg.spark.core.util.SparkHashUtil;
@@ -68,7 +70,7 @@ public void setUtils() {
6870
setHashUtil(new SparkHashUtil(session));
6971
setGraphUtil(new SparkGraphUtil());
7072
setModelUtil(new SparkModelUtil(session));
71-
setBlockingTreeUtil(new SparkBlockingTreeUtil(session, getPipeUtil()));
73+
setBlockingTreeUtil(new SparkBlockingTreeUtil(session, getPipeUtil(), new SparkOptimizedBlockingTreeBuilder()));
7274
}
7375

7476

spark/core/src/main/java/zingg/spark/core/util/SparkBlockingTreeUtil.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import zingg.common.core.block.Block;
2525
import zingg.common.core.block.Canopy;
2626
import zingg.common.core.block.Tree;
27+
import zingg.common.core.block.blockingTree.IBlockingTreeBuilder;
2728
import zingg.common.core.hash.HashFunction;
2829
import zingg.common.core.util.BlockingTreeUtil;
2930
import zingg.spark.client.SparkFrame;
@@ -36,7 +37,8 @@ public class SparkBlockingTreeUtil extends BlockingTreeUtil<SparkSession, Datase
3637
public static final Log LOG = LogFactory.getLog(SparkBlockingTreeUtil.class);
3738
protected SparkSession spark;
3839

39-
public SparkBlockingTreeUtil(SparkSession s, PipeUtilBase pipeUtil) {
40+
public SparkBlockingTreeUtil(SparkSession s, PipeUtilBase pipeUtil, IBlockingTreeBuilder<Dataset<Row>, Row, Column, DataType> iBlockingTreeBuilder) {
41+
super(iBlockingTreeBuilder);
4042
this.spark = s;
4143
setPipeUtil(pipeUtil);
4244
}

spark/core/src/test/java/zingg/common/core/block/TestSparkBlock.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import zingg.common.client.util.IWithSession;
1212
import zingg.common.client.util.WithSession;
1313
import zingg.spark.client.util.SparkDFObjectUtil;
14+
import zingg.spark.core.block.blockingTree.SparkOptimizedBlockingTreeBuilder;
1415
import zingg.spark.core.context.ZinggSparkContext;
1516
import zingg.spark.core.util.SparkBlockingTreeUtil;
1617
import zingg.spark.core.util.SparkHashUtil;
@@ -22,7 +23,7 @@ public class TestSparkBlock extends TestBlockBase<SparkSession, Dataset<Row>, Ro
2223
public static IWithSession<SparkSession> iWithSession = new WithSession<SparkSession>();
2324

2425
public TestSparkBlock(SparkSession sparkSession) throws ZinggClientException {
25-
super(new SparkDFObjectUtil(iWithSession), new SparkHashUtil(sparkSession), new SparkBlockingTreeUtil(sparkSession, zsCTX.getPipeUtil()));
26+
super(new SparkDFObjectUtil(iWithSession), new SparkHashUtil(sparkSession), new SparkBlockingTreeUtil(sparkSession, zsCTX.getPipeUtil(), new SparkOptimizedBlockingTreeBuilder()));
2627
iWithSession.setSession(sparkSession);
2728
zsCTX.init(sparkSession);
2829
}

0 commit comments

Comments
 (0)