Skip to content

Commit

Permalink
Optimized the evaluation of NormContinuous elements
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Oct 13, 2024
1 parent 05c19fc commit 4673e06
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public <V extends Number> Value<V> normalize(NormContinuous normContinuous, Valu
LinearNorm start;
LinearNorm end;

int index = search(linearNorms, LinearNorm::requireOrig, value);
int index = binarySearch(linearNorms, LinearNorm::requireOrig, value);
if(index < 0 || index == (linearNorms.size() - 1)){
OutlierTreatmentMethod outlierTreatmentMethod = normContinuous.getOutliers();

Expand Down Expand Up @@ -122,7 +122,7 @@ public <V extends Number> Value<V> denormalize(NormContinuous normContinuous, Va
LinearNorm start;
LinearNorm end;

int index = search(linearNorms, LinearNorm::requireNorm, value);
int index = binarySearch(linearNorms, LinearNorm::requireNorm, value);
if(index < 0 || index == (linearNorms.size() - 1)){
throw new NotImplementedException();
} else
Expand All @@ -136,42 +136,46 @@ public <V extends Number> Value<V> denormalize(NormContinuous normContinuous, Va
}

static
<V extends Number> int search(List<LinearNorm> linearNorms, Function<LinearNorm, Number> thresholdFunction, Value<V> value){
private <V extends Number> int binarySearch(List<LinearNorm> linearNorms, Function<LinearNorm, Number> thresholdFunction, Value<V> value){
int low = 0;
int high = linearNorms.size() - 1;

for(int i = 0, max = linearNorms.size(); i < max; i++){
LinearNorm linearNorm = linearNorms.get(i);
while(low <= high){
int mid = low + (high - low) / 2;

LinearNorm linearNorm = linearNorms.get(mid);

Number threshold = thresholdFunction.apply(linearNorm);

if(value.compareTo(threshold) >= 0){

if(i < (max - 1)){
LinearNorm nextLinearNorm = linearNorms.get(i + 1);
if(mid < (linearNorms.size() - 1)){
LinearNorm nextLinearNorm = linearNorms.get(mid + 1);

Number nextThreshold = thresholdFunction.apply(nextLinearNorm);

// Assume a closed-closed range, rather than a closed-open range.
// If the value matches some threshold value exactly,
// then it does not matter which bin (ie. this or the next) is used for interpolation.
if(value.compareTo(nextThreshold) <= 0){
return i;
return mid;
}

continue;
} else

// The last element
{
return i;
return mid;
}

low = (mid + 1);
} else

{
return -1;
high = (mid - 1);
}
}

throw new IllegalArgumentException();
return -1;
}

static
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
*/
package org.jpmml.evaluator;

import java.util.ArrayList;
import java.util.List;

import org.dmg.pmml.LinearNorm;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.OutlierTreatmentMethod;
Expand Down Expand Up @@ -109,33 +106,6 @@ public void standardize(){
assertEquals(1d, denormalize(normContinuous, zScore(1d, mu, stdev)), DOUBLE_EXACT);
}

@Test
public void search(){
List<LinearNorm> linearNorms = new ArrayList<>();

linearNorms.add(new LinearNorm(0d, null));
linearNorms.add(new LinearNorm(1d, null));

assertEquals(-1, search(linearNorms, -1d));
assertEquals(0, search(linearNorms, 0d));
assertEquals(0, search(linearNorms, 1d));
assertEquals(1, search(linearNorms, 2d));

linearNorms.add(new LinearNorm(2d, null));

assertEquals(-1, search(linearNorms, -1d));
assertEquals(0, search(linearNorms, 1d));
assertEquals(1, search(linearNorms, 2d));
assertEquals(2, search(linearNorms, 3d));

linearNorms.add(new LinearNorm(3d, null));

assertEquals(-1, search(linearNorms,-1d));
assertEquals(1, search(linearNorms, 2d));
assertEquals(2, search(linearNorms, 3d));
assertEquals(3, search(linearNorms, 4d));
}

static
private Double normalize(NormContinuous normContinuous, double value){
return (Double)NormalizationUtil.normalize(normContinuous, value);
Expand All @@ -146,11 +116,6 @@ private Double denormalize(NormContinuous normContinuous, double value){
return (Double)NormalizationUtil.denormalize(normContinuous, value);
}

static
private int search(List<LinearNorm> linearNorms, double value){
return NormalizationUtil.search(linearNorms, LinearNorm::requireOrig, new DoubleValue(value));
}

static
private double interpolate(double x, double[] begin, double[] end){
return begin[1] + (x - begin[0]) / (end[0] - begin[0]) * (end[1] - begin[1]);
Expand Down

0 comments on commit 4673e06

Please sign in to comment.