From 9490b27bcdf3d6afa193ecc776aa378db17ec59e Mon Sep 17 00:00:00 2001 From: g-degiorgi Date: Tue, 22 Aug 2023 01:31:36 +0200 Subject: [PATCH] added create CPT feature with interval helper methods --- .../crema/preprocess/creators/CreateCPT.java | 125 ++++++++++++++++++ .../idsia/crema/preprocess/creators/Op.java | 6 + .../preprocess/creators/CreateCPTTest.java | 79 +++++++++++ 3 files changed, 210 insertions(+) create mode 100644 src/main/java/ch/idsia/crema/preprocess/creators/CreateCPT.java create mode 100644 src/main/java/ch/idsia/crema/preprocess/creators/Op.java create mode 100644 src/test/java/ch/idsia/crema/preprocess/creators/CreateCPTTest.java diff --git a/src/main/java/ch/idsia/crema/preprocess/creators/CreateCPT.java b/src/main/java/ch/idsia/crema/preprocess/creators/CreateCPT.java new file mode 100644 index 00000000..6c0a3109 --- /dev/null +++ b/src/main/java/ch/idsia/crema/preprocess/creators/CreateCPT.java @@ -0,0 +1,125 @@ +package ch.idsia.crema.preprocess.creators; + +import ch.idsia.crema.core.Strides; +import ch.idsia.crema.factor.credal.linear.interval.IntervalFactor; +import ch.idsia.crema.factor.credal.linear.interval.IntervalFactorFactory; +import ch.idsia.crema.utility.IndexIterator; + +import java.util.Arrays; + +public class CreateCPT { + + /** + * Border returns the borders of the intervals + * + * @param lower array of lower bounds for all the variables + * ex: double[] lower = new double[]{1.55, 55.0}; + * @param upper array of upper bounds for all the variables + * ex: double[] upper = new double[]{1.90, 115.0}; + */ + public double[] borders(double[] lower, double[] upper, Op operation) { + + //shortcut in case the function is strictly monotone growing + //double[] extremes= new double[]{operation.execute(lower[0], upper[1]), operation.execute(upper[0], lower[1])}; + //return Arrays.stream(extremes).sorted().toArray(); + + double[] results = new double[2 * lower.length]; + results[0] = operation.execute(lower[0], lower[1]); + results[1] = operation.execute(lower[0], upper[1]); + results[2] = operation.execute(upper[0], lower[1]); + results[3] = operation.execute(upper[0], upper[1]); + + Arrays.sort(results); + double firstElement = results[0]; + double lastElement = results[results.length - 1]; + + return new double[]{firstElement, lastElement}; + } + + /** + * Method to create a CPT + * + * @param childVar variable of the child + * @param parentsVars array of variables of the parents + * @param childCuts array of cuts for the child + * @param parentCuts array of cuts for the parents + * @param operation operation to be performed with the cuts + * example K(bmi|w,H) + * @return IntervalFactor representing the CPT of the child given the parents + */ + public IntervalFactor create(int childVar, int[] parentsVars, double[] childCuts, double[][] parentCuts, Op operation) { + + // root nodes creation + // add child node + int dimChild = childCuts.length + 1; + Strides stridesChild = Strides.var(childVar, dimChild); + + // create domain + Strides dom = Strides.empty(); + for (int i = 0; i < parentsVars.length; i++) { + dom = dom.and(parentsVars[i], parentCuts[i].length - 1); + } + // add parents nodes + IntervalFactorFactory factory = IntervalFactorFactory.factory().domain(stridesChild, dom); + + // create iterator + IndexIterator iterator = dom.getIterator(); + //iterate over all possible combinations + while (iterator.hasNext()) { + int[] comb = iterator.getPositions().clone(); + // be aware that the structure returned by the method has to be compliant with the child + double[] parentIntervalLower = new double[parentCuts.length]; + double[] parentIntervalUpper = new double[parentCuts.length]; + + for (int i = 0; i < parentCuts.length; i++) { + parentIntervalLower[i] = parentCuts[i][comb[i]]; + parentIntervalUpper[i] = parentCuts[i][comb[i] + 1]; + } + + double[] intervalBorders = borders(parentIntervalLower, parentIntervalUpper, operation); + //map the integers of the position of the childCuts + int[] intervalNumber = Arrays.stream(intervalBorders).mapToInt(val -> whichPosition(childCuts, val)).toArray(); + + // the lower is set to an array of zeroes the upper is set to 1 in the position of the interval number + factory.set(new double[dimChild], createUpper(intervalNumber, dimChild), comb); + iterator.next(); + } + return factory.get(); + } + + /** + * @param interval array containing the position + * @param dim dimension of the array to be generated + * @return double array with 1.0 set for every interval value + */ + public double[] createUpper(int[] interval, int dim) { + double[] upper = new double[dim]; + + // we set to 1.0 all the element relative to the interval + for (int ind : interval) { + upper[ind - 1] = 1.0; + } + return upper; + } + + /** + * Method to find the interval containing the specified value + * Specials: if the number is lower than the first cut, it will be placed in the first interval + * if the number is higher than the last cut, it will be placed in the last interval + * + * @param cutsX array of cuts, typically for the discretization + * @param X value that we want to place + * @return integer of the interval containing x + */ + //greedy + public int whichPosition(double[] cutsX, double X) { + int position = 0; //starts from 1 + for (int i = 1; i < cutsX.length - 1; i++) { + if (X <= cutsX[i]) { + break; + } + position++; + } + return position + 1; + } +} diff --git a/src/main/java/ch/idsia/crema/preprocess/creators/Op.java b/src/main/java/ch/idsia/crema/preprocess/creators/Op.java new file mode 100644 index 00000000..6c7bfef4 --- /dev/null +++ b/src/main/java/ch/idsia/crema/preprocess/creators/Op.java @@ -0,0 +1,6 @@ +package ch.idsia.crema.preprocess.creators; + +@FunctionalInterface +public interface Op { + double execute(double a, double b); +} diff --git a/src/test/java/ch/idsia/crema/preprocess/creators/CreateCPTTest.java b/src/test/java/ch/idsia/crema/preprocess/creators/CreateCPTTest.java new file mode 100644 index 00000000..abe46e2e --- /dev/null +++ b/src/test/java/ch/idsia/crema/preprocess/creators/CreateCPTTest.java @@ -0,0 +1,79 @@ +package ch.idsia.crema.preprocess.creators; + +import ch.idsia.crema.factor.credal.linear.interval.IntervalFactor; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class CreateCPTTest { + + private final CreateCPT creator = new CreateCPT(); + + @Test + public void testCreate() { + + //BMI example + double[] cutsH = new double[]{1.55, 1.60, 1.65, 1.70, 1.75, 1.80, 1.85, 1.90, 1.95, 2.00}; + double[] cutsW = new double[]{55.0, 60.0, 65.0, 70.0, 75.0, 80.0, 85.0, 90.0, 95.0, 100.0, 105.0, 110.0, 115.0}; + double[] cutsBMI = new double[]{0.0, 15.0, 16, 18.5, 25.0, 30.0, 35.0, 40.0, 100.0}; + + double[][] parents = new double[][]{cutsW, cutsH}; + Op bmi = (w, h) -> w / h / h; + + IntervalFactor cpt = creator.create(2, new int[]{0, 1}, cutsBMI, parents, bmi); + //System.out.println(cpt); + + // w1 and h1 + assertArrayEquals(cpt.getLower(0, 0), new double[10]); + assertArrayEquals(cpt.getUpper(0, 0), creator.createUpper(new int[]{4}, 10)); + + // w2 and h1 + assertArrayEquals(cpt.getLower(1, 0), new double[10]); + assertArrayEquals(cpt.getUpper(1, 0), creator.createUpper(new int[]{4, 5}, 10)); + + // w5 and h6 + assertArrayEquals(cpt.getLower(4, 5), new double[10]); + assertArrayEquals(cpt.getUpper(4, 5), creator.createUpper(new int[]{4}, 10)); + + // w12 and h9 + assertArrayEquals(cpt.getLower(11, 8), new double[10]); + assertArrayEquals(cpt.getUpper(11, 8), creator.createUpper(new int[]{5, 6}, 10)); + } + + @Test + public void testBorders() { + double[] parentIntervalLower = new double[]{55.0, 1.55}; + double[] parentIntervalUpper = new double[]{60.0, 1.60}; + //example of the BMI + Op bmi = (w, h) -> w / h / h; + + //bmi low = 21.48 + //bmi high = 24.97 + double tolerance = 0.001; + double[] interval = creator.borders(parentIntervalLower, parentIntervalUpper, bmi); + + assertArrayEquals(new double[]{21.484, 24.973}, interval, tolerance); + } + + @Test + public void testWhichPosition() { + double[] intervals = new double[]{0.0, 15.0, 16, 18.5, 25.0, 30.0, 35.0, 40.0, 100.0}; + assertEquals(1, creator.whichPosition(intervals, -10)); + assertEquals(1, creator.whichPosition(intervals, 11)); + assertEquals(1, creator.whichPosition(intervals, 15)); + assertEquals(3, creator.whichPosition(intervals, 17)); + assertEquals(5, creator.whichPosition(intervals, 29)); + assertEquals(8, creator.whichPosition(intervals, 41)); + assertEquals(8, creator.whichPosition(intervals, 200)); + } + + @Test + public void testCreateUpper() { + int[] interval = new int[]{3, 4}; + double[] result = creator.createUpper(interval, 4); + double[] expected = new double[]{.0, .0, 1.0, 1.0}; + + assertArrayEquals(result, expected); + } +}