diff --git a/velox/functions/prestosql/aggregates/MultiMapAggAggregate.cpp b/velox/functions/prestosql/aggregates/MultiMapAggAggregate.cpp index 17f62ab63085d..68eac78ac8b2f 100644 --- a/velox/functions/prestosql/aggregates/MultiMapAggAggregate.cpp +++ b/velox/functions/prestosql/aggregates/MultiMapAggAggregate.cpp @@ -18,6 +18,7 @@ #include "velox/exec/Strings.h" #include "velox/functions/lib/aggregates/ValueList.h" #include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/type/FloatingPointUtil.h" #include "velox/vector/FlatVector.h" namespace facebook::velox::aggregate::prestosql { @@ -232,6 +233,23 @@ struct MultiMapAccumulatorTypeTraits { using AccumulatorType = MultiMapAccumulator; }; +// Ensure Accumulator treats NaNs as equal. +template <> +struct MultiMapAccumulatorTypeTraits { + using AccumulatorType = MultiMapAccumulator< + float, + util::floating_point::NaNAwareHash, + util::floating_point::NaNAwareEquals>; +}; + +template <> +struct MultiMapAccumulatorTypeTraits { + using AccumulatorType = MultiMapAccumulator< + double, + util::floating_point::NaNAwareHash, + util::floating_point::NaNAwareEquals>; +}; + template <> struct MultiMapAccumulatorTypeTraits { using AccumulatorType = ComplexTypeMultiMapAccumulator; diff --git a/velox/functions/prestosql/aggregates/tests/MultiMapAggTest.cpp b/velox/functions/prestosql/aggregates/tests/MultiMapAggTest.cpp index 38cc67afcedf7..3c63283331809 100644 --- a/velox/functions/prestosql/aggregates/tests/MultiMapAggTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/MultiMapAggTest.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" @@ -266,5 +267,40 @@ TEST_F(MultiMapAggTest, arrayKeyGroupBy) { {expected}); } +TEST_F(MultiMapAggTest, doubleKeyGlobal) { + // Verify that all NaN representations used as a map key are treated as equal + static const double KNan1 = std::nan("1"); + static const double KNan2 = std::nan("2"); + auto data = makeRowVector({ + makeFlatVector( + {KNan1, KNan2, 1.1, 0.2, 23.0, 2.0, 23.0, 2.0, 1.1, 0.2, 23.0, 2.0}), + makeNullableFlatVector( + {-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + }); + + auto expected = makeRowVector({ + makeMapVector( + { + 0, + }, + makeFlatVector({KNan1, 0.2, 1.1, 2.0, 23.0}), + makeArrayVector({ + {-2, -1}, + {1, 7}, + {0, 6}, + {3, 5, 9}, + {2, 4, 8}, + })), + }); + + testAggregations( + {data}, + {}, + {"multimap_agg(c0, c1)"}, + // Sort the result arrays to ensure deterministic results. + {"transform_values(a0, (k, v) -> array_sort(v))"}, + {expected}); +} + } // namespace } // namespace facebook::velox::aggregate::prestosql