-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #110 from g-degiorgi/feature/create_cpt
added create CPT feature with interval helper methods
- Loading branch information
Showing
3 changed files
with
210 additions
and
0 deletions.
There are no files selected for viewing
125 changes: 125 additions & 0 deletions
125
src/main/java/ch/idsia/crema/preprocess/creators/CreateCPT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
package ch.idsia.crema.preprocess.creators; | ||
|
||
import ch.idsia.crema.core.Strides; | ||
import ch.idsia.crema.factor.credal.linear.interval.IntervalFactor; | ||
import ch.idsia.crema.factor.credal.linear.interval.IntervalFactorFactory; | ||
import ch.idsia.crema.utility.IndexIterator; | ||
|
||
import java.util.Arrays; | ||
|
||
public class CreateCPT { | ||
|
||
/** | ||
* Border returns the borders of the intervals | ||
* | ||
* @param lower array of lower bounds for all the variables | ||
* ex: double[] lower = new double[]{1.55, 55.0}; | ||
* @param upper array of upper bounds for all the variables | ||
* ex: double[] upper = new double[]{1.90, 115.0}; | ||
*/ | ||
public double[] borders(double[] lower, double[] upper, Op operation) { | ||
|
||
//shortcut in case the function is strictly monotone growing | ||
//double[] extremes= new double[]{operation.execute(lower[0], upper[1]), operation.execute(upper[0], lower[1])}; | ||
//return Arrays.stream(extremes).sorted().toArray(); | ||
|
||
double[] results = new double[2 * lower.length]; | ||
results[0] = operation.execute(lower[0], lower[1]); | ||
results[1] = operation.execute(lower[0], upper[1]); | ||
results[2] = operation.execute(upper[0], lower[1]); | ||
results[3] = operation.execute(upper[0], upper[1]); | ||
|
||
Arrays.sort(results); | ||
double firstElement = results[0]; | ||
double lastElement = results[results.length - 1]; | ||
|
||
return new double[]{firstElement, lastElement}; | ||
} | ||
|
||
/** | ||
* Method to create a CPT | ||
* | ||
* @param childVar variable of the child | ||
* @param parentsVars array of variables of the parents | ||
* @param childCuts array of cuts for the child | ||
* @param parentCuts array of cuts for the parents | ||
* @param operation operation to be performed with the cuts | ||
* example K(bmi|w,H) | ||
* @return IntervalFactor representing the CPT of the child given the parents | ||
*/ | ||
public IntervalFactor create(int childVar, int[] parentsVars, double[] childCuts, double[][] parentCuts, Op operation) { | ||
|
||
// root nodes creation | ||
// add child node | ||
int dimChild = childCuts.length + 1; | ||
Strides stridesChild = Strides.var(childVar, dimChild); | ||
|
||
// create domain | ||
Strides dom = Strides.empty(); | ||
for (int i = 0; i < parentsVars.length; i++) { | ||
dom = dom.and(parentsVars[i], parentCuts[i].length - 1); | ||
} | ||
// add parents nodes | ||
IntervalFactorFactory factory = IntervalFactorFactory.factory().domain(stridesChild, dom); | ||
|
||
// create iterator | ||
IndexIterator iterator = dom.getIterator(); | ||
//iterate over all possible combinations | ||
while (iterator.hasNext()) { | ||
int[] comb = iterator.getPositions().clone(); | ||
// be aware that the structure returned by the method has to be compliant with the child | ||
double[] parentIntervalLower = new double[parentCuts.length]; | ||
double[] parentIntervalUpper = new double[parentCuts.length]; | ||
|
||
for (int i = 0; i < parentCuts.length; i++) { | ||
parentIntervalLower[i] = parentCuts[i][comb[i]]; | ||
parentIntervalUpper[i] = parentCuts[i][comb[i] + 1]; | ||
} | ||
|
||
double[] intervalBorders = borders(parentIntervalLower, parentIntervalUpper, operation); | ||
//map the integers of the position of the childCuts | ||
int[] intervalNumber = Arrays.stream(intervalBorders).mapToInt(val -> whichPosition(childCuts, val)).toArray(); | ||
|
||
// the lower is set to an array of zeroes the upper is set to 1 in the position of the interval number | ||
factory.set(new double[dimChild], createUpper(intervalNumber, dimChild), comb); | ||
iterator.next(); | ||
} | ||
return factory.get(); | ||
} | ||
|
||
/** | ||
* @param interval array containing the position | ||
* @param dim dimension of the array to be generated | ||
* @return double array with 1.0 set for every interval value | ||
*/ | ||
public double[] createUpper(int[] interval, int dim) { | ||
double[] upper = new double[dim]; | ||
|
||
// we set to 1.0 all the element relative to the interval | ||
for (int ind : interval) { | ||
upper[ind - 1] = 1.0; | ||
} | ||
return upper; | ||
} | ||
|
||
/** | ||
* Method to find the interval containing the specified value | ||
* Specials: if the number is lower than the first cut, it will be placed in the first interval | ||
* if the number is higher than the last cut, it will be placed in the last interval | ||
* | ||
* @param cutsX array of cuts, typically for the discretization | ||
* @param X value that we want to place | ||
* @return integer of the interval containing x | ||
*/ | ||
//greedy | ||
public int whichPosition(double[] cutsX, double X) { | ||
int position = 0; //starts from 1 | ||
for (int i = 1; i < cutsX.length - 1; i++) { | ||
if (X <= cutsX[i]) { | ||
break; | ||
} | ||
position++; | ||
} | ||
return position + 1; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
package ch.idsia.crema.preprocess.creators; | ||
|
||
@FunctionalInterface | ||
public interface Op { | ||
double execute(double a, double b); | ||
} |
79 changes: 79 additions & 0 deletions
79
src/test/java/ch/idsia/crema/preprocess/creators/CreateCPTTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package ch.idsia.crema.preprocess.creators; | ||
|
||
import ch.idsia.crema.factor.credal.linear.interval.IntervalFactor; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertArrayEquals; | ||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
public class CreateCPTTest { | ||
|
||
private final CreateCPT creator = new CreateCPT(); | ||
|
||
@Test | ||
public void testCreate() { | ||
|
||
//BMI example | ||
double[] cutsH = new double[]{1.55, 1.60, 1.65, 1.70, 1.75, 1.80, 1.85, 1.90, 1.95, 2.00}; | ||
double[] cutsW = new double[]{55.0, 60.0, 65.0, 70.0, 75.0, 80.0, 85.0, 90.0, 95.0, 100.0, 105.0, 110.0, 115.0}; | ||
double[] cutsBMI = new double[]{0.0, 15.0, 16, 18.5, 25.0, 30.0, 35.0, 40.0, 100.0}; | ||
|
||
double[][] parents = new double[][]{cutsW, cutsH}; | ||
Op bmi = (w, h) -> w / h / h; | ||
|
||
IntervalFactor cpt = creator.create(2, new int[]{0, 1}, cutsBMI, parents, bmi); | ||
//System.out.println(cpt); | ||
|
||
// w1 and h1 | ||
assertArrayEquals(cpt.getLower(0, 0), new double[10]); | ||
assertArrayEquals(cpt.getUpper(0, 0), creator.createUpper(new int[]{4}, 10)); | ||
|
||
// w2 and h1 | ||
assertArrayEquals(cpt.getLower(1, 0), new double[10]); | ||
assertArrayEquals(cpt.getUpper(1, 0), creator.createUpper(new int[]{4, 5}, 10)); | ||
|
||
// w5 and h6 | ||
assertArrayEquals(cpt.getLower(4, 5), new double[10]); | ||
assertArrayEquals(cpt.getUpper(4, 5), creator.createUpper(new int[]{4}, 10)); | ||
|
||
// w12 and h9 | ||
assertArrayEquals(cpt.getLower(11, 8), new double[10]); | ||
assertArrayEquals(cpt.getUpper(11, 8), creator.createUpper(new int[]{5, 6}, 10)); | ||
} | ||
|
||
@Test | ||
public void testBorders() { | ||
double[] parentIntervalLower = new double[]{55.0, 1.55}; | ||
double[] parentIntervalUpper = new double[]{60.0, 1.60}; | ||
//example of the BMI | ||
Op bmi = (w, h) -> w / h / h; | ||
|
||
//bmi low = 21.48 | ||
//bmi high = 24.97 | ||
double tolerance = 0.001; | ||
double[] interval = creator.borders(parentIntervalLower, parentIntervalUpper, bmi); | ||
|
||
assertArrayEquals(new double[]{21.484, 24.973}, interval, tolerance); | ||
} | ||
|
||
@Test | ||
public void testWhichPosition() { | ||
double[] intervals = new double[]{0.0, 15.0, 16, 18.5, 25.0, 30.0, 35.0, 40.0, 100.0}; | ||
assertEquals(1, creator.whichPosition(intervals, -10)); | ||
assertEquals(1, creator.whichPosition(intervals, 11)); | ||
assertEquals(1, creator.whichPosition(intervals, 15)); | ||
assertEquals(3, creator.whichPosition(intervals, 17)); | ||
assertEquals(5, creator.whichPosition(intervals, 29)); | ||
assertEquals(8, creator.whichPosition(intervals, 41)); | ||
assertEquals(8, creator.whichPosition(intervals, 200)); | ||
} | ||
|
||
@Test | ||
public void testCreateUpper() { | ||
int[] interval = new int[]{3, 4}; | ||
double[] result = creator.createUpper(interval, 4); | ||
double[] expected = new double[]{.0, .0, 1.0, 1.0}; | ||
|
||
assertArrayEquals(result, expected); | ||
} | ||
} |