2
2
3
3
import org .junit .jupiter .api .Assertions ;
4
4
import org .junit .jupiter .api .Test ;
5
+ import org .mockito .MockedStatic ;
6
+ import org .mockito .Mockito ;
5
7
import zingg .common .client .Arguments ;
6
8
import zingg .common .client .ArgumentsUtil ;
9
+ import zingg .common .client .FieldDefinition ;
7
10
import zingg .common .client .IArguments ;
11
+ import zingg .common .client .MatchType ;
8
12
import zingg .common .client .ZFrame ;
9
13
import zingg .common .client .ZinggClientException ;
10
14
import zingg .common .client .util .DFObjectUtil ;
11
- import zingg .common .core . block . data . CsvReader ;
15
+ import zingg .common .client . util . ListMap ;
12
16
import zingg .common .core .block .data .DataUtility ;
13
17
import zingg .common .core .block .model .Customer ;
14
18
import zingg .common .core .block .model .CustomerDupe ;
19
+ import zingg .common .core .hash .HashFunction ;
15
20
import zingg .common .core .util .BlockingTreeUtil ;
21
+ import zingg .common .core .util .CsvReader ;
16
22
import zingg .common .core .util .HashUtil ;
23
+ import zingg .common .core .util .Heuristics ;
17
24
25
+ import java .util .ArrayList ;
18
26
import java .util .Iterator ;
19
27
import java .util .List ;
20
28
import java .util .Objects ;
21
29
22
30
import static java .lang .Math .max ;
23
31
32
+
24
33
public abstract class TestBlockingTreeUtil <S , D , R , C , T > {
25
34
26
35
protected String TEST_DATA_BASE_LOCATION ;
@@ -61,33 +70,63 @@ public void testSameBlockingTreeWithVariance() throws Exception, ZinggClientExce
61
70
testSameBlockingTree (zFrameTest , zFramePositives );
62
71
}
63
72
73
+
64
74
public void testSameBlockingTree (ZFrame <D , R , C > zFrameTest , ZFrame <D , R , C > zFramePositives ) throws Exception , ZinggClientException {
65
75
setTestDataBaseLocation ();
66
- BlockingTreeUtil <S , D , R , C , T > blockingTreeUtil = getBlockingTreeUtil ();
67
76
HashUtil <S , D , R , C , T > hashUtil = getHashUtil ();
68
77
69
-
70
78
IArguments args = new ArgumentsUtil (Arguments .class ).createArgumentsFromJSON (
71
79
TEST_DATA_BASE_LOCATION + "/" + CONFIG_FILE ,
72
80
"" );
73
81
args .setBlockSize (8 );
74
82
75
- long ts = System .currentTimeMillis ();
76
- Tree <Canopy <R >> blockingTreeOptimized = blockingTreeUtil .createBlockingTree (zFrameTest , zFramePositives , 1 , -1 ,
77
- args , hashUtil .getHashFunctionList (), HashUtility .CACHED );
78
- System .out .println ("************ time taken to create optimized blocking tree ************ " + (System .currentTimeMillis () - ts ));
79
-
80
- ts = System .currentTimeMillis ();
81
- Tree <Canopy <R >> blockingTreeDefault = blockingTreeUtil .createBlockingTree (zFrameTest , zFramePositives , 1 , -1 ,
82
- args , hashUtil .getHashFunctionList (), HashUtility .DEFAULT );
83
- System .out .println ("************ time taken to create blocking tree ************ " + (System .currentTimeMillis () - ts ));
83
+ Tree <Canopy <R >> blockingTreeOptimized = getBlockingTree (zFrameTest , zFramePositives , hashUtil , args , "cached" );
84
+ Tree <Canopy <R >> blockingTreeDefault = getBlockingTree (zFrameTest , zFramePositives , hashUtil , args , "default" );
84
85
85
86
int depth = 1 ;
86
87
//assert both the trees are equal
87
88
Assertions .assertTrue (dfsSameTreeValidation (blockingTreeDefault , blockingTreeOptimized , depth ));
88
89
89
90
System .out .println ("-------- max depth of trees -------- " + maxDepth );
90
- System .out .println ("-------- total nodes in a trees -------- " + totalNodes );
91
+ System .out .println ("-------- total nodes in a trees ---- " + totalNodes );
92
+ }
93
+
94
+
95
+ private Tree <Canopy <R >> getBlockingTree (ZFrame <D , R , C > zFrameTest , ZFrame <D , R , C > zFramePositives , HashUtil <S , D , R , C , T > hashUtil ,
96
+ IArguments args , String blockingTreeType ) throws Exception , ZinggClientException {
97
+ long ts = System .currentTimeMillis ();
98
+ Block <D , R , C , T > block ;
99
+ if ("cached" .equals (blockingTreeType )) {
100
+ block = getCachedBasedBlock (zFrameTest , zFramePositives , hashUtil , args );
101
+ } else {
102
+ block = getDefaultBlock (zFrameTest , zFramePositives , hashUtil , args );
103
+ }
104
+ Canopy <R > root = getCanopy (zFrameTest , zFramePositives , 1 );
105
+ Tree <Canopy <R >> blockingTree = block .getBlockingTree (null , null , root , getFieldDefinitions (args ));
106
+ System .out .println ("************ time taken to create " + blockingTreeType + " blocking tree ************, " + (System .currentTimeMillis () - ts ));
107
+ return blockingTree ;
108
+ }
109
+
110
+ //Override with new CacheBasedHashFunctionUtility<D, R, C, T>()
111
+ private Block <D , R , C , T > getCachedBasedBlock (ZFrame <D , R , C > zFrameTest , ZFrame <D , R , C > zFramePositives ,
112
+ HashUtil <S , D , R , C , T > hashUtil , IArguments arguments ) throws Exception {
113
+ try (MockedStatic <HashFunctionUtilityFactory > hashFunctionUtilityFactoryMock = Mockito .mockStatic (HashFunctionUtilityFactory .class )) {
114
+ hashFunctionUtilityFactoryMock .when (() -> HashFunctionUtilityFactory .getHashFunctionUtility (Mockito .any (HashUtility .class )))
115
+ .thenReturn (new CacheBasedHashFunctionUtility <D , R , C , T >());
116
+ return getBlock (zFrameTest , 1 , zFramePositives , -1 ,
117
+ hashUtil .getHashFunctionList (), arguments );
118
+ }
119
+ }
120
+
121
+ //Override with new DefaultHashFunctionUtility<>()
122
+ private Block <D , R , C , T > getDefaultBlock (ZFrame <D , R , C > zFrameTest , ZFrame <D , R , C > zFramePositives ,
123
+ HashUtil <S , D , R , C , T > hashUtil , IArguments arguments ) throws Exception {
124
+ try (MockedStatic <HashFunctionUtilityFactory > hashFunctionUtilityFactoryMock = Mockito .mockStatic (HashFunctionUtilityFactory .class )) {
125
+ hashFunctionUtilityFactoryMock .when (() -> HashFunctionUtilityFactory .getHashFunctionUtility (Mockito .any (HashUtility .class )))
126
+ .thenReturn (new DefaultHashFunctionUtility <D , R , C , T >());
127
+ return getBlock (zFrameTest , 1 , zFramePositives , -1 ,
128
+ hashUtil .getHashFunctionList (), arguments );
129
+ }
91
130
}
92
131
93
132
@@ -159,9 +198,36 @@ private boolean isNodeSubTreesSizeEqual(Tree<Canopy<R>> node1, Tree<Canopy<R>> n
159
198
return node1 .getSubTrees ().size () == node2 .getSubTrees ().size ();
160
199
}
161
200
201
+ private Block <D , R , C , T > getBlock (ZFrame <D , R , C > testData , double sampleFraction , ZFrame <D ,R ,C > positives ,
202
+ long blockSize , ListMap <T , HashFunction <D ,R ,C ,T >> hashFunctions , IArguments args ) {
203
+ ZFrame <D ,R ,C > sample = testData .sample (false , sampleFraction );
204
+ long totalCount = sample .count ();
205
+ if (blockSize == -1 ) blockSize = Heuristics .getMaxBlockSize (totalCount , args .getBlockSize ());
206
+ positives = positives .coalesce (1 );
207
+ Block <D ,R ,C ,T > cblock = getBlock (sample , positives , hashFunctions , blockSize );
208
+ return cblock ;
209
+ }
210
+
211
+ private Canopy <R > getCanopy (ZFrame <D ,R ,C > testData , ZFrame <D ,R ,C > positives , double sampleFraction ) {
212
+ ZFrame <D ,R ,C > sample = testData .sample (false , sampleFraction );
213
+ return new Canopy <R >(sample .collectAsList (), positives .collectAsList ());
214
+ }
215
+
216
+ private List <FieldDefinition > getFieldDefinitions (IArguments arguments ) {
217
+ List <FieldDefinition > fieldDefinitions = new ArrayList <FieldDefinition >();
218
+
219
+ for (FieldDefinition def : arguments .getFieldDefinition ()) {
220
+ if (! (def .getMatchType () == null || def .getMatchType ().contains (MatchType .DONT_USE ))) {
221
+ fieldDefinitions .add (def );
222
+ }
223
+ }
224
+ return fieldDefinitions ;
225
+ }
162
226
163
227
protected abstract DFObjectUtil <S , D , R , C > getDFObjectUtil ();
164
228
protected abstract BlockingTreeUtil <S , D , R , C , T > getBlockingTreeUtil ();
165
229
protected abstract HashUtil <S , D , R , C , T > getHashUtil ();
166
230
protected abstract void setTestDataBaseLocation ();
231
+ protected abstract Block <D , R , C , T > getBlock (ZFrame <D ,R ,C > sample , ZFrame <D ,R ,C > positives ,
232
+ ListMap <T , HashFunction <D ,R ,C ,T >>hashFunctions , long blockSize );
167
233
}
0 commit comments