Skip to content

Commit

Permalink
Adding decimal support for min() and max() functions (facebookincubat…
Browse files Browse the repository at this point in the history
…or#9005)

Summary:
Delivers facebookincubator#9004

This PR is for adding decimal support for min() and max() functions.

presto-cli output:
```
presto:tpch> SELECT MIN(Col,2) FROM (VALUES cast(0.82 as decimal(5,4)), cast(2.333 as decimal(5,4)), cast(3.132 as decimal(5,4)), cast(4.344 as decimal(5,4))) AS X(Col);
      _col0
------------------
 [0.8200, 2.3330]
(1 row)
presto:tpch> SELECT MIN(Col,3) FROM (VALUES cast(0.82 as decimal(5,4)), cast(2.333 as decimal(5,4)), cast(3.132 as decimal(5,4)), cast(4.344 as decimal(5,4))) AS X(Col);
          _col0
--------------------------
 [0.8200, 2.3330, 3.1320]
(1 row)
presto:tpch> SELECT MAX(Col,2) FROM (VALUES cast(0.82 as decimal(5,4)), cast(2.333 as decimal(5,4)), cast(3.132 as decimal(5,4)), cast(4.344 as decimal(5,4))) AS X(Col);
      _col0
------------------
 [4.3440, 3.1320]
(1 row)
presto:tpch> SELECT MAX(Col,3) FROM (VALUES cast(0.82 as decimal(5,4)), cast(2.333 as decimal(5,4)), cast(3.132 as decimal(5,4)), cast(4.344 as decimal(5,4))) AS X(Col);
          _col0
--------------------------
 [4.3440, 3.1320, 2.3330]
(1 row)

```

Pull Request resolved: facebookincubator#9005

Reviewed By: pedroerp

Differential Revision: D56487562

Pulled By: Yuhta

fbshipit-source-id: 283eb189a91840835784e565cd33fa713167d17d
  • Loading branch information
minhancao authored and Joe-Abraham committed Jun 7, 2024
1 parent 15a338c commit 094da60
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 3 deletions.
25 changes: 22 additions & 3 deletions velox/functions/prestosql/aggregates/MinMaxAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,14 @@ std::pair<vector_size_t*, vector_size_t*> rawOffsetAndSizes(
template <typename T, typename Compare>
struct MinMaxNAccumulator {
int64_t n{0};
std::vector<T, StlAllocator<T>> heapValues;
using Allocator = std::conditional_t<
std::is_same_v<int128_t, T>,
AlignedStlAllocator<T, sizeof(int128_t)>,
StlAllocator<T>>;
std::vector<T, Allocator> heapValues;

explicit MinMaxNAccumulator(HashStringAllocator* allocator)
: heapValues{StlAllocator<T>(allocator)} {}
: heapValues{Allocator(allocator)} {}

int64_t getN() const {
return n;
Expand Down Expand Up @@ -916,6 +920,18 @@ exec::AggregateRegistrationResult registerMinMax(
.build());
}

// decimal(p,s), bigint -> row(array(decimal(p,s)), bigint) ->
// array(decimal(p,s))
signatures.push_back(
exec::AggregateFunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.argumentType("DECIMAL(a_precision, a_scale)")
.argumentType("bigint")
.intermediateType("row(bigint, array(DECIMAL(a_precision, a_scale)))")
.returnType("array(DECIMAL(a_precision, a_scale))")
.build());

return exec::registerAggregateFunction(
name,
std::move(signatures),
Expand Down Expand Up @@ -952,7 +968,10 @@ exec::AggregateRegistrationResult registerMinMax(
case TypeKind::TIMESTAMP:
return std::make_unique<TNumericN<Timestamp>>(resultType);
case TypeKind::HUGEINT:
return std::make_unique<TNumericN<int128_t>>(resultType);
if (inputType->isLongDecimal()) {
return std::make_unique<TNumericN<int128_t>>(resultType);
}
VELOX_UNREACHABLE();
default:
VELOX_CHECK(
false,
Expand Down
231 changes: 231 additions & 0 deletions velox/functions/prestosql/aggregates/tests/MinMaxTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,102 @@ class MinMaxNTest : public functions::aggregate::test::AggregationTestBase {
"second argument of max/min must be less than or equal to 10000");
}

template <typename T>
void testNumericGlobalDecimal() {
TypePtr type;
if (std::is_same<T, int64_t>::value) {
type = DECIMAL(6, 2);
} else {
type = DECIMAL(20, 2);
}
auto data = makeRowVector({
makeFlatVector<T>(
{100000,
131011,
223454,
111911,
111300,
800000,
104000,
712452,
161213,
135243},
type),
});
auto expected = makeRowVector({
makeArrayVector<T>(
{
{100000, 104000},
},
type),
makeArrayVector<T>(
{
{100000, 104000, 111300, 111911, 131011},
},
type),
makeArrayVector<T>(
{
{800000, 712452, 223454},
},
type),
makeArrayVector<T>(
{
{800000, 712452, 223454, 161213, 135243, 131011, 111911},
},
type),
});

testAggregations(
{data},
{},
{"min(c0, 2)", "min(c0, 5)", "max(c0, 3)", "max(c0, 7)"},
{expected});

// Add some nulls. Expect these to be ignored.
data = makeRowVector({
makeNullableFlatVector<T>(
{100000,
std::nullopt,
131011,
223454,
111911,
std::nullopt,
111300,
800000,
104000,
712452,
161213,
135243,
std::nullopt},
type),
});

testAggregations(
{data},
{},
{"min(c0, 2)", "min(c0, 5)", "max(c0, 3)", "max(c0, 7)"},
{expected});

// Test all null input.
data = makeRowVector({
makeNullableFlatVector<T>(
{std::nullopt, std::nullopt, std::nullopt, std::nullopt}, type),
});

expected = makeRowVector({
makeAllNullArrayVector(1, data->childAt(0)->type()),
makeAllNullArrayVector(1, data->childAt(0)->type()),
makeAllNullArrayVector(1, data->childAt(0)->type()),
makeAllNullArrayVector(1, data->childAt(0)->type()),
});

testAggregations(
{data},
{},
{"min(c0, 2)", "min(c0, 5)", "max(c0, 3)", "max(c0, 7)"},
{expected});
}

template <typename T>
void testNumericGroupBy() {
auto data = makeRowVector({
Expand Down Expand Up @@ -717,6 +813,131 @@ class MinMaxNTest : public functions::aggregate::test::AggregationTestBase {
{"min(c1, c2)", "min(c1, c4)", "max(c1, c3)", "max(c1, c4)"},
{expected});
}

template <typename T>
void testNumericGroupByDecimal() {
TypePtr type;
if (std::is_same<T, int64_t>::value) {
type = DECIMAL(6, 2);
} else {
type = DECIMAL(20, 2);
}

auto data = makeRowVector({
makeFlatVector<int16_t>({1, 2, 1, 1, 2, 2, 1, 2}),
makeFlatVector<T>(
{100000, 131011, 223454, 111911, 111300, 104000, 161213, 135243},
type),
});

auto expected = makeRowVector({
makeFlatVector<int16_t>({1, 2}),
makeArrayVector<T>(
{
{100000, 111911},
{104000, 111300},
},
type),
makeArrayVector<T>(
{
{100000, 111911, 161213, 223454},
{104000, 111300, 131011, 135243},
},
type),
makeArrayVector<T>(
{
{223454, 161213, 111911},
{135243, 131011, 111300},
},
type),
makeArrayVector<T>(
{
{223454, 161213, 111911, 100000},
{135243, 131011, 111300, 104000},
},
type),
});

testAggregations(
{data},
{"c0"},
{"min(c1, 2)", "min(c1, 5)", "max(c1, 3)", "max(c1, 7)"},
{expected});

// Add some nulls. Expect these to be ignored.
data = makeRowVector({
makeFlatVector<int16_t>({1, 2, 1, 1, 1, 2, 2, 2, 1, 2}),
makeNullableFlatVector<T>(
{100000,
131011,
std::nullopt,
223454,
111911,
111300,
std::nullopt,
104000,
161213,
135243},
type),
});

testAggregations(
{data},
{"c0"},
{"min(c1, 2)", "min(c1, 5)", "max(c1, 3)", "max(c1, 7)"},
{expected});

// Test all null input.
data = makeRowVector({
makeFlatVector<int16_t>({1, 2, 1, 1, 1, 2, 2, 2, 1, 2}),
makeNullableFlatVector<T>(
{std::nullopt,
131011,
std::nullopt,
std::nullopt,
std::nullopt,
111300,
std::nullopt,
104000,
std::nullopt,
135243},
type),
});

expected = makeRowVector({
makeFlatVector<int16_t>({1, 2}),
makeNullableArrayVector<T>(
{
std::nullopt,
{{{104000, 111300}}},
},
ARRAY(type)),
makeNullableArrayVector<T>(
{
std::nullopt,
{{{104000, 111300, 131011, 135243}}},
},
ARRAY(type)),
makeNullableArrayVector<T>(
{
std::nullopt,
{{{135243, 131011, 111300}}},
},
ARRAY(type)),
makeNullableArrayVector<T>(
{
std::nullopt,
{{{135243, 131011, 111300, 104000}}},
},
ARRAY(type)),
});

testAggregations(
{data},
{"c0"},
{"min(c1, 2)", "min(c1, 5)", "max(c1, 3)", "max(c1, 7)"},
{expected});
}
};

TEST_F(MinMaxNTest, tinyint) {
Expand Down Expand Up @@ -749,6 +970,16 @@ TEST_F(MinMaxNTest, double) {
testNumericGroupBy<double>();
}

TEST_F(MinMaxNTest, shortdecimal) {
testNumericGlobalDecimal<int64_t>();
testNumericGroupByDecimal<int64_t>();
}

TEST_F(MinMaxNTest, longdecimal) {
testNumericGlobalDecimal<int128_t>();
testNumericGroupByDecimal<int128_t>();
}

TEST_F(MinMaxNTest, incrementalWindow) {
// SELECT
// c0, c1, c2, c3,
Expand Down

0 comments on commit 094da60

Please sign in to comment.