Skip to content

Commit

Permalink
Add Decimal support to set_agg and set_union
Browse files Browse the repository at this point in the history
  • Loading branch information
rrando901 committed Dec 13, 2023
1 parent db64f6f commit 829978c
Show file tree
Hide file tree
Showing 3 changed files with 381 additions and 8 deletions.
26 changes: 18 additions & 8 deletions velox/functions/prestosql/aggregates/SetAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,9 @@ class SetUnionAggregate : public SetBaseAggregate<T> {

template <template <typename T> class Aggregate>
std::unique_ptr<exec::Aggregate> create(
TypeKind typeKind,
const TypePtr& inputType,
const TypePtr& resultType) {
switch (typeKind) {
switch (inputType->kind()) {
case TypeKind::BOOLEAN:
return std::make_unique<Aggregate<bool>>(resultType);
case TypeKind::TINYINT:
Expand All @@ -398,6 +398,11 @@ std::unique_ptr<exec::Aggregate> create(
return std::make_unique<Aggregate<int32_t>>(resultType);
case TypeKind::BIGINT:
return std::make_unique<Aggregate<int64_t>>(resultType);
case TypeKind::HUGEINT:
VELOX_CHECK(
inputType->isLongDecimal(),
"Non-decimal use of HUGEINT is not supported");
return std::make_unique<Aggregate<int128_t>>(resultType);
case TypeKind::REAL:
return std::make_unique<Aggregate<float>>(resultType);
case TypeKind::DOUBLE:
Expand All @@ -415,7 +420,8 @@ std::unique_ptr<exec::Aggregate> create(
case TypeKind::ROW:
return std::make_unique<Aggregate<ComplexType>>(resultType);
default:
VELOX_UNREACHABLE("Unexpected type {}", mapTypeKindToName(typeKind));
VELOX_UNREACHABLE(
"Unexpected type {}", mapTypeKindToName(inputType->kind()));
}
}

Expand Down Expand Up @@ -443,8 +449,9 @@ void registerSetAggAggregate(const std::string& prefix) {
VELOX_CHECK_EQ(argTypes.size(), 1);

const bool isRawInput = exec::isRawInput(step);
const TypeKind typeKind =
isRawInput ? argTypes[0]->kind() : argTypes[0]->childAt(0)->kind();
const TypePtr& inputType =
isRawInput ? argTypes[0] : argTypes[0]->childAt(0);
const TypeKind typeKind = inputType->kind();
const bool throwOnNestedNulls = isRawInput;

switch (typeKind) {
Expand All @@ -458,6 +465,11 @@ void registerSetAggAggregate(const std::string& prefix) {
return std::make_unique<SetAggAggregate<int32_t>>(resultType);
case TypeKind::BIGINT:
return std::make_unique<SetAggAggregate<int64_t>>(resultType);
case TypeKind::HUGEINT:
VELOX_CHECK(
inputType->isLongDecimal(),
"Non-decimal use of HUGEINT is not supported");
return std::make_unique<SetAggAggregate<int128_t>>(resultType);
case TypeKind::REAL:
return std::make_unique<SetAggAggregate<float>>(resultType);
case TypeKind::DOUBLE:
Expand Down Expand Up @@ -503,9 +515,7 @@ void registerSetUnionAggregate(const std::string& prefix) {
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_EQ(argTypes.size(), 1);

const TypeKind typeKind = argTypes[0]->childAt(0)->kind();

return create<SetUnionAggregate>(typeKind, resultType);
return create<SetUnionAggregate>(argTypes[0]->childAt(0), resultType);
});
}

Expand Down
186 changes: 186 additions & 0 deletions velox/functions/prestosql/aggregates/tests/SetAggTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ namespace facebook::velox::aggregate::test {

namespace {

constexpr int64_t kLongMax = std::numeric_limits<int64_t>::max();
constexpr int64_t kLongMin = std::numeric_limits<int64_t>::min();
constexpr int128_t kHugeMax = std::numeric_limits<int128_t>::max();
constexpr int128_t kHugeMin = std::numeric_limits<int128_t>::min();

class SetAggTest : public AggregationTestBase {
protected:
void SetUp() override {
Expand Down Expand Up @@ -155,6 +160,187 @@ TEST_F(SetAggTest, groupBy) {
{expected});
}

TEST_F(SetAggTest, shortDecimal) {
// Test with short decimal
auto type = DECIMAL(6, 2);

auto data = makeRowVector({
makeFlatVector<int64_t>(
{kLongMin,
2000,
3000,
-4321,
kLongMax,
5000,
3000,
kLongMax,
-2000,
6000,
7000},
type),
});

auto expected = makeRowVector({
makeArrayVector<int64_t>(
{
{kLongMin, -4321, -2000, 2000, 3000, 5000, 6000, 7000, kLongMax},
},
type),
});

testAggregations({data}, {}, {"set_agg(c0)"}, {"array_sort(a0)"}, {expected});

// Test with some NULL inputs (short decimals)
data = makeRowVector({
makeNullableFlatVector<int64_t>(
{1000,
std::nullopt,
kLongMin,
4000,
std::nullopt,
4000,
std::nullopt,
-1000,
5000,
-9999,
kLongMax},
type),
});

expected = makeRowVector({
makeNullableArrayVector(
std::vector<std::vector<std::optional<int64_t>>>{
{kLongMin,
-9999,
-1000,
1000,
4000,
5000,
kLongMax,
std::nullopt}},
ARRAY(type)),
});

testAggregations({data}, {}, {"set_agg(c0)"}, {"array_sort(a0)"}, {expected});

// Test with all NULL inputs (short decimals)
data = makeRowVector({
makeNullableFlatVector<int64_t>(
{std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt},
type),
});

expected = makeRowVector({
makeNullableArrayVector(
std::vector<std::vector<std::optional<int64_t>>>{{std::nullopt}},
ARRAY(type)),
});

testAggregations({data}, {}, {"set_agg(c0)"}, {"array_sort(a0)"}, {expected});
}

TEST_F(SetAggTest, longDecimal) {
// Test with long decimal
auto type = DECIMAL(20, 2);

auto data = makeRowVector({
makeFlatVector<int128_t>(
{kHugeMin,
-2000,
3000,
4000,
5000,
kHugeMax,
-9630,
2000,
6000,
7000},
type),
});

auto expected = makeRowVector({
makeArrayVector<int128_t>(
{
{kHugeMin,
-9630,
-2000,
2000,
3000,
4000,
5000,
6000,
7000,
kHugeMax},
},
type),
});

testAggregations(
{data}, {}, {"set_agg(c0)"}, {"array_sort(a0)"}, {expected}, {}, false);

// Test with some NULL inputs (long decimals)
data = makeRowVector({
makeNullableFlatVector<int128_t>(
{1000,
std::nullopt,
3000,
4000,
std::nullopt,
kHugeMax,
-8424,
4000,
std::nullopt,
-1000,
5000,
kHugeMin,
2000},
type),
});

expected = makeRowVector({
makeNullableArrayVector(
std::vector<std::vector<std::optional<int128_t>>>{
{kHugeMin,
-8424,
-1000,
1000,
2000,
3000,
4000,
5000,
kHugeMax,
std::nullopt}},
ARRAY(type)),
});

testAggregations(
{data}, {}, {"set_agg(c0)"}, {"array_sort(a0)"}, {expected}, {}, false);

// Test with all NULL inputs (long decimals)
data = makeRowVector({
makeNullableFlatVector<int128_t>(
{std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt},
type),
});

expected = makeRowVector({
makeNullableArrayVector(
std::vector<std::vector<std::optional<int128_t>>>{{std::nullopt}},
ARRAY(type)),
});

testAggregations(
{data}, {}, {"set_agg(c0)"}, {"array_sort(a0)"}, {expected}, {}, false);
}

std::vector<std::optional<std::string>> generateStrings(
const std::vector<std::optional<std::string>>& choices,
vector_size_t size) {
Expand Down
Loading

0 comments on commit 829978c

Please sign in to comment.