Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for weighted fraction accumulator #385

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/boost/histogram/accumulators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@
#include <boost/histogram/accumulators/sum.hpp>
#include <boost/histogram/accumulators/weighted_mean.hpp>
#include <boost/histogram/accumulators/weighted_sum.hpp>
#include <boost/histogram/experimental/weighted_fraction.hpp>

#endif
42 changes: 30 additions & 12 deletions include/boost/histogram/accumulators/fraction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

#include <boost/core/nvp.hpp>
#include <boost/histogram/fwd.hpp> // for fraction<>
#include <boost/histogram/utility/binomial_proportion_interval.hpp>
#include <boost/histogram/utility/wilson_interval.hpp>
#include <boost/histogram/weight.hpp>
#include <type_traits> // for std::common_type

namespace boost {
Expand All @@ -36,7 +38,8 @@ class fraction {
using const_reference = const value_type&;
using real_type = typename std::conditional<std::is_floating_point<value_type>::value,
value_type, double>::type;
using interval_type = typename utility::wilson_interval<real_type>::interval_type;
using interval_type =
typename utility::binomial_proportion_interval<real_type>::interval_type;

fraction() noexcept = default;

Expand All @@ -51,11 +54,14 @@ class fraction {
static_cast<value_type>(e.failures())} {}

/// Insert boolean sample x.
void operator()(bool x) noexcept {
void operator()(bool x) noexcept { operator()(weight(1), x); }

/// Insert boolean sample x with weight w.
void operator()(const weight_type<value_type>& w, bool x) noexcept {
if (x)
++succ_;
succ_ += w.value;
else
++fail_;
fail_ += w.value;
}

/// Add another accumulator.
Expand All @@ -79,18 +85,12 @@ class fraction {

/// Return variance of the success fraction.
real_type variance() const noexcept {
// We want to compute Var(p) for p = X / n with Var(X) = n p (1 - p)
// For Var(X) see
// https://en.wikipedia.org/wiki/Binomial_distribution#Expected_value_and_variance
// Error propagation: Var(p) = p'(X)^2 Var(X) = p (1 - p) / n
const real_type p = value();
return p * (1 - p) / count();
return variance_for_p_and_n_eff(value(), count());
}

/// Return standard interval with 68.3 % confidence level (Wilson score interval).
interval_type confidence_interval() const noexcept {
return utility::wilson_interval<real_type>()(static_cast<real_type>(successes()),
static_cast<real_type>(failures()));
return confidence_interval(utility::wilson_interval<real_type>());
}

bool operator==(const fraction& rhs) const noexcept {
Expand All @@ -106,6 +106,24 @@ class fraction {
}

private:
friend class weighted_fraction<value_type>;

// Calculate the variance for a given success fraction and effective number of samples.
template <class T>
static real_type variance_for_p_and_n_eff(const real_type& p, const T& n_eff) noexcept {
// We want to compute Var(p) for p = X / n with Var(X) = n p (1 - p)
// For Var(X) see
// https://en.wikipedia.org/wiki/Binomial_distribution#Expected_value_and_variance
// Error propagation: Var(p) = p'(X)^2 Var(X) = p (1 - p) / n
return p * (1 - p) / n_eff;
}

// Return interval for the given binomial proportion interval computer.
interval_type confidence_interval(
const utility::binomial_proportion_interval<real_type>& b) const noexcept {
return b(static_cast<real_type>(successes()), static_cast<real_type>(failures()));
}

value_type succ_{};
value_type fail_{};
};
Expand Down
8 changes: 8 additions & 0 deletions include/boost/histogram/accumulators/ostream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ std::basic_ostream<CharT, Traits>& operator<<(std::basic_ostream<CharT, Traits>&
return detail::handle_nonzero_width(os, x);
}

template <class CharT, class Traits, class U>
std::basic_ostream<CharT, Traits>& operator<<(std::basic_ostream<CharT, Traits>& os,
const weighted_fraction<U>& x) {
if (os.width() == 0)
return os << "weighted_fraction(" << x.get_fraction() << ", " << x.sum_w2() << ")";
return detail::handle_nonzero_width(os, x);
}

} // namespace accumulators
} // namespace histogram
} // namespace boost
Expand Down
198 changes: 198 additions & 0 deletions include/boost/histogram/experimental/weighted_fraction.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
// Copyright 2022 Jay Gohil, Hans Dembinski
//
// Distributed under the Boost Software License, version 1.0.
// (See accompanying file LICENSE_1_0.txt
// or copy at http://www.boost.org/LICENSE_1_0.txt)

#ifndef BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_FRACTION_HPP
#define BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_FRACTION_HPP

#include <boost/core/nvp.hpp>
#include <boost/histogram/accumulators/fraction.hpp>
#include <boost/histogram/detail/square.hpp>
#include <boost/histogram/fwd.hpp> // for weighted_fraction<>
#include <boost/histogram/weight.hpp>
#include <type_traits> // for std::common_type

namespace boost {
namespace histogram {
namespace accumulators {

namespace internal {

// Accumulates the sum of weights squared.
template <class ValueType>
class sum_of_weights_squared {
public:
using value_type = ValueType;
using const_reference = const value_type&;

sum_of_weights_squared() = default;

// Allow implicit conversion from sum_of_weights_squared<T>
template <class T>
sum_of_weights_squared(const sum_of_weights_squared<T>& o) noexcept
: sum_of_weights_squared(o.sum_of_weights_squared_) {}

// Initialize to external sum of weights squared.
sum_of_weights_squared(const_reference sum_w2) noexcept
: sum_of_weights_squared_(sum_w2) {}

// Increment by one.
sum_of_weights_squared& operator++() {
++sum_of_weights_squared_;
return *this;
}

// Increment by weight.
sum_of_weights_squared& operator+=(const weight_type<value_type>& w) {
sum_of_weights_squared_ += detail::square(w.value);
return *this;
}

// Added another sum_of_weights_squared.
sum_of_weights_squared& operator+=(const sum_of_weights_squared& rhs) {
sum_of_weights_squared_ += rhs.sum_of_weights_squared_;
return *this;
}

bool operator==(const sum_of_weights_squared& rhs) const noexcept {
return sum_of_weights_squared_ == rhs.sum_of_weights_squared_;
}

bool operator!=(const sum_of_weights_squared& rhs) const noexcept {
return !operator==(rhs);
}

// Return sum of weights squared.
const_reference value() const noexcept { return sum_of_weights_squared_; }

template <class Archive>
void serialize(Archive& ar, unsigned /* version */) {
ar& make_nvp("sum_of_weights_squared", sum_of_weights_squared_);
}

private:
ValueType sum_of_weights_squared_{};
};

} // namespace internal

/// Accumulates weighted boolean samples and computes the fraction of true samples.
template <class ValueType>
class weighted_fraction {
public:
using value_type = ValueType;
using const_reference = const value_type&;
using fraction_type = fraction<ValueType>;
using real_type = typename fraction_type::real_type;
using interval_type = typename fraction_type::interval_type;

weighted_fraction() noexcept = default;

/// Initialize to external fraction and sum of weights squared.
weighted_fraction(const fraction_type& f, const_reference sum_w2) noexcept
: f_(f), sum_w2_(sum_w2) {}

/// Convert the weighted_fraction class to a different type T.
template <class T>
operator weighted_fraction<T>() const noexcept {
return weighted_fraction<T>(static_cast<fraction<T>>(f_),
static_cast<T>(sum_w2_.value()));
}

/// Insert boolean sample x with weight 1.
void operator()(bool x) noexcept { operator()(weight(1), x); }

/// Insert boolean sample x with weight w.
void operator()(const weight_type<value_type>& w, bool x) noexcept {
f_(w, x);
sum_w2_ += w;
}

/// Add another weighted_fraction.
weighted_fraction& operator+=(const weighted_fraction& rhs) noexcept {
f_ += rhs.f_;
sum_w2_ += rhs.sum_w2_;
return *this;
}

bool operator==(const weighted_fraction& rhs) const noexcept {
return f_ == rhs.f_ && sum_w2_ == rhs.sum_w2_;
}

bool operator!=(const weighted_fraction& rhs) const noexcept {
return !operator==(rhs);
}

/// Return number of boolean samples that were true.
const_reference successes() const noexcept { return f_.successes(); }

/// Return number of boolean samples that were false.
const_reference failures() const noexcept { return f_.failures(); }

/// Return effective number of boolean samples.
real_type count() const noexcept {
return static_cast<real_type>(detail::square(f_.count())) / sum_w2_.value();
}

/// Return success weighted_fraction of boolean samples.
real_type value() const noexcept { return f_.value(); }

/// Return variance of the success weighted_fraction.
real_type variance() const noexcept {
return fraction_type::variance_for_p_and_n_eff(value(), count());
}

/// Return the sum of weights squared.
value_type sum_of_weights_squared() const noexcept { return sum_w2_.value(); }

/// Return standard interval with 68.3 % confidence level (Wilson score interval).
interval_type confidence_interval() const noexcept {
return confidence_interval(utility::wilson_interval<real_type>());
}

/// Return the Wilson score interval.
interval_type confidence_interval(
const utility::wilson_interval<real_type>& w) const noexcept {
const real_type n_eff = count();
const real_type p_hat = value();
const real_type correction = w.third_order_correction(n_eff);
return w.solve_for_neff_phat_correction(n_eff, p_hat, correction);
}

/// Return the fraction.
const fraction_type& get_fraction() const noexcept { return f_; }

/// Return the sum of weights squared.
const value_type& sum_w2() const noexcept { return sum_w2_.value(); }

template <class Archive>
void serialize(Archive& ar, unsigned /* version */) {
ar& make_nvp("fraction", f_);
ar& make_nvp("sum_of_weights_squared", sum_w2_);
}

private:
fraction_type f_;
internal::sum_of_weights_squared<ValueType> sum_w2_;
};

} // namespace accumulators
} // namespace histogram
} // namespace boost

#ifndef BOOST_HISTOGRAM_DOXYGEN_INVOKED

namespace std {
template <class T, class U>
/// Specialization for boost::histogram::accumulators::weighted_fraction.
struct common_type<boost::histogram::accumulators::weighted_fraction<T>,
boost::histogram::accumulators::weighted_fraction<U>> {
using type = boost::histogram::accumulators::weighted_fraction<common_type_t<T, U>>;
};
} // namespace std

#endif

#endif
10 changes: 10 additions & 0 deletions include/boost/histogram/fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ class count;
template <class ValueType = double>
class fraction;

namespace internal {

template <class ValueType = double>
class sum_of_weights_squared;

} // namespace internal

template <class ValueType = double>
class weighted_fraction;

template <class ValueType = double>
class sum;

Expand Down
65 changes: 65 additions & 0 deletions include/boost/histogram/utility/wilson_interval.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,72 @@ class wilson_interval : public binomial_proportion_interval<ValueType> {
return {t1 - t2, t1 + t2};
}

/// Returns the third order correction for n_eff.
static value_type third_order_correction(value_type n_eff) noexcept {
// The approximate formula reads:
// f(n) = (n³ + n² + 2n + 6) / n³
//
// Applying the substitution x = 1 / n gives:
// f(n) = 1 + x + 2x² + 6x³
//
// Using Horner's method to evaluate this polynomial gives:
// f(n) = 1 + x (1 + x (2 + 6x))
if (n_eff == 0) return 1;
const value_type x = 1 / n_eff;
return 1 + x * (1 + x * (2 + 6 * x));
}

/** Computer the confidence interval for the provided problem.

@param p The problem to solve.
*/
interval_type solve_for_neff_phat_correction(
const value_type& n_eff, const value_type& p_hat,
const value_type& correction) const noexcept {
// Equation 41 from this paper: https://arxiv.org/abs/2110.00294
// (p̂ - p)² = p (1 - p) (z² f(n) / n)
// Multiply by n to avoid floating point error when n = 0.
// n (p̂ - p)² = p (1 - p) z² f(n)
// Expand.
// np² - 2np̂p + np̂² = pz²f(n) - p²z²f(n)
// Collect terms of p.
// p²(n + z²f(n)) + p(-2np̂ - z²f(n)) + (np̂²) = 0
//
// This is a quadratic equation ap² + bp + c = 0 where
// a = n + z²f(n)
// b = -2np̂ - z²f(n)
// c = np̂²

const value_type zz_correction = (z_ * z_) * correction;

const value_type a = n_eff + zz_correction;
const value_type b = -2 * n_eff * p_hat - zz_correction;
const value_type c = n_eff * (p_hat * p_hat);

return quadratic_roots(a, b, c);
}

private:
// Finds the roots of the quadratic equation ax² + bx + c = 0.
static interval_type quadratic_roots(const value_type& a, const value_type& b,
const value_type& c) noexcept {
// https://people.csail.mit.edu/bkph/articles/Quadratics.pdf

const value_type two_a = 2 * a;
const value_type two_c = 2 * c;
const value_type sqrt_bb_4ac = std::sqrt(b * b - two_a * two_c);

if (b >= 0) {
const value_type root1 = (-b - sqrt_bb_4ac) / two_a;
const value_type root2 = two_c / (-b - sqrt_bb_4ac);
return {root1, root2};
} else {
const value_type root1 = two_c / (-b + sqrt_bb_4ac);
const value_type root2 = (-b + sqrt_bb_4ac) / two_a;
return {root1, root2};
}
}

value_type z_;
};

Expand Down
Loading