Skip to content

Commit

Permalink
Improved efficency of median calculation in the even case.
Browse files Browse the repository at this point in the history
This relies on the fact that, after nth_element() is called, all
elements that are not greater than the n-th element have been moved
before the n-th element. Thus, to find the (n-1)-th element, we simply
have to take the maximum of the elements before the n-th element with
max_element(), rather than calling nth_element() again. The former is
guaranteed linear and much simpler than the latter.
  • Loading branch information
LTLA committed Aug 20, 2024
1 parent 036ea1b commit 87ed962
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 72 deletions.
77 changes: 50 additions & 27 deletions include/tatami_stats/medians.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,19 @@ Output_ direct(Value_* ptr, Index_ num, bool skip_nan) {
size_t halfway = num / 2;
bool is_even = (num % 2 == 0);

// At some point, I found two nth_element calls to be faster than partial_sort.
std::nth_element(ptr, ptr + halfway, ptr + num);
double medtmp = *(ptr + halfway);
if (is_even) {
std::nth_element(ptr, ptr + halfway - 1, ptr + num);
return (medtmp + *(ptr + halfway - 1))/2;
} else {
Output_ medtmp = *(ptr + halfway);
if (!is_even) {
return medtmp;
}

// 'nth_element()' reorganizes 'ptr' so that everything below 'halfway' is
// less than or equal to 'ptr[halfway]', while everything above 'halfway'
// is greater than or equal to 'ptr[halfway]'. Thus, to get the element
// immediately before 'halfway' in the sort order, we just need to find the
// maximum from '[0, halfway)'.
Output_ other = *std::max_element(ptr, ptr + halfway);
return (medtmp + other)/2;
}

/**
Expand All @@ -126,6 +130,9 @@ Output_ direct(Value_* ptr, Index_ num, bool skip_nan) {
*/
template<typename Output_ = double, typename Value_, typename Index_>
Output_ direct(Value_* value, Index_ num_nonzero, Index_ num_all, bool skip_nan) {
// Fallback to the dense code if there are no structural zeros. This is not
// just for efficiency as the downstream averaging code assumes that there
// is at least one structural zero when considering its scenarios.
if (num_nonzero == num_all) {
return direct<Output_>(value, num_all, skip_nan);
}
Expand All @@ -151,35 +158,51 @@ Output_ direct(Value_* value, Index_ num_nonzero, Index_ num_all, bool skip_nan)
size_t halfway = num_all / 2;
bool is_even = (num_all % 2 == 0);

auto vend = value + num_nonzero;
std::sort(value, vend);
size_t zeropos = std::lower_bound(value, vend, 0) - value;
size_t nzero = num_all - num_nonzero;
size_t num_zero = num_all - num_nonzero;
size_t num_negative = 0;
for (Index_ i = 0; i < num_nonzero; ++i) {
num_negative += (value[i] < 0);
}

if (!is_even) {
if (zeropos > halfway) {
if (num_negative > halfway) {
std::nth_element(value, value + halfway, value + num_nonzero);
return value[halfway];
} else if (halfway >= zeropos + nzero) {
return value[halfway - nzero];

} else if (halfway >= num_negative + num_zero) {
size_t skip_zeros = halfway - num_zero;
std::nth_element(value, value + skip_zeros, value + num_nonzero);
return value[skip_zeros];

} else {
return 0; // zero is the median.
return 0;
}
}

double tmp = 0;
if (zeropos > halfway) {
tmp = value[halfway] + value[halfway - 1];
} else if (zeropos == halfway) {
// guaranteed to be at least 1 zero.
tmp += value[halfway - 1];
} else if (zeropos < halfway && zeropos + nzero > halfway) {
; // zero is the median.
} else if (zeropos + nzero == halfway) {
// guaranteed to be at least 1 zero.
tmp += value[halfway - nzero];
} else {
tmp = value[halfway - nzero] + value[halfway - nzero - 1];
Output_ tmp = 0;
if (num_negative > halfway) { // both halves of the median are negative.
std::nth_element(value, value + halfway, value + num_nonzero);
tmp = value[halfway] + *(std::max_element(value, value + halfway)); // max_element gets the sorted value at halfway - 1, see explanation for the dense case.

} else if (num_negative == halfway) { // the upper half is guaranteed to be zero.
size_t below_halfway = halfway - 1;
std::nth_element(value, value + below_halfway, value + num_nonzero);
tmp = value[below_halfway];

} else if (num_negative < halfway && num_negative + num_zero > halfway) { // both halves are zero, so zero is the median.
;

} else if (num_negative + num_zero == halfway) { // the lower half is guaranteed to be zero.
size_t skip_zeros = halfway - num_zero;
std::nth_element(value, value + skip_zeros, value + num_nonzero);
tmp = value[skip_zeros];

} else { // both halves of the median are non-negative.
size_t skip_zeros = halfway - num_zero;
std::nth_element(value, value + skip_zeros, value + num_nonzero);
tmp = value[skip_zeros] + *(std::max_element(value, value + skip_zeros)); // max_element gets the sorted value at skip_zeros - 1, see explanation for the dense case.
}

return tmp / 2;
}

Expand Down
197 changes: 152 additions & 45 deletions tests/src/medians.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,75 +10,182 @@
#include "tatami_stats/medians.hpp"
#include "tatami_test/tatami_test.hpp"

TEST(ComputeMedians, Dense) {
{
std::vector<int> vec { 2, 1, 4, 5, 3 };
int vsize = vec.size();
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, false), 3);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data() + 1, vsize - 1, false), 3.5);
class ComputeMediansTest : public ::testing::Test {
protected:
template<typename Value_>
static double direct_medians(const Value_* vec, size_t n, bool skip_nan) {
std::vector<Value_> copy(vec, vec + n);
return tatami_stats::medians::direct<double>(copy.data(), copy.size(), skip_nan);
}

// Now with NaN stripping.
{
std::vector<double> vec { 2, 1, std::numeric_limits<double>::quiet_NaN(), 5, 3 };
int vsize = vec.size();
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, true), 2.5);
template<typename Value_, typename Index_>
static double direct_medians(const Value_* vec, Index_ num_nonzero, Index_ num_all, bool skip_nan) {
std::vector<Value_> copy(vec, vec + num_nonzero);
return tatami_stats::medians::direct<double>(copy.data(), num_nonzero, num_all, skip_nan);
}
};

TEST_F(ComputeMediansTest, DenseBasic) {
std::vector<int> vec { 2, 1, 4, 5, 3 };
int vsize = vec.size();
EXPECT_EQ(direct_medians(vec.data(), vsize, false), 3);
EXPECT_EQ(direct_medians(vec.data() + 1, vsize - 1, false), 3.5);
EXPECT_EQ(direct_medians(vec.data(), vsize - 1, false), 3);

EXPECT_TRUE(std::isnan(tatami_stats::medians::direct(static_cast<double*>(NULL), 0, false)));
}

TEST(ComputeMedians, Sparse) {
{
std::vector<int> vec { 2, 1, 4, 5, 3 };
int vsize = vec.size();
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 5, false), 3);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 11, false), 0);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 10, false), 0.5);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 9, false), 1);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 8, false), 1.5);
}
TEST_F(ComputeMediansTest, DenseTies) {
std::vector<int> vec { 1, 2, 3, 1, 2, 1 };
int vsize = vec.size();
EXPECT_EQ(direct_medians(vec.data(), vsize, false), 1.5);
EXPECT_EQ(direct_medians(vec.data() + 1, vsize - 1, false), 2);
}

{
std::vector<int> vec { -2, -1, -4, -5, -3 };
int vsize = vec.size();
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 5, false), -3);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 11, false), 0);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 10, false), -0.5);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 9, false), -1);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 8, false), -1.5);
TEST_F(ComputeMediansTest, DenseRealistic) {
for (size_t n = 10; n < 100; n += 10) {
std::mt19937_64 rng(n);
std::vector<double> contents;
std::normal_distribution dist;
for (size_t i = 0; i < n; ++i) {
contents.push_back(dist(rng));
}

// Even
{
auto copy = contents;
std::sort(copy.begin(), copy.end());
EXPECT_EQ(direct_medians(contents.data(), contents.size(), false), (copy[copy.size() / 2] + copy[copy.size() / 2 - 1]) / 2);
}

// Odd
{
auto copy = contents;
copy.pop_back();
std::sort(copy.begin(), copy.end());
EXPECT_EQ(direct_medians(contents.data(), contents.size() - 1, false), copy[copy.size() / 2]);
}
}
}

TEST_F(ComputeMediansTest, DenseNaN) {
std::vector<double> vec { 2, 1, std::numeric_limits<double>::quiet_NaN(), 5, 3 };
int vsize = vec.size();
EXPECT_EQ(direct_medians(vec.data(), vsize, true), 2.5);

vec[0] = std::numeric_limits<double>::quiet_NaN();
EXPECT_EQ(direct_medians(vec.data(), vsize, true), 3);

vec[4] = std::numeric_limits<double>::quiet_NaN();
EXPECT_EQ(direct_medians(vec.data(), vsize, true), 3);

// Various mixed flavors.
std::fill(vec.begin(), vec.end(), std::numeric_limits<double>::quiet_NaN());
EXPECT_TRUE(std::isnan(direct_medians(vec.data(), vsize, true)));
}

TEST_F(ComputeMediansTest, SparseAllPositive) {
std::vector<int> vec { 2, 1, 4, 5, 3 };
int vsize = vec.size();
EXPECT_EQ(direct_medians(vec.data(), vsize, 5, false), 3);
EXPECT_EQ(direct_medians(vec.data(), vsize, 11, false), 0);
EXPECT_EQ(direct_medians(vec.data(), vsize, 10, false), 0.5);
EXPECT_EQ(direct_medians(vec.data(), vsize, 9, false), 1);
EXPECT_EQ(direct_medians(vec.data(), vsize, 8, false), 1.5);

EXPECT_TRUE(std::isnan(tatami_stats::medians::direct(static_cast<double*>(NULL), 0, 0, false)));
}

TEST_F(ComputeMediansTest, SparseAllNegative) {
std::vector<int> vec { -2, -1, -4, -5, -3 };
int vsize = vec.size();
EXPECT_EQ(direct_medians(vec.data(), vsize, 5, false), -3);
EXPECT_EQ(direct_medians(vec.data(), vsize, 11, false), 0);
EXPECT_EQ(direct_medians(vec.data(), vsize, 10, false), -0.5);
EXPECT_EQ(direct_medians(vec.data(), vsize, 9, false), -1);
EXPECT_EQ(direct_medians(vec.data(), vsize, 8, false), -1.5);
}

TEST_F(ComputeMediansTest, SparseMixed) {
// Mostly positive.
{
std::vector<double> vec { 2.5, -1, 4, -5, 3 };
int vsize = vec.size();
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 5, false), 2.5);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 11, false), 0);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 10, false), 0);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 6, false), 1.25);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 7, false), 0);
EXPECT_EQ(direct_medians(vec.data(), vsize, 5, false), 2.5);
EXPECT_EQ(direct_medians(vec.data(), vsize, 11, false), 0);
EXPECT_EQ(direct_medians(vec.data(), vsize, 10, false), 0);
EXPECT_EQ(direct_medians(vec.data(), vsize, 6, false), 1.25);
EXPECT_EQ(direct_medians(vec.data(), vsize, 7, false), 0);
}

// Mostly negative.
{
std::vector<double> vec { -2.5, 1, -4, 5, -3 };
int vsize = vec.size();
EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 5, false), -2.5);
EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 11, false), 0);
EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 10, false), 0);
EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 6, false), -1.25);
EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 7, false), 0);
EXPECT_EQ(direct_medians(vec.data(), vsize, 5, false), -2.5);
EXPECT_EQ(direct_medians(vec.data(), vsize, 11, false), 0);
EXPECT_EQ(direct_medians(vec.data(), vsize, 10, false), 0);
EXPECT_EQ(direct_medians(vec.data(), vsize, 6, false), -1.25);
EXPECT_EQ(direct_medians(vec.data(), vsize, 7, false), 0);
}

// Plus missing values.
// Equal numbers of positive and negative.
{
std::vector<double> vec { 2, 1, std::numeric_limits<double>::quiet_NaN(), 5, 3 };
std::vector<double> vec { -2.5, 1, -4, 5, -3, 6 };
int vsize = vec.size();
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 8, true), 1);
EXPECT_EQ(tatami_stats::medians::direct<double>(vec.data(), vsize, 9, true), 0.5);
EXPECT_FLOAT_EQ(direct_medians(vec.data(), vsize, 6, false), -0.75);
EXPECT_EQ(direct_medians(vec.data(), vsize, 13, false), 0);
EXPECT_EQ(direct_medians(vec.data(), vsize, 12, false), 0);
EXPECT_EQ(direct_medians(vec.data(), vsize, 7, false), 0);
EXPECT_EQ(direct_medians(vec.data(), vsize, 8, false), 0);
}
}

EXPECT_TRUE(std::isnan(tatami_stats::medians::direct(static_cast<double*>(NULL), 0, 0, false)));
TEST_F(ComputeMediansTest, SparseNaN) {
std::vector<double> vec { 2, 1, std::numeric_limits<double>::quiet_NaN(), 5, 3 };
int vsize = vec.size();
EXPECT_EQ(direct_medians(vec.data(), vsize, 8, true), 1);
EXPECT_EQ(direct_medians(vec.data(), vsize, 9, true), 0.5);
}

TEST_F(ComputeMediansTest, SparseRealistic) {
for (int n = 10; n < 100; n += 5) {
std::mt19937_64 rng(n);
std::vector<double> contents;
std::normal_distribution dist;
for (int i = 0; i < n; ++i) {
contents.push_back(dist(rng));
}

{
auto ref = direct_medians(contents.data(), n, n, false);
EXPECT_EQ(ref, direct_medians(contents.data(), n, false));
}

// Replacing the back with a zero.
{
auto ref = direct_medians(contents.data(), n - 1, n, false);
auto copy = contents;
copy.back() = 0;
EXPECT_EQ(ref, direct_medians(copy.data(), n, false));
}

// Adding an extra zero.
{
auto ref = direct_medians(contents.data(), n, n + 1, false);
auto copy = contents;
copy.push_back(0);
EXPECT_EQ(ref, direct_medians(copy.data(), n + 1, false));
}

// Adding two extra zeros.
{
auto ref = direct_medians(contents.data(), n, n + 2, false);
auto copy = contents;
copy.push_back(0);
copy.push_back(0);
EXPECT_EQ(ref, direct_medians(copy.data(), n + 2, false));
}
}
}

TEST(ComputingDimMedians, SparseMedians) {
Expand Down

0 comments on commit 87ed962

Please sign in to comment.