Skip to content

Commit

Permalink
Fix and simplify Pearson coefficient
Browse files Browse the repository at this point in the history
  • Loading branch information
vks committed Dec 17, 2023
1 parent 339fa05 commit 5ba4923
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 27 deletions.
36 changes: 12 additions & 24 deletions src/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,14 @@ impl Covariance {
let n = self.n.to_f64().unwrap();

let delta_x = x - self.avg_x;
let delta_y = y - self.avg_y;
let delta_x_n = delta_x / n;
let delta_y_n = (y - self.avg_y) / n;

self.avg_x += delta_x / n;
self.sum_x_2 += delta_x * delta_x * n * (n - 1.);
self.avg_x += delta_x_n;
self.sum_x_2 += delta_x_n * delta_x_n * n * (n - 1.);

self.avg_y += delta_y / n;
self.sum_y_2 += delta_y * delta_y * n * (n - 1.);
self.avg_y += delta_y_n;
self.sum_y_2 += delta_y_n * delta_y_n * n * (n - 1.);

self.sum_prod += delta_x * (y - self.avg_y);
}
Expand Down Expand Up @@ -91,30 +92,17 @@ impl Covariance {
self.sum_prod / (self.n - 1).to_f64().unwrap()
}

/// Calculate the population Pearson correlation coefficient of the sample.
/// Calculate the population Pearson correlation coefficient.
///
/// Returns NaN for an empty sample.
#[cfg(any(feature = "std", feature = "libm"))]
#[cfg_attr(doc_cfg, doc(cfg(any(feature = "std", feature = "libm"))))]
#[inline]
pub fn population_pearson(&self) -> f64 {
let cov = self.population_covariance();
let var_x = self.population_variance_x();
let var_y = self.population_variance_y();
cov / num_traits::Float::sqrt(var_x * var_y)
}

/// Calculate the sample Pearson correlation coefficient.
///
/// Returns NaN for an empty sample.
#[cfg(any(feature = "std", feature = "libm"))]
#[cfg_attr(doc_cfg, doc(cfg(any(feature = "std", feature = "libm"))))]
#[inline]
pub fn sample_pearson(&self) -> f64 {
let cov = self.sample_covariance();
let var_x = self.sample_variance_x();
let var_y = self.sample_variance_y();
cov / num_traits::Float::sqrt(var_x * var_y)
pub fn pearson(&self) -> f64 {
if self.n < 2 {
return f64::NAN;
}
self.sum_prod / num_traits::Float::sqrt(self.sum_x_2 * self.sum_y_2)
}

/// Return the sample size.
Expand Down
33 changes: 30 additions & 3 deletions tests/integration/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,66 @@ fn simple() {
let mut cov = Covariance::new();
assert!(cov.mean_x().is_nan());
assert!(cov.mean_y().is_nan());
assert!(cov.population_variance_x().is_nan());
assert!(cov.population_variance_y().is_nan());
assert!(cov.sample_variance_x().is_nan());
assert!(cov.sample_variance_y().is_nan());
assert!(cov.population_covariance().is_nan());
assert!(cov.sample_covariance().is_nan());
assert!(cov.population_pearson().is_nan());
assert!(cov.sample_pearson().is_nan());
assert!(cov.pearson().is_nan());

cov.add(1., 5.);
assert_eq!(cov.mean_x(), 1.);
assert_eq!(cov.mean_y(), 5.);
assert_eq!(cov.population_variance_x(), 0.);
assert_eq!(cov.population_variance_y(), 0.);
assert!(cov.sample_variance_x().is_nan());
assert!(cov.sample_variance_y().is_nan());
assert_eq!(cov.population_covariance(), 0.);
assert!(cov.sample_covariance().is_nan());
// TODO: pearson
assert!(cov.pearson().is_nan());

cov.add(2., 4.);
assert_eq!(cov.mean_x(), 1.5);
assert_eq!(cov.mean_y(), 4.5);
assert_eq!(cov.population_variance_x(), 0.25);
assert_eq!(cov.population_variance_y(), 0.25);
assert_eq!(cov.sample_variance_x(), 0.5);
assert_eq!(cov.sample_variance_y(), 0.5);
assert_eq!(cov.population_covariance(), -0.25);
assert_eq!(cov.sample_covariance(), -0.5);
assert_eq!(cov.pearson(), -1.);

cov.add(3., 3.);
assert_eq!(cov.mean_x(), 2.);
assert_eq!(cov.mean_y(), 4.);
assert_eq!(cov.population_variance_x(), 2./3.);
assert_eq!(cov.population_variance_y(), 2./3.);
assert_eq!(cov.sample_variance_x(), 1.);
assert_eq!(cov.sample_variance_y(), 1.);
assert_eq!(cov.population_covariance(), -2./3.);
assert_eq!(cov.sample_covariance(), -1.);
assert_eq!(cov.pearson(), -1.);

cov.add(4., 2.);
assert_eq!(cov.mean_x(), 2.5);
assert_eq!(cov.mean_y(), 3.5);
assert_eq!(cov.population_variance_x(), 1.25);
assert_eq!(cov.population_variance_y(), 1.25);
assert_eq!(cov.sample_variance_x(), 5./3.);
assert_eq!(cov.sample_variance_y(), 5./3.);
assert_eq!(cov.population_covariance(), -1.25);
assert_eq!(cov.sample_covariance(), -5./3.);
assert_eq!(cov.pearson(), -1.);

cov.add(5., 1.);
assert_eq!(cov.mean_x(), 3.);
assert_eq!(cov.mean_y(), 3.);
assert_eq!(cov.population_variance_x(), 2.);
assert_eq!(cov.population_variance_y(), 2.);
assert_eq!(cov.sample_variance_x(), 2.5);
assert_eq!(cov.sample_variance_y(), 2.5);
assert_eq!(cov.population_covariance(), -2.0);
assert_eq!(cov.sample_covariance(), -2.5);
assert_eq!(cov.pearson(), -1.);
}

0 comments on commit 5ba4923

Please sign in to comment.