Skip to content

Commit 64ad943

Browse files
authored
mapping fix for labeller
* added cache * changes * added select to make original order * added cache * Case normalize (zinggAI#1027) * added Case normalizer preprocessor * removed toLowerCase() in sim call() * fixed junits * added junits for case normalizer * added for spark driver memory in spark session builder * added log * added log * logged memory in GB * abstracted out stopWord files names * added logging * added exception * made join on both id and cluster * test not needed as we are already case normalizing at start
1 parent e04cb3c commit 64ad943

File tree

10 files changed

+22
-15
lines changed

10 files changed

+22
-15
lines changed

common/client/src/main/java/zingg/common/client/util/DSUtil.java

+13-2
Original file line numberDiff line numberDiff line change
@@ -271,17 +271,28 @@ public ZFrame<D,R,C> postprocess(ZFrame<D,R,C> actual, ZFrame<D,R,C> orig) {
271271
public ZFrame<D,R,C> postProcessLabel(ZFrame<D,R,C> updatedLabelledRecords, ZFrame<D,R,C> unmarkedRecords) {
272272
List<C> cols = new ArrayList<C>();
273273
cols.add(updatedLabelledRecords.col(ColName.ID_COL));
274+
cols.add(updatedLabelledRecords.col(ColName.CLUSTER_COLUMN));
274275

275276
String[] unmarkedRecordColumns = unmarkedRecords.columns();
276277

277278
//drop isMatch column from unMarked records
278279
//and replace with updated isMatch column
279280
cols.add(updatedLabelledRecords.col(ColName.MATCH_FLAG_COL));
280-
ZFrame<D,R,C> zFieldsFromUpdatedLabelledRecords = updatedLabelledRecords.select(cols);
281+
ZFrame<D,R,C> zFieldsFromUpdatedLabelledRecords = updatedLabelledRecords.select(cols).
282+
withColumnRenamed(ColName.ID_COL, ColName.COL_PREFIX + ColName.ID_COL).
283+
withColumnRenamed(ColName.CLUSTER_COLUMN, ColName.COL_PREFIX + ColName.CLUSTER_COLUMN);
284+
281285
unmarkedRecords = unmarkedRecords.drop(ColName.MATCH_FLAG_COL);
282286

287+
/*
288+
join on z_id and z_cluster
289+
*/
290+
C joinCondition1 = unmarkedRecords.equalTo(unmarkedRecords.col(ColName.ID_COL), zFieldsFromUpdatedLabelledRecords.col(ColName.COL_PREFIX + ColName.ID_COL));
291+
C joinCondition2 = unmarkedRecords.equalTo(unmarkedRecords.col(ColName.CLUSTER_COLUMN), zFieldsFromUpdatedLabelledRecords.col(ColName.COL_PREFIX + ColName.CLUSTER_COLUMN));
292+
C joinCondition = unmarkedRecords.and(joinCondition1, joinCondition2);
293+
283294
//we are selecting columns to bring back to original shape
284-
return unmarkedRecords.joinOnCol(zFieldsFromUpdatedLabelledRecords, ColName.ID_COL).select(unmarkedRecordColumns);
295+
return unmarkedRecords.join(zFieldsFromUpdatedLabelledRecords, joinCondition, "inner").select(unmarkedRecordColumns);
285296
}
286297

287298

common/core/src/main/java/zingg/common/core/executor/Trainer.java

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ public void execute() throws ZinggClientException {
3131
ZFrame<D,R,C> traOriginal = getDSUtil().getTraining(getPipeUtil(), args, getModelHelper());
3232
ZFrame<D,R,C> tra = preprocess(traOriginal).cache();
3333
tra = getDSUtil().joinWithItself(tra, ColName.CLUSTER_COLUMN, true);
34-
tra = tra.cache();
3534
positives = tra.filter(tra.equalTo(ColName.MATCH_FLAG_COL,ColValues.MATCH_TYPE_MATCH));
3635
negatives = tra.filter(tra.equalTo(ColName.MATCH_FLAG_COL,ColValues.MATCH_TYPE_NOT_A_MATCH));
3736

common/core/src/main/java/zingg/common/core/similarity/function/OnlyAlphabetsExactSimilarity.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public Double call(String first, String second) {
3636
if (score1 != 1.0d && score2 != 1.0d) {
3737
first = first.replaceAll("[0-9.]", "");
3838
second = second.replaceAll("[0-9.]", "");
39-
score = first.equalsIgnoreCase(second)? 1.0d : 0.0d;
39+
score = first.equals(second)? 1.0d : 0.0d;
4040
}
4141
else {
4242
score = 1.0d;

common/core/src/main/java/zingg/common/core/similarity/function/PinCodeMatchTypeFunction.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public Double call(String first, String second) {
2525
if (second == null || second.trim().length() ==0) return 1d;
2626
first = first.split("-")[0];
2727
second = second.split("-")[0];
28-
double score = first.trim().equalsIgnoreCase(second.trim()) ? 1d : 0d;
28+
double score = first.trim().equals(second.trim()) ? 1d : 0d;
2929
return score;
3030
}
3131
}

common/core/src/main/java/zingg/common/core/similarity/function/StringSimilarityDistanceFunction.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public AbstractStringDistance getDistanceFunction(){
2626
public Double call(String first, String second) {
2727
if (first == null || first.trim().length() ==0) return 1d;
2828
if (second == null || second.trim().length() ==0) return 1d;
29-
if (first.equalsIgnoreCase(second)) return 1d;
29+
if (first.equals(second)) return 1d;
3030
double score = getDistanceFunction().score(first, second);
3131
if (Double.isNaN(score)) return 0d;
3232
//LOG.warn(" score " + gap + " " + first + " " + second + " is " + score);

common/core/src/main/java/zingg/common/core/similarity/function/StringSimilarityFunction.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public StringSimilarityFunction(String name) {
2222
public Double call(String first, String second) {
2323
if (first == null || first.trim().length() ==0) return 1d;
2424
if (second == null || second.trim().length() ==0) return 1d;
25-
double score = first.trim().equalsIgnoreCase(second.trim()) ? 1d : 0d;
25+
double score = first.trim().equals(second.trim()) ? 1d : 0d;
2626
return score;
2727
}
2828

common/core/src/test/java/zingg/common/core/similarity/function/TestOnlyAlphabetsExactSimilarity.java

-5
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,4 @@ public void testDiffNoNumber() {
3333
assertEquals(0d, sim.call("I have a no number", "I have r number"));
3434
}
3535

36-
@Test
37-
public void testSameIgnoreCase() {
38-
OnlyAlphabetsExactSimilarity sim = new OnlyAlphabetsExactSimilarity();
39-
assertEquals(1d, sim.call("I have 1 number", "I HAVE 2 number"));
40-
}
4136
}

spark/core/src/main/java/zingg/spark/core/executor/SparkMatcher.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
*
2626
*
2727
*/
28-
public class SparkMatcher extends Matcher<SparkSession,Dataset<Row>,Row,Column,DataType> implements ISparkPreprocMapSupplier{
28+
public class SparkMatcher extends Matcher<SparkSession,Dataset<Row>,Row,Column,DataType> implements ISparkPreprocMapSupplier {
2929

3030

3131
private static final long serialVersionUID = 1L;

spark/core/src/main/java/zingg/spark/core/preprocess/caseNormalize/SparkCaseNormalizer.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public SparkCaseNormalizer(IContext<SparkSession, Dataset<Row>, Row, Column, Dat
3131

3232
@Override
3333
protected ZFrame<Dataset<Row>, Row, Column> applyCaseNormalizer(ZFrame<Dataset<Row>, Row, Column> incomingDataFrame, List<String> relevantFields) {
34+
String[] incomingDFColumns = incomingDataFrame.columns();
3435
Seq<String> columnsSeq = JavaConverters.asScalaIteratorConverter(relevantFields.iterator())
3536
.asScala()
3637
.toSeq();
@@ -41,6 +42,6 @@ protected ZFrame<Dataset<Row>, Row, Column> applyCaseNormalizer(ZFrame<Dataset<R
4142
Seq<Column> caseNormalizedSeq = JavaConverters.asScalaIteratorConverter(caseNormalizedValues.iterator())
4243
.asScala()
4344
.toSeq();
44-
return new SparkFrame(incomingDataFrame.df().withColumns(columnsSeq, caseNormalizedSeq));
45+
return new SparkFrame(incomingDataFrame.df().withColumns(columnsSeq, caseNormalizedSeq)).select(incomingDFColumns);
4546
}
4647
}

spark/core/src/test/java/zingg/spark/core/executor/labeller/ProgrammaticSparkLabeller.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
import zingg.common.core.executor.labeller.ProgrammaticLabeller;
1414
import zingg.spark.core.context.ZinggSparkContext;
1515
import zingg.spark.core.executor.SparkLabeller;
16+
import zingg.spark.core.preprocess.ISparkPreprocMapSupplier;
1617

17-
public class ProgrammaticSparkLabeller extends SparkLabeller {
18+
public class ProgrammaticSparkLabeller extends SparkLabeller implements ISparkPreprocMapSupplier {
1819

1920
private static final long serialVersionUID = 1L;
2021

0 commit comments

Comments
 (0)