Skip to content

Commit

Permalink
Fix MultiRange filter to correctly handle NaN input
Browse files Browse the repository at this point in the history
Summary:
The current MultiRange filter includes an extra flag specifically for
allowing or filtering out NaN values, independent of its internal
filters. This responsibility should instead be managed by the filters
themselves. This change therefore removes that additional handling to
ensure consistency between using a single filter VS using multiple.

Differential Revision: D60138398
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed Jul 23, 2024
1 parent dd7b290 commit 6da7ee6
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 69 deletions.
2 changes: 1 addition & 1 deletion velox/connectors/hive/HiveConnectorUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ std::unique_ptr<common::Filter> makeFloatingPointMapKeyFilter(
if (filters.size() == 1) {
return std::move(filters[0]);
}
return std::make_unique<common::MultiRange>(std::move(filters), false, false);
return std::make_unique<common::MultiRange>(std::move(filters), false);
}

// Recursively add subfields to scan spec.
Expand Down
3 changes: 1 addition & 2 deletions velox/expression/ExprToSubfieldFilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,7 @@ std::unique_ptr<common::Filter> makeNotEqualFilter(
std::vector<std::unique_ptr<common::Filter>> filters;
filters.emplace_back(std::move(lessThanFilter));
filters.emplace_back(std::move(greaterThanFilter));
return std::make_unique<common::MultiRange>(
std::move(filters), false, false);
return std::make_unique<common::MultiRange>(std::move(filters), false);
}
}

Expand Down
10 changes: 3 additions & 7 deletions velox/expression/ExprToSubfieldFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,16 +320,12 @@ inline std::unique_ptr<common::IsNotNull> isNotNull() {
}

template <typename T>
std::unique_ptr<common::MultiRange> orFilter(
std::unique_ptr<T> a,
std::unique_ptr<T> b,
bool nullAllowed = false,
bool nanAllowed = false) {
std::unique_ptr<common::MultiRange>
orFilter(std::unique_ptr<T> a, std::unique_ptr<T> b, bool nullAllowed = false) {
std::vector<std::unique_ptr<common::Filter>> filters;
filters.emplace_back(std::move(a));
filters.emplace_back(std::move(b));
return std::make_unique<common::MultiRange>(
std::move(filters), nullAllowed, nanAllowed);
return std::make_unique<common::MultiRange>(std::move(filters), nullAllowed);
}

inline std::unique_ptr<common::HugeintRange> lessThanHugeint(
Expand Down
28 changes: 7 additions & 21 deletions velox/type/Filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,6 @@ bool NegatedBytesValues::testingEquals(const Filter& other) const {

folly::dynamic MultiRange::serialize() const {
auto obj = Filter::serializeBase("MultiRange");
obj["nanAllowed"] = nanAllowed_;
folly::dynamic arr = folly::dynamic::array;
for (const auto& f : filters_) {
arr.push_back(f->serialize());
Expand All @@ -723,7 +722,6 @@ folly::dynamic MultiRange::serialize() const {

FilterPtr MultiRange::create(const folly::dynamic& obj) {
auto nullAllowed = deserializeNullAllowed(obj);
auto nanAllowed = obj["nanAllowed"].asBool();
folly::dynamic arr = obj["filters"];
auto tmpFilters = ISerializable::deserialize<std::vector<Filter>>(arr);

Expand All @@ -733,14 +731,13 @@ FilterPtr MultiRange::create(const folly::dynamic& obj) {
filters.emplace_back(f->clone());
}

return std::make_unique<MultiRange>(
std::move(filters), nullAllowed, nanAllowed);
return std::make_unique<MultiRange>(std::move(filters), nullAllowed);
}

bool MultiRange::testingEquals(const Filter& other) const {
auto otherMultiRange = dynamic_cast<const MultiRange*>(&other);
auto res = otherMultiRange != nullptr && Filter::testingBaseEquals(other) &&
nanAllowed_ == otherMultiRange->nanAllowed_ &&

filters_.size() == otherMultiRange->filters_.size();

if (!res) {
Expand Down Expand Up @@ -1357,7 +1354,7 @@ std::unique_ptr<Filter> NegatedBytesRange::toMultiRange() const {
if (accepted.size() == 1) {
return accepted[0]->clone(nullAllowed_);
}
return std::make_unique<MultiRange>(std::move(accepted), nullAllowed_, false);
return std::make_unique<MultiRange>(std::move(accepted), nullAllowed_);
}

bool NegatedBytesValues::testBytesRange(
Expand Down Expand Up @@ -1439,17 +1436,13 @@ std::unique_ptr<Filter> MultiRange::clone(

if (nullAllowed) {
return std::make_unique<MultiRange>(
std::move(filters), nullAllowed.value(), nanAllowed_);
std::move(filters), nullAllowed.value());
} else {
return std::make_unique<MultiRange>(
std::move(filters), nullAllowed_, nanAllowed_);
return std::make_unique<MultiRange>(std::move(filters), nullAllowed_);
}
}

bool MultiRange::testDouble(double value) const {
if (std::isnan(value)) {
return nanAllowed_;
}
for (const auto& filter : filters_) {
if (filter->testDouble(value)) {
return true;
Expand All @@ -1459,9 +1452,6 @@ bool MultiRange::testDouble(double value) const {
}

bool MultiRange::testFloat(float value) const {
if (std::isnan(value)) {
return nanAllowed_;
}
for (const auto& filter : filters_) {
if (filter->testFloat(value)) {
return true;
Expand Down Expand Up @@ -1554,15 +1544,13 @@ std::unique_ptr<Filter> MultiRange::mergeWith(const Filter* other) const {
case FilterKind::kBytesRange:
case FilterKind::kMultiRange: {
bool bothNullAllowed = nullAllowed_ && other->testNull();
bool bothNanAllowed = nanAllowed_;
std::vector<const Filter*> otherFilters;

if (other->kind() == FilterKind::kMultiRange) {
auto multiRangeOther = static_cast<const MultiRange*>(other);
for (auto const& filterOther : multiRangeOther->filters()) {
otherFilters.emplace_back(filterOther.get());
}
bothNanAllowed = bothNanAllowed && multiRangeOther->nanAllowed();
} else {
otherFilters.emplace_back(other);
}
Expand Down Expand Up @@ -1614,8 +1602,7 @@ std::unique_ptr<Filter> MultiRange::mergeWith(const Filter* other) const {
} else if (merged.size() == 1) {
return merged.front()->clone(bothNullAllowed);
} else {
return std::make_unique<MultiRange>(
std::move(merged), bothNullAllowed, bothNanAllowed);
return std::make_unique<MultiRange>(std::move(merged), bothNullAllowed);
}
}
default:
Expand Down Expand Up @@ -2637,8 +2624,7 @@ std::unique_ptr<Filter> NegatedBytesValues::mergeWith(
bytesRangeOther->upperUnbounded(),
hiExclusive,
false));
return std::make_unique<MultiRange>(
std::move(ranges), bothNullAllowed, false);
return std::make_unique<MultiRange>(std::move(ranges), bothNullAllowed);
}
default:
VELOX_UNREACHABLE();
Expand Down
16 changes: 2 additions & 14 deletions velox/type/Filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2042,16 +2042,9 @@ class MultiRange final : public Filter {
/// All entries must support the same data types.
/// @param nullAllowed Null values are passing the filter if true. nullAllowed
/// flags in the 'ranges' filters are ignored.
/// @param nanAllowed Not-a-Number floating point values are passing the
/// filter if true. Applies to floating point data types only. NaN values are
/// not further tested using contained filters.
MultiRange(
std::vector<std::unique_ptr<Filter>> filters,
bool nullAllowed,
bool nanAllowed)
MultiRange(std::vector<std::unique_ptr<Filter>> filters, bool nullAllowed)
: Filter(true, nullAllowed, FilterKind::kMultiRange),
filters_(std::move(filters)),
nanAllowed_(nanAllowed) {}
filters_(std::move(filters)) {}

folly::dynamic serialize() const override;

Expand Down Expand Up @@ -2083,15 +2076,10 @@ class MultiRange final : public Filter {

std::unique_ptr<Filter> mergeWith(const Filter* other) const override final;

bool nanAllowed() const {
return nanAllowed_;
}

bool testingEquals(const Filter& other) const final;

private:
const std::vector<std::unique_ptr<Filter>> filters_;
const bool nanAllowed_;
};

// Helper for applying filters to different types
Expand Down
2 changes: 1 addition & 1 deletion velox/type/tests/FilterSerDeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ TEST_F(FilterSerDeTest, multiFilter) {
filters.emplace_back(std::make_unique<BytesRange>(
"ABCD", true, true, "FFFF", false, true, false));

MultiRange multiRange(std::move(filters), true, true);
MultiRange multiRange(std::move(filters), true);
testSerde(multiRange);
}

Expand Down
28 changes: 8 additions & 20 deletions velox/type/tests/FilterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,7 @@ TEST(FilterTest, multiRange) {
EXPECT_TRUE(filter->testDouble(1.3));

EXPECT_FALSE(filter->testNull());
EXPECT_FALSE(filter->testDouble(std::nan("nan")));
EXPECT_TRUE(filter->testDouble(std::nan("nan")));
EXPECT_FALSE(filter->testDouble(1.2));

filter = orFilter(lessThanFloat(1.2), greaterThanFloat(1.2));
Expand All @@ -1059,7 +1059,7 @@ TEST(FilterTest, multiRange) {
EXPECT_TRUE(filter->testFloat(1.1f));
EXPECT_FALSE(filter->testFloat(1.2f));
EXPECT_TRUE(filter->testFloat(1.3f));
EXPECT_FALSE(filter->testFloat(std::nanf("nan")));
EXPECT_TRUE(filter->testFloat(std::nanf("nan")));

// != ''
filter = orFilter(lessThan(""), greaterThan(""));
Expand All @@ -1069,51 +1069,39 @@ TEST(FilterTest, multiRange) {

TEST(FilterTest, multiRangeWithNaNs) {
// x <> 1.2 with nanAllowed true
auto filter =
orFilter(lessThanFloat(1.2), greaterThanFloat(1.2), false, true);
auto filter = orFilter(lessThanFloat(1.2), greaterThanFloat(1.2), false);
EXPECT_TRUE(filter->testFloat(std::nanf("nan")));
EXPECT_FALSE(filter->testFloat(1.2f));
EXPECT_TRUE(filter->testFloat(1.1f));

filter = orFilter(lessThanDouble(1.2), greaterThanDouble(1.2), false, true);
filter = orFilter(lessThanDouble(1.2), greaterThanDouble(1.2), false);
EXPECT_TRUE(filter->testDouble(std::nan("nan")));
EXPECT_FALSE(filter->testDouble(1.2));
EXPECT_TRUE(filter->testDouble(1.1));

// x <> 1.2 with nanAllowed false
filter = orFilter(lessThanFloat(1.2), greaterThanFloat(1.2));
EXPECT_FALSE(filter->testFloat(std::nanf("nan")));
EXPECT_TRUE(filter->testFloat(std::nanf("nan")));
EXPECT_TRUE(filter->testFloat(1.0f));

filter = orFilter(lessThanDouble(1.2), greaterThanDouble(1.2));
EXPECT_FALSE(filter->testDouble(std::nan("nan")));
EXPECT_TRUE(filter->testDouble(std::nan("nan")));
EXPECT_TRUE(filter->testDouble(1.4));

// x NOT IN (1.2, 1.3) with nanAllowed true
filter = orFilter(lessThanFloat(1.2), greaterThanFloat(1.3), false, true);
filter = orFilter(lessThanFloat(1.2), greaterThanFloat(1.3), false);
EXPECT_TRUE(filter->testFloat(std::nanf("nan")));
EXPECT_FALSE(filter->testFloat(1.2f));
EXPECT_FALSE(filter->testFloat(1.3f));
EXPECT_TRUE(filter->testFloat(1.4f));
EXPECT_TRUE(filter->testFloat(1.1f));

filter = orFilter(lessThanDouble(1.2), greaterThanDouble(1.3), false, true);
filter = orFilter(lessThanDouble(1.2), greaterThanDouble(1.3), false);
EXPECT_TRUE(filter->testDouble(std::nan("nan")));
EXPECT_FALSE(filter->testDouble(1.2));
EXPECT_FALSE(filter->testDouble(1.3));
EXPECT_TRUE(filter->testDouble(1.4));
EXPECT_TRUE(filter->testDouble(1.1));

// x NOT IN (1.2) with nanAllowed false
filter = orFilter(lessThanFloat(1.2), greaterThanFloat(1.2));
EXPECT_FALSE(filter->testFloat(std::nanf("nan")));
EXPECT_FALSE(filter->testFloat(1.2f));
EXPECT_TRUE(filter->testFloat(1.3f));

filter = orFilter(lessThanDouble(1.2), greaterThanDouble(1.2));
EXPECT_FALSE(filter->testDouble(std::nan("nan")));
EXPECT_FALSE(filter->testDouble(1.2));
EXPECT_TRUE(filter->testDouble(1.3));
}

TEST(FilterTest, createBigintValues) {
Expand Down
4 changes: 2 additions & 2 deletions velox/type/tests/NegatedBytesRangeBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ int32_t main(int32_t argc, char** argv) {
"", true, false, lo, false, true, false));
rangeFilters.emplace_back(std::make_unique<common::BytesRange>(
hi, false, false, "", true, false, false));
multiRanges.emplace_back(std::make_unique<common::MultiRange>(
std::move(rangeFilters), false, false));
multiRanges.emplace_back(
std::make_unique<common::MultiRange>(std::move(rangeFilters), false));

LOG(INFO) << "Generated filter for length " << len << " with percentage "
<< pct;
Expand Down
2 changes: 1 addition & 1 deletion velox/type/tests/NegatedBytesValuesBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ int32_t main(int32_t argc, char* argv[]) {
range_filters.emplace_back(std::make_unique<common::BytesRange>(
*back, false, true, "", true, true, false));
multi_ranges.emplace_back(std::make_unique<common::MultiRange>(
std::move(range_filters), false, false));
std::move(range_filters), false));

LOG(INFO) << "Generated filter for length " << len << " with size "
<< size;
Expand Down

0 comments on commit 6da7ee6

Please sign in to comment.