Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for weighted fraction accumulator #385

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/boost/histogram/accumulators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <boost/histogram/accumulators/count.hpp>
#include <boost/histogram/accumulators/fraction.hpp>
#include <boost/histogram/accumulators/weighted_fraction.hpp>
#include <boost/histogram/accumulators/mean.hpp>
#include <boost/histogram/accumulators/sum.hpp>
#include <boost/histogram/accumulators/weighted_mean.hpp>
Expand Down
26 changes: 18 additions & 8 deletions include/boost/histogram/accumulators/fraction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <boost/core/nvp.hpp>
#include <boost/histogram/fwd.hpp> // for fraction<>
#include <boost/histogram/utility/wilson_interval.hpp>
#include <boost/histogram/weight.hpp>
#include <type_traits> // for std::common_type

namespace boost {
Expand All @@ -36,7 +37,8 @@ class fraction {
using const_reference = const value_type&;
using real_type = typename std::conditional<std::is_floating_point<value_type>::value,
value_type, double>::type;
using interval_type = typename utility::wilson_interval<real_type>::interval_type;
using score_type = typename utility::wilson_interval<real_type>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

score_type should not appear here, because it is not a type that needs to be publicly visible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue should be addressed now. Please take a look.

using interval_type = typename score_type::interval_type;

fraction() noexcept = default;

Expand All @@ -51,11 +53,14 @@ class fraction {
static_cast<value_type>(e.failures())} {}

/// Insert boolean sample x.
void operator()(bool x) noexcept {
void operator()(bool x) noexcept { operator()(weight(1), x); }

/// Insert boolean sample x with weight w.
void operator()(const weight_type<value_type>& w, bool x) noexcept {
if (x)
++succ_;
succ_ += w.value;
else
++fail_;
fail_ += w.value;
}

/// Add another accumulator.
Expand All @@ -79,18 +84,23 @@ class fraction {

/// Return variance of the success fraction.
real_type variance() const noexcept {
return variance_for_p_and_n_eff(value(), count());
}

/// Calculate the variance for a given success fraction and effective number of samples.
template <class T>
static real_type variance_for_p_and_n_eff(const real_type& p, const T& n_eff) noexcept {
// We want to compute Var(p) for p = X / n with Var(X) = n p (1 - p)
// For Var(X) see
// https://en.wikipedia.org/wiki/Binomial_distribution#Expected_value_and_variance
// Error propagation: Var(p) = p'(X)^2 Var(X) = p (1 - p) / n
const real_type p = value();
return p * (1 - p) / count();
return p * (1 - p) / n_eff;
}

/// Return standard interval with 68.3 % confidence level (Wilson score interval).
interval_type confidence_interval() const noexcept {
return utility::wilson_interval<real_type>()(static_cast<real_type>(successes()),
static_cast<real_type>(failures()));
return score_type()(static_cast<real_type>(successes()),
static_cast<real_type>(failures()));
}

bool operator==(const fraction& rhs) const noexcept {
Expand Down
8 changes: 8 additions & 0 deletions include/boost/histogram/accumulators/ostream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ std::basic_ostream<CharT, Traits>& operator<<(std::basic_ostream<CharT, Traits>&
return detail::handle_nonzero_width(os, x);
}

template <class CharT, class Traits, class U>
std::basic_ostream<CharT, Traits>& operator<<(std::basic_ostream<CharT, Traits>& os,
const weighted_fraction<U>& x) {
if (os.width() == 0)
return os << "weighted_fraction(" << x.get_fraction() << ", " << x.sum_w2() << ")";
return detail::handle_nonzero_width(os, x);
}

} // namespace accumulators
} // namespace histogram
} // namespace boost
Expand Down
193 changes: 193 additions & 0 deletions include/boost/histogram/accumulators/weighted_fraction.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
// Copyright 2022 Jay Gohil, Hans Dembinski
//
// Distributed under the Boost Software License, version 1.0.
// (See accompanying file LICENSE_1_0.txt
// or copy at http://www.boost.org/LICENSE_1_0.txt)

#ifndef BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_FRACTION_HPP
#define BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_FRACTION_HPP

#include <boost/core/nvp.hpp>
#include <boost/histogram/accumulators/fraction.hpp>
#include <boost/histogram/detail/square.hpp>
#include <boost/histogram/fwd.hpp> // for weighted_fraction<>
#include <boost/histogram/weight.hpp>
#include <type_traits> // for std::common_type

namespace boost {
namespace histogram {
namespace accumulators {

namespace internal {

// Accumulates the sum of weights squared.
template <class ValueType>
class sum_of_weights_squared {
public:
using value_type = ValueType;
using const_reference = const value_type&;

sum_of_weights_squared() = default;

// Allow implicit conversion from sum_of_weights_squared<T>
template <class T>
sum_of_weights_squared(const sum_of_weights_squared<T>& o) noexcept
: sum_of_weights_squared(o.sum_of_weights_squared_) {}

// Initialize to external sum of weights squared.
sum_of_weights_squared(const_reference sum_w2) noexcept
: sum_of_weights_squared_(sum_w2) {}

// Increment by one.
sum_of_weights_squared& operator++() {
++sum_of_weights_squared_;
return *this;
}

// Increment by weight.
sum_of_weights_squared& operator+=(const weight_type<value_type>& w) {
sum_of_weights_squared_ += detail::square(w.value);
return *this;
}

// Added another sum_of_weights_squared.
sum_of_weights_squared& operator+=(const sum_of_weights_squared& rhs) {
sum_of_weights_squared_ += rhs.sum_of_weights_squared_;
return *this;
}

bool operator==(const sum_of_weights_squared& rhs) const noexcept {
return sum_of_weights_squared_ == rhs.sum_of_weights_squared_;
}

bool operator!=(const sum_of_weights_squared& rhs) const noexcept {
return !operator==(rhs);
}

// Return sum of weights squared.
const_reference value() const noexcept { return sum_of_weights_squared_; }

template <class Archive>
void serialize(Archive& ar, unsigned /* version */) {
ar& make_nvp("sum_of_weights_squared", sum_of_weights_squared_);
}

private:
ValueType sum_of_weights_squared_{};
};

} // namespace internal

/// Accumulates weighted boolean samples and computes the fraction of true samples.
template <class ValueType>
class weighted_fraction {
public:
using value_type = ValueType;
using const_reference = const value_type&;
using fraction_type = fraction<ValueType>;
using real_type = typename fraction_type::real_type;
using score_type = typename fraction_type::score_type;
using interval_type = typename score_type::interval_type;

weighted_fraction() noexcept = default;

/// Initialize to external fraction and sum of weights squared.
weighted_fraction(const fraction_type& f, const_reference sum_w2) noexcept
: f_(f), sum_w2_(sum_w2) {}

/// Convert the weighted_fraction class to a different type T.
template <class T>
operator weighted_fraction<T>() const noexcept {
return weighted_fraction<T>(static_cast<fraction<T>>(f_),
static_cast<T>(sum_w2_.value()));
}

/// Insert boolean sample x with weight 1.
void operator()(bool x) noexcept { operator()(weight(1), x); }

/// Insert boolean sample x with weight w.
void operator()(const weight_type<value_type>& w, bool x) noexcept {
f_(w, x);
sum_w2_ += w;
}

/// Add another weighted_fraction.
weighted_fraction& operator+=(const weighted_fraction& rhs) noexcept {
f_ += rhs.f_;
sum_w2_ += rhs.sum_w2_;
return *this;
}

bool operator==(const weighted_fraction& rhs) const noexcept {
return f_ == rhs.f_ && sum_w2_ == rhs.sum_w2_;
}

bool operator!=(const weighted_fraction& rhs) const noexcept {
return !operator==(rhs);
}

/// Return number of boolean samples that were true.
const_reference successes() const noexcept { return f_.successes(); }

/// Return number of boolean samples that were false.
const_reference failures() const noexcept { return f_.failures(); }

/// Return effective number of boolean samples.
real_type count() const noexcept {
return static_cast<real_type>(detail::square(f_.count())) / sum_w2_.value();
}

/// Return success weighted_fraction of boolean samples.
real_type value() const noexcept { return f_.value(); }

/// Return variance of the success weighted_fraction.
real_type variance() const noexcept {
return fraction_type::variance_for_p_and_n_eff(value(), count());
}

/// Return the sum of weights squared.
value_type sum_of_weights_squared() const noexcept { return sum_w2_.value(); }

/// Return standard interval with 68.3 % confidence level (Wilson score interval).
interval_type confidence_interval() const noexcept {
const real_type n_eff = count();
const real_type p_hat = value();
const real_type correction = score_type::third_order_correction(n_eff);
return score_type().wilson_solve_for_neff_phat_correction(n_eff, p_hat, correction);
}

/// Return the fraction.
const fraction_type& get_fraction() const noexcept { return f_; }

/// Return the sum of weights squared.
const value_type& sum_w2() const noexcept { return sum_w2_.value(); }

template <class Archive>
void serialize(Archive& ar, unsigned /* version */) {
ar& make_nvp("fraction", f_);
ar& make_nvp("sum_of_weights_squared", sum_w2_);
}

private:
fraction_type f_;
internal::sum_of_weights_squared<ValueType> sum_w2_;
};

} // namespace accumulators
} // namespace histogram
} // namespace boost

#ifndef BOOST_HISTOGRAM_DOXYGEN_INVOKED

namespace std {
template <class T, class U>
/// Specialization for boost::histogram::accumulators::weighted_fraction.
struct common_type<boost::histogram::accumulators::weighted_fraction<T>,
boost::histogram::accumulators::weighted_fraction<U>> {
using type = boost::histogram::accumulators::weighted_fraction<common_type_t<T, U>>;
};
} // namespace std

#endif

#endif
10 changes: 10 additions & 0 deletions include/boost/histogram/fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ class count;
template <class ValueType = double>
class fraction;

namespace internal {

template <class ValueType = double>
class sum_of_weights_squared;

} // namespace internal

template <class ValueType = double>
class weighted_fraction;

template <class ValueType = double>
class sum;

Expand Down
65 changes: 65 additions & 0 deletions include/boost/histogram/utility/wilson_interval.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,72 @@ class wilson_interval : public binomial_proportion_interval<ValueType> {
return {t1 - t2, t1 + t2};
}

/// Returns the third order correction for n_eff.
static value_type third_order_correction(value_type n_eff) noexcept {
// The approximate formula reads:
// f(n) = (n³ + n² + 2n + 6) / n³
//
// Applying the substitution x = 1 / n gives:
// f(n) = 1 + x + 2x² + 6x³
//
// Using Horner's method to evaluate this polynomial gives:
// f(n) = 1 + x (1 + x (2 + 6x))
if (n_eff == 0) return 1;
const value_type x = 1 / n_eff;
return 1 + x * (1 + x * (2 + 6 * x));
}

/** Computer the confidence interval for the provided problem.

@param p The problem to solve.
*/
interval_type wilson_solve_for_neff_phat_correction(
const value_type& n_eff, const value_type& p_hat,
const value_type& correction) const noexcept {
// Equation 41 from this paper: https://arxiv.org/abs/2110.00294
// (p̂ - p)² = p (1 - p) (z² f(n) / n)
// Multiply by n to avoid floating point error when n = 0.
// n (p̂ - p)² = p (1 - p) z² f(n)
// Expand.
// np² - 2np̂p + np̂² = pz²f(n) - p²z²f(n)
// Collect terms of p.
// p²(n + z²f(n)) + p(-2np̂ - z²f(n)) + (np̂²) = 0
//
// This is a quadratic equation ap² + bp + c = 0 where
// a = n + z²f(n)
// b = -2np̂ - z²f(n)
// c = np̂²

const value_type zz_correction = (z_ * z_) * correction;

const value_type a = n_eff + zz_correction;
const value_type b = -2 * n_eff * p_hat - zz_correction;
const value_type c = n_eff * (p_hat * p_hat);

return quadratic_roots(a, b, c);
}

private:
// Finds the roots of the quadratic equation ax² + bx + c = 0.
static interval_type quadratic_roots(const value_type& a, const value_type& b,
const value_type& c) noexcept {
// https://people.csail.mit.edu/bkph/articles/Quadratics.pdf

const value_type two_a = 2 * a;
const value_type two_c = 2 * c;
const value_type sqrt_bb_4ac = std::sqrt(b * b - two_a * two_c);

if (b >= 0) {
const value_type root1 = (-b - sqrt_bb_4ac) / two_a;
const value_type root2 = two_c / (-b - sqrt_bb_4ac);
return {root1, root2};
} else {
const value_type root1 = two_c / (-b + sqrt_bb_4ac);
const value_type root2 = (-b + sqrt_bb_4ac) / two_a;
return {root1, root2};
}
}

value_type z_;
};

Expand Down
3 changes: 2 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ boost_test(TYPE compile-fail SOURCES histogram_fail4.cpp)
set(BOOST_TEST_LINK_LIBRARIES Boost::histogram Boost::core)

boost_test(TYPE run SOURCES accumulators_count_test.cpp)
boost_test(TYPE run SOURCES accumulators_fraction_test.cpp)
boost_test(TYPE run SOURCES accumulators_weighted_fraction_test.cpp)
boost_test(TYPE run SOURCES accumulators_mean_test.cpp)
boost_test(TYPE run SOURCES accumulators_sum_test.cpp)
boost_test(TYPE run SOURCES accumulators_weighted_mean_test.cpp)
Expand Down Expand Up @@ -96,7 +98,6 @@ boost_test(TYPE run SOURCES unlimited_storage_test.cpp)
boost_test(TYPE run SOURCES tools_test.cpp)
boost_test(TYPE run SOURCES issue_327_test.cpp)
boost_test(TYPE run SOURCES issue_353_test.cpp)
boost_test(TYPE run SOURCES accumulators_fraction_test.cpp)
boost_test(TYPE run SOURCES utility_binomial_proportion_interval_test.cpp)
boost_test(TYPE run SOURCES utility_wald_interval_test.cpp)
boost_test(TYPE run SOURCES utility_wilson_interval_test.cpp)
Expand Down
1 change: 1 addition & 0 deletions test/Jamfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ alias odr :
alias cxx14 :
[ run accumulators_count_test.cpp ]
[ run accumulators_fraction_test.cpp ]
[ run accumulators_weighted_fraction_test.cpp ]
[ run accumulators_mean_test.cpp ]
[ run accumulators_sum_test.cpp : : :
# make sure sum accumulator works even with -ffast-math and optimizations
Expand Down
Loading