Skip to content

Commit

Permalink
Support custom comparison in Presto's IN function (#11032)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #11032

Building on #11021 this adds support for custom
comparison functions provided by custom types in Presto's IN function.

I was able to reuse the ComplexTypeInPredicate and the support for custom comparisons already
present in BaseVector.

This diff is largely just renaming ComplexTypeInPredicate to VectorSetInPredicate (to clarify it's not
just for complex types anymore) and if statement to identify the case where
providesCustomComparison() is true for the element type (and of course updating the tests).

Making TimestampWithTimeZone a special case of bigint (comparing the millis) in the future might
give a performance boost if this shows up as a bottleneck.

Reviewed By: xiaoxmeng

Differential Revision: D62994557

fbshipit-source-id: 82d3eb2c3d24118f7f555b3f739810c91c58fc32
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Sep 26, 2024
1 parent 71d0697 commit 7a9b141
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 57 deletions.
38 changes: 23 additions & 15 deletions velox/functions/prestosql/InPredicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,39 @@
namespace facebook::velox::functions {
namespace {

// This implements InPredicate using a set over VectorValues (pairs of
// BaseVector, index). Can be used in place of Filters for Types not supported
// by Filters or when custom comparisons are needed.
// Returns NULL if
// - input value is NULL
// - in-list is NULL or empty
// - input value doesn't have an exact match, but has an indeterminate match in
// the in-list. E.g., 'array[null] in (array[1])' or 'array[1] in
// (array[null])'.
class ComplexTypeInPredicate : public exec::VectorFunction {
class VectorSetInPredicate : public exec::VectorFunction {
public:
struct ComplexValue {
struct VectorValue {
BaseVector* vector;
vector_size_t index;
};

struct ComplexValueHash {
size_t operator()(ComplexValue value) const {
struct VectorValueHash {
size_t operator()(VectorValue value) const {
return value.vector->hashValueAt(value.index);
}
};

struct ComplexValueEqualTo {
bool operator()(ComplexValue left, ComplexValue right) const {
struct VectorValueEqualTo {
bool operator()(VectorValue left, VectorValue right) const {
return left.vector->equalValueAt(right.vector, left.index, right.index);
}
};

using ComplexSet =
folly::F14FastSet<ComplexValue, ComplexValueHash, ComplexValueEqualTo>;
using VectorSet =
folly::F14FastSet<VectorValue, VectorValueHash, VectorValueEqualTo>;

ComplexTypeInPredicate(
ComplexSet uniqueValues,
VectorSetInPredicate(
VectorSet uniqueValues,
bool hasNull,
VectorPtr originalValues)
: uniqueValues_{std::move(uniqueValues)},
Expand All @@ -58,7 +61,7 @@ class ComplexTypeInPredicate : public exec::VectorFunction {

static std::shared_ptr<exec::VectorFunction>
create(const VectorPtr& values, vector_size_t offset, vector_size_t size) {
ComplexSet uniqueValues;
VectorSet uniqueValues;
bool hasNull = false;

for (auto i = offset; i < offset + size; i++) {
Expand All @@ -68,7 +71,7 @@ class ComplexTypeInPredicate : public exec::VectorFunction {
uniqueValues.insert({values.get(), i});
}

return std::make_shared<ComplexTypeInPredicate>(
return std::make_shared<VectorSetInPredicate>(
std::move(uniqueValues), hasNull, values);
}

Expand Down Expand Up @@ -126,7 +129,7 @@ class ComplexTypeInPredicate : public exec::VectorFunction {

// Set of unique values to check against. This set doesn't include any value
// that is null or contains null.
const ComplexSet uniqueValues_;
const VectorSet uniqueValues_;

// Boolean indicating whether one of the value was null or contained null.
const bool hasNull_;
Expand Down Expand Up @@ -339,10 +342,15 @@ class InPredicate : public exec::VectorFunction {
}

const auto& elements = arrayVector->elements();
const auto& elementType = elements->type();

if (elementType->providesCustomComparison()) {
return VectorSetInPredicate::create(elements, offset, size);
}

std::pair<std::unique_ptr<common::Filter>, bool> filter;

switch (inListType->childAt(0)->kind()) {
switch (elementType->kind()) {
case TypeKind::HUGEINT:
filter = createHugeintValuesFilter<int128_t>(elements, offset, size);
break;
Expand Down Expand Up @@ -384,7 +392,7 @@ class InPredicate : public exec::VectorFunction {
case TypeKind::MAP:
[[fallthrough]];
case TypeKind::ROW:
return ComplexTypeInPredicate::create(elements, offset, size);
return VectorSetInPredicate::create(elements, offset, size);
default:
VELOX_UNSUPPORTED(
"Unsupported in-list type for IN predicate: {}",
Expand Down
103 changes: 61 additions & 42 deletions velox/functions/prestosql/tests/InPredicateTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
* limitations under the License.
*/
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/lib/DateTimeFormatter.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/type/tz/TimeZoneMap.h"

using namespace facebook::velox::test;
using namespace facebook::velox::functions::test;
Expand All @@ -25,50 +28,50 @@ namespace {
class InPredicateTest : public FunctionBaseTest {
protected:
template <typename T>
std::string getInList(
ArrayVectorPtr getInList(
std::vector<std::optional<T>> input,
const TypePtr& type = CppToType<T>::create()) {
const TypePtr& type) {
FlatVectorPtr<T> flatVec = makeNullableFlatVector<T>(input, type);
std::string inList;
auto len = flatVec->size();
auto toString = [&](vector_size_t idx) {
if (type->isDecimal()) {
if (flatVec->isNullAt(idx)) {
return std::string("null");
}
return fmt::format(
"cast({} as {})", flatVec->toString(idx), type->toString());
}
return flatVec->toString(idx);
};

for (auto i = 0; i < len - 1; i++) {
inList += fmt::format("{}, ", toString(i));
}
inList += toString(len - 1);
return inList;
return makeArrayVector({0, flatVec->size()}, flatVec);
}

core::TypedExprPtr makeInExpression(
const std::string& probe,
const ArrayVectorPtr& inList,
const TypePtr& type) {
return std::make_shared<core::CallTypedExpr>(
BOOLEAN(),
std::vector<core::TypedExprPtr>{
std::make_shared<core::FieldAccessTypedExpr>(type, probe),
std::make_shared<core::ConstantTypedExpr>(inList)},
"in");
}

template <typename T>
void testValues(const TypePtr type = CppToType<T>::create()) {
void testValues(
const TypePtr type = CppToType<T>::create(),
std::function<T(vector_size_t /*row*/)> valueAt = [](auto row) {
return row % 17;
}) {
if (type->isDecimal()) {
this->options_.parseDecimalAsDouble = false;
}
std::shared_ptr<memory::MemoryPool> pool{
memory::memoryManager()->addLeafPool()};

const vector_size_t size = 1'000;
auto inList = getInList<T>({1, 3, 5}, type);
auto inList = getInList<T>({valueAt(1), valueAt(3), valueAt(5)}, type);

auto vector = makeFlatVector<T>(
size, [](auto row) { return row % 17; }, nullptr, type);
size, [&](auto row) { return valueAt(row); }, nullptr, type);
auto vectorWithNulls = makeFlatVector<T>(
size, [](auto row) { return row % 17; }, nullEvery(7), type);
size, [&](auto row) { return valueAt(row); }, nullEvery(7), type);
auto rowVector = makeRowVector({vector, vectorWithNulls});

// no nulls
auto result = evaluate<SimpleVector<bool>>(
fmt::format("c0 IN ({})", inList), rowVector);
makeInExpression("c0", inList, type), rowVector);
auto expected = makeFlatVector<bool>(size, [](auto row) {
auto n = row % 17;
return n == 1 || n == 3 || n == 5;
Expand All @@ -78,7 +81,7 @@ class InPredicateTest : public FunctionBaseTest {

// some nulls
result = evaluate<SimpleVector<bool>>(
fmt::format("c1 IN ({})", inList), rowVector);
makeInExpression("c1", inList, type), rowVector);
expected = makeFlatVector<bool>(
size,
[](auto row) {
Expand All @@ -91,9 +94,10 @@ class InPredicateTest : public FunctionBaseTest {

// null values in the in-list
// The results can be either true or null, but not false.
inList = getInList<T>({1, 3, std::nullopt, 5}, type);
inList =
getInList<T>({valueAt(1), valueAt(3), std::nullopt, valueAt(5)}, type);
result = evaluate<SimpleVector<bool>>(
fmt::format("c0 IN ({})", inList), rowVector);
makeInExpression("c0", inList, type), rowVector);
expected = makeFlatVector<bool>(
size,
[](auto /* row */) { return true; },
Expand All @@ -105,7 +109,7 @@ class InPredicateTest : public FunctionBaseTest {
assertEqualVectors(expected, result);

result = evaluate<SimpleVector<bool>>(
fmt::format("c1 IN ({})", inList), rowVector);
makeInExpression("c1", inList, type), rowVector);
expected = makeFlatVector<bool>(
size,
[](auto /* row */) { return true; },
Expand All @@ -116,9 +120,9 @@ class InPredicateTest : public FunctionBaseTest {

assertEqualVectors(expected, result);

inList = getInList<T>({2, std::nullopt}, type);
inList = getInList<T>({valueAt(2), std::nullopt}, type);
result = evaluate<SimpleVector<bool>>(
fmt::format("c0 IN ({})", inList), rowVector);
makeInExpression("c0", inList, type), rowVector);
expected = makeFlatVector<bool>(
size,
[](auto /* row */) { return true; },
Expand All @@ -130,7 +134,7 @@ class InPredicateTest : public FunctionBaseTest {
assertEqualVectors(expected, result);

result = evaluate<SimpleVector<bool>>(
fmt::format("c1 IN ({})", inList), rowVector);
makeInExpression("c1", inList, type), rowVector);
expected = makeFlatVector<bool>(
size,
[](auto /* row */) { return true; },
Expand Down Expand Up @@ -173,9 +177,9 @@ class InPredicateTest : public FunctionBaseTest {

rowVector = makeRowVector({dict});

inList = getInList<T>({2, 5, 9}, type);
inList = getInList<T>({valueAt(2), valueAt(5), valueAt(9)}, type);
result = evaluate<SimpleVector<bool>>(
fmt::format("c0 IN ({})", inList), rowVector);
makeInExpression("c0", inList, type), rowVector);
assertEqualVectors(expected, result);

// an in list with nulls only is always null.
Expand All @@ -186,37 +190,41 @@ class InPredicateTest : public FunctionBaseTest {
}

template <typename T>
void testConstantValues(const TypePtr type = CppToType<T>::create()) {
void testConstantValues(
const TypePtr type = CppToType<T>::create(),
std::function<T(vector_size_t /*row*/)> valueAt = [](auto row) {
return row % 17;
}) {
const vector_size_t size = 1'000;
auto rowVector = makeRowVector(
{makeConstant(static_cast<T>(123), size, type),
{makeConstant(valueAt(123), size, type),
BaseVector::createNullConstant(type, size, pool())});
auto inList = getInList<T>({1, 3, 5}, type);
auto inList = getInList<T>({valueAt(1), valueAt(3), valueAt(5)}, type);

auto constTrue = makeConstant(true, size);
auto constFalse = makeConstant(false, size);
auto constNull = makeNullConstant(TypeKind::BOOLEAN, size);

// a miss
auto result = evaluate<SimpleVector<bool>>(
fmt::format("c0 IN ({})", inList), rowVector);
makeInExpression("c0", inList, type), rowVector);
assertEqualVectors(constFalse, result);

// null
result = evaluate<SimpleVector<bool>>(
fmt::format("c1 IN ({})", inList), rowVector);
makeInExpression("c1", inList, type), rowVector);
assertEqualVectors(constNull, result);

// a hit
inList = getInList<T>({1, 123, 5}, type);
inList = getInList<T>({valueAt(1), valueAt(123), valueAt(5)}, type);
result = evaluate<SimpleVector<bool>>(
fmt::format("c0 IN ({})", inList), rowVector);
makeInExpression("c0", inList, type), rowVector);
assertEqualVectors(constTrue, result);

// a miss that is a null
inList = getInList<T>({1, std::nullopt, 5}, type);
inList = getInList<T>({valueAt(1), std::nullopt, valueAt(5)}, type);
result = evaluate<SimpleVector<bool>>(
fmt::format("c0 IN ({})", inList), rowVector);
makeInExpression("c1", inList, type), rowVector);
assertEqualVectors(constNull, result);
}

Expand Down Expand Up @@ -1120,5 +1128,16 @@ TEST_F(InPredicateTest, nans) {
testNaNs<double>();
}

TEST_F(InPredicateTest, TimestampWithTimeZone) {
// The millis ranges from 0-17, but after every 17th row we increment the time
// zone ID, so that no two rows have the same millis and time zone. However,
// by the semantics of TimestampWithTimeZone's comparison, it's the same 17
// values repeated.
auto valueAt = [](auto row) { return pack(row % 17, row / 17); };

testValues<int64_t>(TIMESTAMP_WITH_TIME_ZONE(), valueAt);
testConstantValues<int64_t>(TIMESTAMP_WITH_TIME_ZONE(), valueAt);
}

} // namespace
} // namespace facebook::velox::functions

0 comments on commit 7a9b141

Please sign in to comment.