From 05c19fcc350cf5fc3ecafd12443dc352c037d24e Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Sun, 13 Oct 2024 14:11:15 +0300 Subject: [PATCH] Refactored the evaluation of NormContinuous elements --- .../jpmml/evaluator/NormalizationUtil.java | 114 ++++++++--------- .../evaluator/NormalizationUtilTest.java | 120 +++++++++++++++--- 2 files changed, 161 insertions(+), 73 deletions(-) diff --git a/pmml-evaluator/src/main/java/org/jpmml/evaluator/NormalizationUtil.java b/pmml-evaluator/src/main/java/org/jpmml/evaluator/NormalizationUtil.java index cd85a030..f825ce32 100644 --- a/pmml-evaluator/src/main/java/org/jpmml/evaluator/NormalizationUtil.java +++ b/pmml-evaluator/src/main/java/org/jpmml/evaluator/NormalizationUtil.java @@ -20,6 +20,7 @@ import java.util.List; +import com.google.common.base.Function; import org.dmg.pmml.LinearNorm; import org.dmg.pmml.NormContinuous; import org.dmg.pmml.OutlierTreatmentMethod; @@ -54,29 +55,25 @@ public Number normalize(NormContinuous normContinuous, Number value){ public Value normalize(NormContinuous normContinuous, Value value){ List linearNorms = ensureLinearNorms(normContinuous); - LinearNorm start = linearNorms.get(0); - LinearNorm end = linearNorms.get(linearNorms.size() - 1); + LinearNorm start; + LinearNorm end; - Number startOrig = start.requireOrig(); - Number endOrig = end.requireOrig(); - - if(value.compareTo(startOrig) < 0 || value.compareTo(endOrig) > 0){ + int index = search(linearNorms, LinearNorm::requireOrig, value); + if(index < 0 || index == (linearNorms.size() - 1)){ OutlierTreatmentMethod outlierTreatmentMethod = normContinuous.getOutliers(); switch(outlierTreatmentMethod){ case AS_IS: // "Extrapolate from the first interval" - if(value.compareTo(startOrig) < 0){ + if(index < 0){ + start = linearNorms.get(0); end = linearNorms.get(1); - - endOrig = end.requireOrig(); } else // "Extrapolate from the last interval" { start = linearNorms.get(linearNorms.size() - 2); - - startOrig = start.requireOrig(); + end = linearNorms.get(linearNorms.size() - 1); } break; case AS_MISSING_VALUES: @@ -84,17 +81,17 @@ public Value normalize(NormContinuous normContinuous, Valu return null; case AS_EXTREME_VALUES: // "Map to the value of the first interval" - if(value.compareTo(startOrig) < 0){ - Number startNorm = start.requireNorm(); + if(index < 0){ + start = linearNorms.get(0); - return value.reset(startNorm); + return value.reset(start.requireNorm()); } else // "Map to the value of the last interval" { - Number endNorm = end.requireNorm(); + end = linearNorms.get(linearNorms.size() - 1); - return value.reset(endNorm); + return value.reset(end.requireNorm()); } default: throw new UnsupportedAttributeException(normContinuous, outlierTreatmentMethod); @@ -102,31 +99,11 @@ public Value normalize(NormContinuous normContinuous, Valu } else { - for(int i = 1, max = (linearNorms.size() - 1); i < max; i++){ - LinearNorm linearNorm = linearNorms.get(i); - - Number orig = linearNorm.requireOrig(); - - if(value.compareTo(orig) >= 0){ - start = linearNorm; - - startOrig = orig; - } else - - if(value.compareTo(orig) <= 0){ - end = linearNorm; - - endOrig = orig; - - break; - } - } + start = linearNorms.get(index); + end = linearNorms.get(index + 1); } - Number startNorm = start.requireNorm(); - Number endNorm = end.requireNorm(); - - return value.normalize(startOrig, startNorm, endOrig, endNorm); + return value.normalize(start.requireOrig(), start.requireNorm(), end.requireOrig(), end.requireNorm()); } static @@ -142,36 +119,59 @@ public Number denormalize(NormContinuous normContinuous, Number value){ public Value denormalize(NormContinuous normContinuous, Value value){ List linearNorms = ensureLinearNorms(normContinuous); - LinearNorm start = linearNorms.get(0); - LinearNorm end = linearNorms.get(linearNorms.size() - 1); + LinearNorm start; + LinearNorm end; + + int index = search(linearNorms, LinearNorm::requireNorm, value); + if(index < 0 || index == (linearNorms.size() - 1)){ + throw new NotImplementedException(); + } else + + { + start = linearNorms.get(index); + end = linearNorms.get(index + 1); + } - Number startNorm = start.requireNorm(); - Number endNorm = end.requireNorm(); + return value.denormalize(start.requireOrig(), start.requireNorm(), end.requireOrig(), end.requireNorm()); + } + + static + int search(List linearNorms, Function thresholdFunction, Value value){ - for(int i = 1, max = (linearNorms.size() - 1); i < max; i++){ + for(int i = 0, max = linearNorms.size(); i < max; i++){ LinearNorm linearNorm = linearNorms.get(i); - Number norm = linearNorm.requireNorm(); + Number threshold = thresholdFunction.apply(linearNorm); - if(value.compareTo(norm) >= 0){ - start = linearNorm; + if(value.compareTo(threshold) >= 0){ - startNorm = norm; - } else + if(i < (max - 1)){ + LinearNorm nextLinearNorm = linearNorms.get(i + 1); - if(value.compareTo(norm) <= 0){ - end = linearNorm; + Number nextThreshold = thresholdFunction.apply(nextLinearNorm); - endNorm = norm; + // Assume a closed-closed range, rather than a closed-open range. + // If the value matches some threshold value exactly, + // then it does not matter which bin (ie. this or the next) is used for interpolation. + if(value.compareTo(nextThreshold) <= 0){ + return i; + } + + continue; + } else - break; + // The last element + { + return i; + } + } else + + { + return -1; } } - Number startOrig = start.requireOrig(); - Number endOrig = end.requireOrig(); - - return value.denormalize(startOrig, startNorm, endOrig, endNorm); + throw new IllegalArgumentException(); } static diff --git a/pmml-evaluator/src/test/java/org/jpmml/evaluator/NormalizationUtilTest.java b/pmml-evaluator/src/test/java/org/jpmml/evaluator/NormalizationUtilTest.java index 61602bce..8bff0567 100644 --- a/pmml-evaluator/src/test/java/org/jpmml/evaluator/NormalizationUtilTest.java +++ b/pmml-evaluator/src/test/java/org/jpmml/evaluator/NormalizationUtilTest.java @@ -18,6 +18,9 @@ */ package org.jpmml.evaluator; +import java.util.ArrayList; +import java.util.List; + import org.dmg.pmml.LinearNorm; import org.dmg.pmml.NormContinuous; import org.dmg.pmml.OutlierTreatmentMethod; @@ -25,6 +28,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; public class NormalizationUtilTest implements Deltas { @@ -32,40 +36,119 @@ public class NormalizationUtilTest implements Deltas { public void normalize(){ NormContinuous normContinuous = createNormContinuous(); - assertEquals(BEGIN[1], (double)NormalizationUtil.normalize(normContinuous, BEGIN[0]), DOUBLE_EXACT); - assertEquals(interpolate(1.212d, BEGIN, MIDPOINT), (double)NormalizationUtil.normalize(normContinuous, 1.212d), DOUBLE_EXACT); - assertEquals(MIDPOINT[1], (double)NormalizationUtil.normalize(normContinuous, MIDPOINT[0]), DOUBLE_EXACT); - assertEquals(interpolate(6.5d, MIDPOINT, END), (double)NormalizationUtil.normalize(normContinuous, 6.5d), DOUBLE_EXACT); - assertEquals(END[1], (double)NormalizationUtil.normalize(normContinuous, END[0]), DOUBLE_EXACT); + assertEquals(BEGIN[1], normalize(normContinuous, BEGIN[0]), DOUBLE_EXACT); + assertEquals(interpolate(1.212d, BEGIN, MIDPOINT), normalize(normContinuous, 1.212d), DOUBLE_EXACT); + assertEquals(MIDPOINT[1], normalize(normContinuous, MIDPOINT[0]), DOUBLE_EXACT); + assertEquals(interpolate(6.5d, MIDPOINT, END), normalize(normContinuous, 6.5d), DOUBLE_EXACT); + assertEquals(END[1], normalize(normContinuous, END[0]), DOUBLE_EXACT); } @Test public void normalizeOutliers(){ NormContinuous normContinuous = createNormContinuous(); - assertEquals(interpolate(-1d, BEGIN, MIDPOINT), (double)NormalizationUtil.normalize(normContinuous, -1d), DOUBLE_EXACT); - assertEquals(interpolate(12.2d, MIDPOINT, END), (double)NormalizationUtil.normalize(normContinuous, 12.2d), DOUBLE_EXACT); + assertEquals(interpolate(-1d, BEGIN, MIDPOINT), normalize(normContinuous, -1d), DOUBLE_EXACT); + assertEquals(interpolate(12.2d, MIDPOINT, END), normalize(normContinuous, 12.2d), DOUBLE_EXACT); normContinuous.setOutliers(OutlierTreatmentMethod.AS_MISSING_VALUES); - assertNull(NormalizationUtil.normalize(normContinuous, -1d)); - assertNull(NormalizationUtil.normalize(normContinuous, 12.2d)); + assertNull(normalize(normContinuous, -1d)); + assertNull(normalize(normContinuous, 12.2d)); normContinuous.setOutliers(OutlierTreatmentMethod.AS_EXTREME_VALUES); - assertEquals(BEGIN[1], (double)NormalizationUtil.normalize(normContinuous, -1d), DOUBLE_EXACT); - assertEquals(END[1], (double)NormalizationUtil.normalize(normContinuous, 12.2d), DOUBLE_EXACT); + assertEquals(BEGIN[1], normalize(normContinuous, -1d), DOUBLE_EXACT); + assertEquals(END[1], normalize(normContinuous, 12.2d), DOUBLE_EXACT); } @Test public void denormalize(){ NormContinuous normContinuous = createNormContinuous(); - assertEquals(BEGIN[0], (double)NormalizationUtil.denormalize(normContinuous, BEGIN[1]), DOUBLE_EXACT); - assertEquals(0.3d, (double)NormalizationUtil.denormalize(normContinuous, interpolate(0.3d, BEGIN, MIDPOINT)), DOUBLE_EXACT); - assertEquals(MIDPOINT[0], (double)NormalizationUtil.denormalize(normContinuous, MIDPOINT[1]), DOUBLE_EXACT); - assertEquals(7.123d, (double)NormalizationUtil.denormalize(normContinuous, interpolate(7.123d, MIDPOINT, END)), DOUBLE_EXACT); - assertEquals(END[0], (double)NormalizationUtil.denormalize(normContinuous, END[1]), DOUBLE_EXACT); + try { + denormalize(normContinuous, -0.5d); + + fail(); + } catch(NotImplementedException nie){ + // Ignored + } + + assertEquals(BEGIN[0], denormalize(normContinuous, BEGIN[1]), DOUBLE_EXACT); + assertEquals(0.3d, denormalize(normContinuous, interpolate(0.3d, BEGIN, MIDPOINT)), DOUBLE_EXACT); + assertEquals(MIDPOINT[0], denormalize(normContinuous, MIDPOINT[1]), DOUBLE_EXACT); + assertEquals(7.123d, denormalize(normContinuous, interpolate(7.123d, MIDPOINT, END)), DOUBLE_EXACT); + assertEquals(END[0], denormalize(normContinuous, END[1]), DOUBLE_EXACT); + + try { + denormalize(normContinuous, 1.5d); + + fail(); + } catch(NotImplementedException nie){ + // Ignored + } + } + + @Test + public void standardize(){ + double mu = 1.5; + double stdev = Math.sqrt(2d); + + NormContinuous normContinuous = new NormContinuous("x", null) + .setOutliers(OutlierTreatmentMethod.AS_IS) + .addLinearNorms( + new LinearNorm(0d, -(mu / stdev)), + new LinearNorm(mu, 0d) + ); + + assertEquals(zScore(-2d, mu, stdev), normalize(normContinuous, -2d), DOUBLE_EXACT); + assertEquals(zScore(-1d, mu, stdev), normalize(normContinuous, -1d), DOUBLE_EXACT); + assertEquals(zScore(0d, mu, stdev), normalize(normContinuous, 0d), DOUBLE_EXACT); + assertEquals(zScore(1d, mu, stdev), normalize(normContinuous, 1d), DOUBLE_EXACT); + assertEquals(zScore(2d, mu, stdev), normalize(normContinuous, 2d), DOUBLE_EXACT); + + assertEquals(1d, denormalize(normContinuous, zScore(1d, mu, stdev)), DOUBLE_EXACT); + } + + @Test + public void search(){ + List linearNorms = new ArrayList<>(); + + linearNorms.add(new LinearNorm(0d, null)); + linearNorms.add(new LinearNorm(1d, null)); + + assertEquals(-1, search(linearNorms, -1d)); + assertEquals(0, search(linearNorms, 0d)); + assertEquals(0, search(linearNorms, 1d)); + assertEquals(1, search(linearNorms, 2d)); + + linearNorms.add(new LinearNorm(2d, null)); + + assertEquals(-1, search(linearNorms, -1d)); + assertEquals(0, search(linearNorms, 1d)); + assertEquals(1, search(linearNorms, 2d)); + assertEquals(2, search(linearNorms, 3d)); + + linearNorms.add(new LinearNorm(3d, null)); + + assertEquals(-1, search(linearNorms,-1d)); + assertEquals(1, search(linearNorms, 2d)); + assertEquals(2, search(linearNorms, 3d)); + assertEquals(3, search(linearNorms, 4d)); + } + + static + private Double normalize(NormContinuous normContinuous, double value){ + return (Double)NormalizationUtil.normalize(normContinuous, value); + } + + static + private Double denormalize(NormContinuous normContinuous, double value){ + return (Double)NormalizationUtil.denormalize(normContinuous, value); + } + + static + private int search(List linearNorms, double value){ + return NormalizationUtil.search(linearNorms, LinearNorm::requireOrig, new DoubleValue(value)); } static @@ -73,6 +156,11 @@ private double interpolate(double x, double[] begin, double[] end){ return begin[1] + (x - begin[0]) / (end[0] - begin[0]) * (end[1] - begin[1]); } + static + private double zScore(double x, double mu, double stdev){ + return (x - mu) / stdev; + } + static private NormContinuous createNormContinuous(){ NormContinuous result = new NormContinuous("x", null)