Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added create CPT feature with interval helper methods #110

Merged
merged 1 commit into from Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions src/main/java/ch/idsia/crema/preprocess/creators/CreateCPT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package ch.idsia.crema.preprocess.creators;

import ch.idsia.crema.core.Strides;
import ch.idsia.crema.factor.credal.linear.interval.IntervalFactor;
import ch.idsia.crema.factor.credal.linear.interval.IntervalFactorFactory;
import ch.idsia.crema.utility.IndexIterator;

import java.util.Arrays;

public class CreateCPT {

/**
* Border returns the borders of the intervals
*
* @param lower array of lower bounds for all the variables
* ex: double[] lower = new double[]{1.55, 55.0};
* @param upper array of upper bounds for all the variables
* ex: double[] upper = new double[]{1.90, 115.0};
*/
public double[] borders(double[] lower, double[] upper, Op operation) {

//shortcut in case the function is strictly monotone growing
//double[] extremes= new double[]{operation.execute(lower[0], upper[1]), operation.execute(upper[0], lower[1])};
//return Arrays.stream(extremes).sorted().toArray();

double[] results = new double[2 * lower.length];
results[0] = operation.execute(lower[0], lower[1]);
results[1] = operation.execute(lower[0], upper[1]);
results[2] = operation.execute(upper[0], lower[1]);
results[3] = operation.execute(upper[0], upper[1]);

Arrays.sort(results);
double firstElement = results[0];
double lastElement = results[results.length - 1];

return new double[]{firstElement, lastElement};
}

/**
* Method to create a CPT
*
* @param childVar variable of the child
* @param parentsVars array of variables of the parents
* @param childCuts array of cuts for the child
* @param parentCuts array of cuts for the parents
* @param operation operation to be performed with the cuts
* example K(bmi|w,H)
* @return IntervalFactor representing the CPT of the child given the parents
*/
public IntervalFactor create(int childVar, int[] parentsVars, double[] childCuts, double[][] parentCuts, Op operation) {

// root nodes creation
// add child node
int dimChild = childCuts.length + 1;
Strides stridesChild = Strides.var(childVar, dimChild);

// create domain
Strides dom = Strides.empty();
for (int i = 0; i < parentsVars.length; i++) {
dom = dom.and(parentsVars[i], parentCuts[i].length - 1);
}
// add parents nodes
IntervalFactorFactory factory = IntervalFactorFactory.factory().domain(stridesChild, dom);

// create iterator
IndexIterator iterator = dom.getIterator();
//iterate over all possible combinations
while (iterator.hasNext()) {
int[] comb = iterator.getPositions().clone();
// be aware that the structure returned by the method has to be compliant with the child
double[] parentIntervalLower = new double[parentCuts.length];
double[] parentIntervalUpper = new double[parentCuts.length];

for (int i = 0; i < parentCuts.length; i++) {
parentIntervalLower[i] = parentCuts[i][comb[i]];
parentIntervalUpper[i] = parentCuts[i][comb[i] + 1];
}

double[] intervalBorders = borders(parentIntervalLower, parentIntervalUpper, operation);
//map the integers of the position of the childCuts
int[] intervalNumber = Arrays.stream(intervalBorders).mapToInt(val -> whichPosition(childCuts, val)).toArray();

// the lower is set to an array of zeroes the upper is set to 1 in the position of the interval number
factory.set(new double[dimChild], createUpper(intervalNumber, dimChild), comb);
iterator.next();
}
return factory.get();
}

/**
* @param interval array containing the position
* @param dim dimension of the array to be generated
* @return double array with 1.0 set for every interval value
*/
public double[] createUpper(int[] interval, int dim) {
double[] upper = new double[dim];

// we set to 1.0 all the element relative to the interval
for (int ind : interval) {
upper[ind - 1] = 1.0;
}
return upper;
}

/**
* Method to find the interval containing the specified value
* Specials: if the number is lower than the first cut, it will be placed in the first interval
* if the number is higher than the last cut, it will be placed in the last interval
*
* @param cutsX array of cuts, typically for the discretization
* @param X value that we want to place
* @return integer of the interval containing x
*/
//greedy
public int whichPosition(double[] cutsX, double X) {
int position = 0; //starts from 1
for (int i = 1; i < cutsX.length - 1; i++) {
if (X <= cutsX[i]) {
break;
}
position++;
}
return position + 1;
}
}
6 changes: 6 additions & 0 deletions src/main/java/ch/idsia/crema/preprocess/creators/Op.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package ch.idsia.crema.preprocess.creators;

@FunctionalInterface
public interface Op {
double execute(double a, double b);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package ch.idsia.crema.preprocess.creators;

import ch.idsia.crema.factor.credal.linear.interval.IntervalFactor;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class CreateCPTTest {

private final CreateCPT creator = new CreateCPT();

@Test
public void testCreate() {

//BMI example
double[] cutsH = new double[]{1.55, 1.60, 1.65, 1.70, 1.75, 1.80, 1.85, 1.90, 1.95, 2.00};
double[] cutsW = new double[]{55.0, 60.0, 65.0, 70.0, 75.0, 80.0, 85.0, 90.0, 95.0, 100.0, 105.0, 110.0, 115.0};
double[] cutsBMI = new double[]{0.0, 15.0, 16, 18.5, 25.0, 30.0, 35.0, 40.0, 100.0};

double[][] parents = new double[][]{cutsW, cutsH};
Op bmi = (w, h) -> w / h / h;

IntervalFactor cpt = creator.create(2, new int[]{0, 1}, cutsBMI, parents, bmi);
//System.out.println(cpt);

// w1 and h1
assertArrayEquals(cpt.getLower(0, 0), new double[10]);
assertArrayEquals(cpt.getUpper(0, 0), creator.createUpper(new int[]{4}, 10));

// w2 and h1
assertArrayEquals(cpt.getLower(1, 0), new double[10]);
assertArrayEquals(cpt.getUpper(1, 0), creator.createUpper(new int[]{4, 5}, 10));

// w5 and h6
assertArrayEquals(cpt.getLower(4, 5), new double[10]);
assertArrayEquals(cpt.getUpper(4, 5), creator.createUpper(new int[]{4}, 10));

// w12 and h9
assertArrayEquals(cpt.getLower(11, 8), new double[10]);
assertArrayEquals(cpt.getUpper(11, 8), creator.createUpper(new int[]{5, 6}, 10));
}

@Test
public void testBorders() {
double[] parentIntervalLower = new double[]{55.0, 1.55};
double[] parentIntervalUpper = new double[]{60.0, 1.60};
//example of the BMI
Op bmi = (w, h) -> w / h / h;

//bmi low = 21.48
//bmi high = 24.97
double tolerance = 0.001;
double[] interval = creator.borders(parentIntervalLower, parentIntervalUpper, bmi);

assertArrayEquals(new double[]{21.484, 24.973}, interval, tolerance);
}

@Test
public void testWhichPosition() {
double[] intervals = new double[]{0.0, 15.0, 16, 18.5, 25.0, 30.0, 35.0, 40.0, 100.0};
assertEquals(1, creator.whichPosition(intervals, -10));
assertEquals(1, creator.whichPosition(intervals, 11));
assertEquals(1, creator.whichPosition(intervals, 15));
assertEquals(3, creator.whichPosition(intervals, 17));
assertEquals(5, creator.whichPosition(intervals, 29));
assertEquals(8, creator.whichPosition(intervals, 41));
assertEquals(8, creator.whichPosition(intervals, 200));
}

@Test
public void testCreateUpper() {
int[] interval = new int[]{3, 4};
double[] result = creator.createUpper(interval, 4);
double[] expected = new double[]{.0, .0, 1.0, 1.0};

assertArrayEquals(result, expected);
}
}
Loading