Skip to content

Commit

Permalink
Add support for custom comparison in Presto's array_distinct (#11143)
Browse files Browse the repository at this point in the history
Summary:

Update Presto's array_distinct UDF to work with types that provide custom
comparison.  We can reuse the implementation of ValueSet for complex types, since
that just uses the compare and hash functions provided by the Vector.  With
#11022 these just invoke the Type's
custom implementations of these functions.

Reviewed By: xiaoxmeng

Differential Revision: D63719542
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Oct 7, 2024
1 parent efd99ee commit 86e87d7
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 17 deletions.
50 changes: 33 additions & 17 deletions velox/functions/prestosql/ArrayDistinct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,35 @@ struct ValueSet {

template <>
struct ValueSet<ComplexType> {
using TKey = std::tuple<uint64_t, const BaseVector*, vector_size_t>;
struct Key {
const uint64_t hash;
const BaseVector* vector;
const vector_size_t index;
};

struct Hash {
size_t operator()(const TKey& key) const {
return std::get<0>(key);
size_t operator()(const Key& key) const {
return key.hash;
}
};

struct EqualTo {
bool operator()(const TKey& left, const TKey& right) const {
return std::get<1>(left)
bool operator()(const Key& left, const Key& right) const {
return left.vector
->equalValueAt(
std::get<1>(right),
std::get<2>(left),
std::get<2>(right),
right.vector,
left.index,
right.index,
CompareFlags::NullHandlingMode::kNullAsValue)
.value();
}
};

folly::F14FastSet<TKey, Hash, EqualTo> values;
folly::F14FastSet<Key, Hash, EqualTo> values;

bool insert(const BaseVector* vector, vector_size_t index) {
const uint64_t hash = vector->hashValueAt(index);
return values.insert(std::make_tuple(hash, vector, index)).second;
return values.insert({hash, vector, index}).second;
}

void reset() {
Expand All @@ -86,7 +90,7 @@ struct ValueSet<ComplexType> {
/// which will be present in the output, and wrapped into a DictionaryVector.
/// Next the `lengths` and `offsets` vectors that control where output arrays
/// start and end are wrapped into the output ArrayVector.template <typename T>
template <typename T>
template <typename T, bool useCustomComparison = false>
class ArrayDistinctFunction : public exec::VectorFunction {
public:
void apply(
Expand Down Expand Up @@ -117,6 +121,12 @@ class ArrayDistinctFunction : public exec::VectorFunction {
}

private:
// We want to use ValueSet<ComplexType> when we need custom comparison because
// it uses the Vector's implementation of compare and hash which it gets from
// the type.
using ValueSetT = std::
conditional_t<useCustomComparison, ValueSet<ComplexType>, ValueSet<T>>;

VectorPtr applyFlat(
const SelectivityVector& rows,
const VectorPtr& arg,
Expand All @@ -143,7 +153,7 @@ class ArrayDistinctFunction : public exec::VectorFunction {
auto* rawNewOffsets = newOffsets->asMutable<vector_size_t>();

// Process the rows: store unique values in the hash table.
ValueSet<T> uniqueSet;
ValueSetT uniqueSet;

rows.applyToSelected([&](vector_size_t row) {
auto size = arrayVector->sizeAt(row);
Expand All @@ -159,7 +169,7 @@ class ArrayDistinctFunction : public exec::VectorFunction {
}
} else {
bool unique;
if constexpr (std::is_same_v<ComplexType, T>) {
if constexpr (std::is_same_v<ValueSetT, ValueSet<ComplexType>>) {
unique = uniqueSet.insert(elements->base(), elements->index(i));
} else {
auto value = elements->valueAt<T>(i);
Expand Down Expand Up @@ -193,7 +203,7 @@ class ArrayDistinctFunction : public exec::VectorFunction {
};

template <>
VectorPtr ArrayDistinctFunction<UnknownType>::applyFlat(
VectorPtr ArrayDistinctFunction<UnknownType, false>::applyFlat(
const SelectivityVector& rows,
const VectorPtr& arg,
exec::EvalCtx& context) const {
Expand Down Expand Up @@ -259,11 +269,17 @@ void validateType(const std::vector<exec::VectorFunctionArg>& inputArgs) {
// Create function template based on type.
template <TypeKind kind>
std::shared_ptr<exec::VectorFunction> createTyped(
const std::vector<exec::VectorFunctionArg>& inputArgs) {
const std::vector<exec::VectorFunctionArg>& inputArgs,
const TypePtr& elementType) {
VELOX_CHECK_EQ(inputArgs.size(), 1);

using T = typename TypeTraits<kind>::NativeType;
return std::make_shared<ArrayDistinctFunction<T>>();

if (elementType->providesCustomComparison()) {
return std::make_shared<ArrayDistinctFunction<T, true>>();
} else {
return std::make_shared<ArrayDistinctFunction<T, false>>();
}
}

// Create function.
Expand All @@ -282,7 +298,7 @@ std::shared_ptr<exec::VectorFunction> create(
}

return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
createTyped, elementType->kind(), inputArgs);
createTyped, elementType->kind(), inputArgs, elementType);
}

// Define function signature.
Expand Down
67 changes: 67 additions & 0 deletions velox/functions/prestosql/tests/ArrayDistinctTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <optional>
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

using namespace facebook::velox;
using namespace facebook::velox::test;
Expand Down Expand Up @@ -369,3 +370,69 @@ TEST_F(ArrayDistinctTest, unknownType) {
result = evaluate("array_distinct(c0)", makeRowVector({nullArrayVector}));
assertEqualVectors(expected, result);
}

TEST_F(ArrayDistinctTest, timestampWithTimezone) {
const auto testArrayDistinct =
[this](
const std::vector<std::optional<int64_t>>& inputArray,
const std::vector<std::optional<int64_t>>& expectedArray) {
const auto input = makeRowVector({makeArrayVector(
{0},
makeNullableFlatVector(inputArray, TIMESTAMP_WITH_TIME_ZONE()))});
const auto expected = makeArrayVector(
{0},
makeNullableFlatVector(expectedArray, TIMESTAMP_WITH_TIME_ZONE()));

assertEqualVectors(expected, evaluate("array_distinct(c0)", input));
};

testArrayDistinct({}, {});
testArrayDistinct({pack(0, 0)}, {pack(0, 0)});
testArrayDistinct({pack(1, 0)}, {pack(1, 0)});
testArrayDistinct(
{pack(std::numeric_limits<int64_t>::min(), 0)},
{pack(std::numeric_limits<int64_t>::min(), 0)});
testArrayDistinct(
{pack(std::numeric_limits<int64_t>::max(), 0)},
{pack(std::numeric_limits<int64_t>::max(), 0)});
testArrayDistinct({std::nullopt}, {std::nullopt});
testArrayDistinct({pack(-1, 0)}, {pack(-1, 0)});
testArrayDistinct(
{pack(1, 3), pack(2, 2), pack(3, 1)},
{pack(1, 3), pack(2, 2), pack(3, 1)});
testArrayDistinct(
{pack(1, 0), pack(2, 1), pack(1, 2)}, {pack(1, 0), pack(2, 1)});
testArrayDistinct({pack(1, 0), pack(1, 1), pack(1, 2)}, {pack(1, 0)});
testArrayDistinct(
{pack(-1, 0), pack(-2, 1), pack(-3, 2)},
{pack(-1, 0), pack(-2, 1), pack(-3, 2)});
testArrayDistinct(
{pack(-1, 0), pack(-2, 1), pack(-1, 2)}, {pack(-1, 0), pack(-2, 1)});
testArrayDistinct({pack(-1, 0), pack(-1, 1), pack(-1, 2)}, {pack(-1, 0)});
testArrayDistinct({std::nullopt, std::nullopt, std::nullopt}, {std::nullopt});
testArrayDistinct(
{pack(1, 0), pack(2, 1), pack(-2, 2), pack(1, 3)},
{pack(1, 0), pack(2, 1), pack(-2, 2)});
testArrayDistinct(
{pack(1, 0),
pack(1, 1),
pack(-2, 2),
pack(-2, 3),
pack(-2, 4),
pack(4, 5),
pack(8, 6)},
{pack(1, 0), pack(-2, 2), pack(4, 5), pack(8, 6)});
testArrayDistinct(
{pack(3, 0), pack(8, 1), std::nullopt},
{pack(3, 0), pack(8, 1), std::nullopt});
testArrayDistinct(
{pack(1, 0),
pack(2, 1),
pack(3, 2),
std::nullopt,
pack(4, 3),
pack(1, 4),
pack(2, 5),
std::nullopt},
{pack(1, 0), pack(2, 1), pack(3, 2), std::nullopt, pack(4, 3)});
}

0 comments on commit 86e87d7

Please sign in to comment.