diff --git a/velox/functions/prestosql/aggregates/PrestoHasher.cpp b/velox/functions/prestosql/aggregates/PrestoHasher.cpp index a2e9d93ac9b32..b3eac733521ac 100644 --- a/velox/functions/prestosql/aggregates/PrestoHasher.cpp +++ b/velox/functions/prestosql/aggregates/PrestoHasher.cpp @@ -231,14 +231,17 @@ void PrestoHasher::hash( BufferPtr& hashes) { auto baseRow = vector_->base()->as(); auto indices = vector_->indices(); - SelectivityVector elementRows; - if (vector_->isIdentityMapping()) { + SelectivityVector elementRows; + if (vector_->isIdentityMapping() && !vector_->mayHaveNulls()) { elementRows = rows; } else { elementRows = SelectivityVector(baseRow->size(), false); - rows.applyToSelected( - [&](auto row) { elementRows.setValid(indices[row], true); }); + rows.applyToSelected([&](auto row) { + if (!vector_->isNullAt(row)) { + elementRows.setValid(indices[row], true); + } + }); elementRows.updateBounds(); } @@ -246,14 +249,14 @@ void PrestoHasher::hash( AlignedBuffer::allocate(elementRows.end(), baseRow->pool()); auto rawHashes = hashes->asMutable(); - auto rowChildHashes = childHashes->as(); if (isTimestampWithTimeZoneType(vector_->base()->type())) { // Hash only timestamp value. children_[0]->hash(baseRow->childAt(0), elementRows, childHashes); + auto rawChildHashes = childHashes->as(); rows.applyToSelected([&](auto row) { if (!baseRow->isNullAt(indices[row])) { - rawHashes[row] = rowChildHashes[indices[row]]; + rawHashes[row] = rawChildHashes[indices[row]]; } else { rawHashes[row] = 0; } @@ -261,15 +264,30 @@ void PrestoHasher::hash( return; } + BufferPtr combinedChildHashes = + AlignedBuffer::allocate(elementRows.end(), baseRow->pool()); + auto* rawCombinedChildHashes = combinedChildHashes->asMutable(); + std::fill_n(rawCombinedChildHashes, rows.end(), 1); + std::fill_n(rawHashes, rows.end(), 1); for (int i = 0; i < baseRow->childrenSize(); i++) { children_[i]->hash(baseRow->childAt(i), elementRows, childHashes); - rows.applyToSelected([&](auto row) { - rawHashes[row] = safeHash(rawHashes[row], rowChildHashes[indices[row]]); + auto rawChildHashes = childHashes->as(); + elementRows.applyToSelected([&](auto row) { + rawCombinedChildHashes[row] = + safeHash(rawCombinedChildHashes[row], rawChildHashes[row]); }); } + + rows.applyToSelected([&](auto row) { + if (!vector_->isNullAt(row)) { + rawHashes[row] = rawCombinedChildHashes[indices[row]]; + } else { + rawHashes[row] = 0; + } + }); } void PrestoHasher::hash( diff --git a/velox/functions/prestosql/aggregates/tests/ChecksumAggregateTest.cpp b/velox/functions/prestosql/aggregates/tests/ChecksumAggregateTest.cpp index 2d09330b13064..75f97eaa8b04c 100644 --- a/velox/functions/prestosql/aggregates/tests/ChecksumAggregateTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/ChecksumAggregateTest.cpp @@ -205,16 +205,45 @@ TEST_F(ChecksumAggregateTest, maps) { } TEST_F(ChecksumAggregateTest, rows) { - auto row = makeRowVector( - {makeFlatVector({1, 3}), makeFlatVector({2, 4})}); + auto row = makeRowVector({ + makeFlatVector({1, 3}), + makeFlatVector({2, 4}), + }); assertChecksum(row, "jMIvLQ5YEVg="); - row = makeRowVector( - {makeNullableFlatVector({1, std::nullopt}), - makeNullableFlatVector({std::nullopt, 4})}); + row->setNull(0, true); + assertChecksum(row, "nbYF0I9UTeU="); + + row->setNull(1, true); + assertChecksum(row, "DpXXC2Pzbjw="); + + row = makeRowVector({ + makeNullableFlatVector({1, std::nullopt}), + makeNullableFlatVector({std::nullopt, 4}), + }); assertChecksum(row, "6jtxEIUj7Hg="); + + row = makeRowVector({ + makeRowVector({ + makeNullableFlatVector({"Hello", "world!"}), + makeNullableFlatVector({true, false}), + }), + makeNullableFlatVector({17, 4}), + }); + + assertChecksum(row, "21pwcVg31Kc="); + + row = makeRowVector({ + makeRowVector({ + makeNullableFlatVector({"Hello", std::nullopt}), + makeNullableFlatVector({std::nullopt, false}), + }), + makeNullableFlatVector({std::nullopt, 4}), + }); + + assertChecksum(row, "Aw9tzUPOiUc="); } TEST_F(ChecksumAggregateTest, globalAggregationNoData) { diff --git a/velox/functions/prestosql/aggregates/tests/PrestoHasherTest.cpp b/velox/functions/prestosql/aggregates/tests/PrestoHasherTest.cpp index ef1e58a442c0b..8761bb16f03bb 100644 --- a/velox/functions/prestosql/aggregates/tests/PrestoHasherTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/PrestoHasherTest.cpp @@ -310,16 +310,25 @@ TEST_F(PrestoHasherTest, maps) { } TEST_F(PrestoHasherTest, rows) { - auto row = makeRowVector( - {makeFlatVector({1, 3}), makeFlatVector({2, 4})}); + auto row = makeRowVector({ + makeFlatVector({1, 3}), + makeFlatVector({2, 4}), + }); assertHash(row, {4329740752828761434, 655643799837772474}); - row = makeRowVector( - {makeNullableFlatVector({1, std::nullopt}), - makeNullableFlatVector({std::nullopt, 4})}); + row = makeRowVector({ + makeNullableFlatVector({1, std::nullopt}), + makeNullableFlatVector({std::nullopt, 4}), + }); assertHash(row, {7113531408683827503, -1169223928725763049}); + + row->setNull(0, true); + assertHash(row, {0, -1169223928725763049}); + + row->setNull(1, true); + assertHash(row, {0, 0}); } TEST_F(PrestoHasherTest, wrongVectorType) {