diff --git a/include/tatami_stats/medians.hpp b/include/tatami_stats/medians.hpp index eec7baf..fb51ed3 100644 --- a/include/tatami_stats/medians.hpp +++ b/include/tatami_stats/medians.hpp @@ -96,15 +96,19 @@ Output_ direct(Value_* ptr, Index_ num, bool skip_nan) { size_t halfway = num / 2; bool is_even = (num % 2 == 0); - // At some point, I found two nth_element calls to be faster than partial_sort. std::nth_element(ptr, ptr + halfway, ptr + num); - double medtmp = *(ptr + halfway); - if (is_even) { - std::nth_element(ptr, ptr + halfway - 1, ptr + num); - return (medtmp + *(ptr + halfway - 1))/2; - } else { + Output_ medtmp = *(ptr + halfway); + if (!is_even) { return medtmp; } + + // 'nth_element()' reorganizes 'ptr' so that everything below 'halfway' is + // less than or equal to 'ptr[halfway]', while everything above 'halfway' + // is greater than or equal to 'ptr[halfway]'. Thus, to get the element + // immediately before 'halfway' in the sort order, we just need to find the + // maximum from '[0, halfway)'. + Output_ other = *std::max_element(ptr, ptr + halfway); + return (medtmp + other)/2; } /** @@ -126,6 +130,9 @@ Output_ direct(Value_* ptr, Index_ num, bool skip_nan) { */ template Output_ direct(Value_* value, Index_ num_nonzero, Index_ num_all, bool skip_nan) { + // Fallback to the dense code if there are no structural zeros. This is not + // just for efficiency as the downstream averaging code assumes that there + // is at least one structural zero when considering its scenarios. if (num_nonzero == num_all) { return direct(value, num_all, skip_nan); } @@ -151,35 +158,51 @@ Output_ direct(Value_* value, Index_ num_nonzero, Index_ num_all, bool skip_nan) size_t halfway = num_all / 2; bool is_even = (num_all % 2 == 0); - auto vend = value + num_nonzero; - std::sort(value, vend); - size_t zeropos = std::lower_bound(value, vend, 0) - value; - size_t nzero = num_all - num_nonzero; + size_t num_zero = num_all - num_nonzero; + size_t num_negative = 0; + for (Index_ i = 0; i < num_nonzero; ++i) { + num_negative += (value[i] < 0); + } if (!is_even) { - if (zeropos > halfway) { + if (num_negative > halfway) { + std::nth_element(value, value + halfway, value + num_nonzero); return value[halfway]; - } else if (halfway >= zeropos + nzero) { - return value[halfway - nzero]; + + } else if (halfway >= num_negative + num_zero) { + size_t skip_zeros = halfway - num_zero; + std::nth_element(value, value + skip_zeros, value + num_nonzero); + return value[skip_zeros]; + } else { - return 0; // zero is the median. + return 0; } } - double tmp = 0; - if (zeropos > halfway) { - tmp = value[halfway] + value[halfway - 1]; - } else if (zeropos == halfway) { - // guaranteed to be at least 1 zero. - tmp += value[halfway - 1]; - } else if (zeropos < halfway && zeropos + nzero > halfway) { - ; // zero is the median. - } else if (zeropos + nzero == halfway) { - // guaranteed to be at least 1 zero. - tmp += value[halfway - nzero]; - } else { - tmp = value[halfway - nzero] + value[halfway - nzero - 1]; + Output_ tmp = 0; + if (num_negative > halfway) { // both halves of the median are negative. + std::nth_element(value, value + halfway, value + num_nonzero); + tmp = value[halfway] + *(std::max_element(value, value + halfway)); // max_element gets the sorted value at halfway - 1, see explanation for the dense case. + + } else if (num_negative == halfway) { // the upper half is guaranteed to be zero. + size_t below_halfway = halfway - 1; + std::nth_element(value, value + below_halfway, value + num_nonzero); + tmp = value[below_halfway]; + + } else if (num_negative < halfway && num_negative + num_zero > halfway) { // both halves are zero, so zero is the median. + ; + + } else if (num_negative + num_zero == halfway) { // the lower half is guaranteed to be zero. + size_t skip_zeros = halfway - num_zero; + std::nth_element(value, value + skip_zeros, value + num_nonzero); + tmp = value[skip_zeros]; + + } else { // both halves of the median are non-negative. + size_t skip_zeros = halfway - num_zero; + std::nth_element(value, value + skip_zeros, value + num_nonzero); + tmp = value[skip_zeros] + *(std::max_element(value, value + skip_zeros)); // max_element gets the sorted value at skip_zeros - 1, see explanation for the dense case. } + return tmp / 2; } diff --git a/tests/src/medians.cpp b/tests/src/medians.cpp index ba67558..7968b59 100644 --- a/tests/src/medians.cpp +++ b/tests/src/medians.cpp @@ -10,75 +10,182 @@ #include "tatami_stats/medians.hpp" #include "tatami_test/tatami_test.hpp" -TEST(ComputeMedians, Dense) { - { - std::vector vec { 2, 1, 4, 5, 3 }; - int vsize = vec.size(); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, false), 3); - EXPECT_EQ(tatami_stats::medians::direct(vec.data() + 1, vsize - 1, false), 3.5); +class ComputeMediansTest : public ::testing::Test { +protected: + template + static double direct_medians(const Value_* vec, size_t n, bool skip_nan) { + std::vector copy(vec, vec + n); + return tatami_stats::medians::direct(copy.data(), copy.size(), skip_nan); } - // Now with NaN stripping. - { - std::vector vec { 2, 1, std::numeric_limits::quiet_NaN(), 5, 3 }; - int vsize = vec.size(); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, true), 2.5); + template + static double direct_medians(const Value_* vec, Index_ num_nonzero, Index_ num_all, bool skip_nan) { + std::vector copy(vec, vec + num_nonzero); + return tatami_stats::medians::direct(copy.data(), num_nonzero, num_all, skip_nan); } +}; + +TEST_F(ComputeMediansTest, DenseBasic) { + std::vector vec { 2, 1, 4, 5, 3 }; + int vsize = vec.size(); + EXPECT_EQ(direct_medians(vec.data(), vsize, false), 3); + EXPECT_EQ(direct_medians(vec.data() + 1, vsize - 1, false), 3.5); + EXPECT_EQ(direct_medians(vec.data(), vsize - 1, false), 3); EXPECT_TRUE(std::isnan(tatami_stats::medians::direct(static_cast(NULL), 0, false))); } -TEST(ComputeMedians, Sparse) { - { - std::vector vec { 2, 1, 4, 5, 3 }; - int vsize = vec.size(); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 5, false), 3); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 11, false), 0); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 10, false), 0.5); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 9, false), 1); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 8, false), 1.5); - } +TEST_F(ComputeMediansTest, DenseTies) { + std::vector vec { 1, 2, 3, 1, 2, 1 }; + int vsize = vec.size(); + EXPECT_EQ(direct_medians(vec.data(), vsize, false), 1.5); + EXPECT_EQ(direct_medians(vec.data() + 1, vsize - 1, false), 2); +} - { - std::vector vec { -2, -1, -4, -5, -3 }; - int vsize = vec.size(); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 5, false), -3); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 11, false), 0); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 10, false), -0.5); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 9, false), -1); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 8, false), -1.5); +TEST_F(ComputeMediansTest, DenseRealistic) { + for (size_t n = 10; n < 100; n += 10) { + std::mt19937_64 rng(n); + std::vector contents; + std::normal_distribution dist; + for (size_t i = 0; i < n; ++i) { + contents.push_back(dist(rng)); + } + + // Even + { + auto copy = contents; + std::sort(copy.begin(), copy.end()); + EXPECT_EQ(direct_medians(contents.data(), contents.size(), false), (copy[copy.size() / 2] + copy[copy.size() / 2 - 1]) / 2); + } + + // Odd + { + auto copy = contents; + copy.pop_back(); + std::sort(copy.begin(), copy.end()); + EXPECT_EQ(direct_medians(contents.data(), contents.size() - 1, false), copy[copy.size() / 2]); + } } +} + +TEST_F(ComputeMediansTest, DenseNaN) { + std::vector vec { 2, 1, std::numeric_limits::quiet_NaN(), 5, 3 }; + int vsize = vec.size(); + EXPECT_EQ(direct_medians(vec.data(), vsize, true), 2.5); + + vec[0] = std::numeric_limits::quiet_NaN(); + EXPECT_EQ(direct_medians(vec.data(), vsize, true), 3); + + vec[4] = std::numeric_limits::quiet_NaN(); + EXPECT_EQ(direct_medians(vec.data(), vsize, true), 3); - // Various mixed flavors. + std::fill(vec.begin(), vec.end(), std::numeric_limits::quiet_NaN()); + EXPECT_TRUE(std::isnan(direct_medians(vec.data(), vsize, true))); +} + +TEST_F(ComputeMediansTest, SparseAllPositive) { + std::vector vec { 2, 1, 4, 5, 3 }; + int vsize = vec.size(); + EXPECT_EQ(direct_medians(vec.data(), vsize, 5, false), 3); + EXPECT_EQ(direct_medians(vec.data(), vsize, 11, false), 0); + EXPECT_EQ(direct_medians(vec.data(), vsize, 10, false), 0.5); + EXPECT_EQ(direct_medians(vec.data(), vsize, 9, false), 1); + EXPECT_EQ(direct_medians(vec.data(), vsize, 8, false), 1.5); + + EXPECT_TRUE(std::isnan(tatami_stats::medians::direct(static_cast(NULL), 0, 0, false))); +} + +TEST_F(ComputeMediansTest, SparseAllNegative) { + std::vector vec { -2, -1, -4, -5, -3 }; + int vsize = vec.size(); + EXPECT_EQ(direct_medians(vec.data(), vsize, 5, false), -3); + EXPECT_EQ(direct_medians(vec.data(), vsize, 11, false), 0); + EXPECT_EQ(direct_medians(vec.data(), vsize, 10, false), -0.5); + EXPECT_EQ(direct_medians(vec.data(), vsize, 9, false), -1); + EXPECT_EQ(direct_medians(vec.data(), vsize, 8, false), -1.5); +} + +TEST_F(ComputeMediansTest, SparseMixed) { + // Mostly positive. { std::vector vec { 2.5, -1, 4, -5, 3 }; int vsize = vec.size(); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 5, false), 2.5); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 11, false), 0); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 10, false), 0); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 6, false), 1.25); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 7, false), 0); + EXPECT_EQ(direct_medians(vec.data(), vsize, 5, false), 2.5); + EXPECT_EQ(direct_medians(vec.data(), vsize, 11, false), 0); + EXPECT_EQ(direct_medians(vec.data(), vsize, 10, false), 0); + EXPECT_EQ(direct_medians(vec.data(), vsize, 6, false), 1.25); + EXPECT_EQ(direct_medians(vec.data(), vsize, 7, false), 0); } + // Mostly negative. { std::vector vec { -2.5, 1, -4, 5, -3 }; int vsize = vec.size(); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 5, false), -2.5); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 11, false), 0); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 10, false), 0); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 6, false), -1.25); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 7, false), 0); + EXPECT_EQ(direct_medians(vec.data(), vsize, 5, false), -2.5); + EXPECT_EQ(direct_medians(vec.data(), vsize, 11, false), 0); + EXPECT_EQ(direct_medians(vec.data(), vsize, 10, false), 0); + EXPECT_EQ(direct_medians(vec.data(), vsize, 6, false), -1.25); + EXPECT_EQ(direct_medians(vec.data(), vsize, 7, false), 0); } - // Plus missing values. + // Equal numbers of positive and negative. { - std::vector vec { 2, 1, std::numeric_limits::quiet_NaN(), 5, 3 }; + std::vector vec { -2.5, 1, -4, 5, -3, 6 }; int vsize = vec.size(); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 8, true), 1); - EXPECT_EQ(tatami_stats::medians::direct(vec.data(), vsize, 9, true), 0.5); + EXPECT_FLOAT_EQ(direct_medians(vec.data(), vsize, 6, false), -0.75); + EXPECT_EQ(direct_medians(vec.data(), vsize, 13, false), 0); + EXPECT_EQ(direct_medians(vec.data(), vsize, 12, false), 0); + EXPECT_EQ(direct_medians(vec.data(), vsize, 7, false), 0); + EXPECT_EQ(direct_medians(vec.data(), vsize, 8, false), 0); } +} - EXPECT_TRUE(std::isnan(tatami_stats::medians::direct(static_cast(NULL), 0, 0, false))); +TEST_F(ComputeMediansTest, SparseNaN) { + std::vector vec { 2, 1, std::numeric_limits::quiet_NaN(), 5, 3 }; + int vsize = vec.size(); + EXPECT_EQ(direct_medians(vec.data(), vsize, 8, true), 1); + EXPECT_EQ(direct_medians(vec.data(), vsize, 9, true), 0.5); +} + +TEST_F(ComputeMediansTest, SparseRealistic) { + for (int n = 10; n < 100; n += 5) { + std::mt19937_64 rng(n); + std::vector contents; + std::normal_distribution dist; + for (int i = 0; i < n; ++i) { + contents.push_back(dist(rng)); + } + + { + auto ref = direct_medians(contents.data(), n, n, false); + EXPECT_EQ(ref, direct_medians(contents.data(), n, false)); + } + + // Replacing the back with a zero. + { + auto ref = direct_medians(contents.data(), n - 1, n, false); + auto copy = contents; + copy.back() = 0; + EXPECT_EQ(ref, direct_medians(copy.data(), n, false)); + } + + // Adding an extra zero. + { + auto ref = direct_medians(contents.data(), n, n + 1, false); + auto copy = contents; + copy.push_back(0); + EXPECT_EQ(ref, direct_medians(copy.data(), n + 1, false)); + } + + // Adding two extra zeros. + { + auto ref = direct_medians(contents.data(), n, n + 2, false); + auto copy = contents; + copy.push_back(0); + copy.push_back(0); + EXPECT_EQ(ref, direct_medians(copy.data(), n + 2, false)); + } + } } TEST(ComputingDimMedians, SparseMedians) {