Skip to content

Commit

Permalink
Implement and test predict_proba
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoDiFrancesco committed Apr 24, 2024
1 parent 6b38849 commit 107354a
Showing 1 changed file with 24 additions and 20 deletions.
44 changes: 24 additions & 20 deletions src/classification/mondrian_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,33 +138,37 @@ impl<F: FType> Stats<F> {
.zip(self.counts.iter())
.enumerate()
{
// TODO (from nel215): case that var is 0 and count <= 1
// Shadow with bogous values count, sum, sq_sum, x
// let xx: Array1<F> = Array1::from_vec(vec![F::from_f32(1.5).unwrap(), F::from_f32(3.0).unwrap()]);
// let x = &xx;
// let count = 2;
// let sum = Array1::from_vec(vec![F::from_f32(1.0).unwrap(), F::from_f32(2.0).unwrap()]);
// let sq_sum: ArrayBase<ndarray::OwnedRepr<F>, Dim<[usize; 1]>> = Array1::from_vec(vec![F::from_f32(1.0).unwrap(), F::from_f32(2.0).unwrap()]);

let count_f = F::from_usize(count).unwrap();
let mean = &sum / count_f;
let var = (&sq_sum / count_f) - (&mean * &mean); // + F::epsilon()
let avg = &sum / count_f;
let var = (&sq_sum / count_f) - (&avg * &avg) + F::epsilon();
let sigma = (&var * count_f) / (count_f - F::one() + F::epsilon());
let pi = F::from_f32(std::f32::consts::PI).unwrap();
unimplemented!("Uncomment everything below and start from fixing z, z below is very different than nel215");
// let z = ((2.0 * pi * sigma).sqrt());
// let exp_term = x
// .iter()
// .zip(mean.iter())
// .map(|(&xi, &mi)| {
// let diff = xi - mi;
// (diff * diff) / (2.0 * sigma)
// })
// .sum::<f64>();

// let prob = (-exp_term).exp() / z;
// probs[index] = prob;
// sum_prob += prob;
let pi = F::from_f32(std::f32::consts::PI).unwrap() * F::from_f32(2.0).unwrap();
let z = pi.powi(x.len() as i32) * sigma.mapv(|s| s * s).sum().sqrt();
// Same as dot product
let dot_delta = (&(x - &avg) * &(x - &avg)).sum();
let dot_sigma = (&sigma * &sigma).sum();
let exponent = F::from_f32(0.5).unwrap() * dot_delta / dot_sigma;
let mut prob = exponent.exp() / z;
if count >= 1 {
assert!(!prob.is_nan(), "Probabaility should never be NaN.");
} else {
assert!(prob.is_nan(), "Probabaility should be NaN.");
prob = F::zero();
}
sum_prob += prob;
probs[index] = prob;
}

for prob in probs.iter_mut() {
*prob /= sum_prob;
}

unimplemented!("Finish uncommenting predict_proba()");
probs
}
}

0 comments on commit 107354a

Please sign in to comment.