Skip to content

Commit

Permalink
Add support for complex type keys in multimap_agg (#7814)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #7814

Reviewed By: xiaoxmeng

Differential Revision: D51727866

Pulled By: mbasmanova

fbshipit-source-id: bc07ed9c71420662560093f31e2dbd8ef5d134a4
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Dec 1, 2023
1 parent aa10bc5 commit aa60e3d
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 27 deletions.
159 changes: 132 additions & 27 deletions velox/functions/prestosql/aggregates/MultiMapAggAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/exec/AddressableNonNullValueList.h"
#include "velox/exec/Aggregate.h"
#include "velox/exec/Strings.h"
#include "velox/functions/prestosql/aggregates/AggregateNames.h"
Expand Down Expand Up @@ -42,6 +43,17 @@ struct MultiMapAccumulator {
: keys{AlignedStlAllocator<std::pair<const K, ValueList>, 16>(
allocator)} {}

MultiMapAccumulator(
Hash hash,
EqualTo equalTo,
HashStringAllocator* allocator)
: keys{
0, // initialCapacity
hash,
equalTo,
AlignedStlAllocator<std::pair<const K, ValueList>, 16>(
allocator)} {}

size_t size() const {
return keys.size();
}
Expand Down Expand Up @@ -124,13 +136,116 @@ struct MultiMapAccumulator {
}
};

struct ComplexTypeMultiMapAccumulator {
MultiMapAccumulator<
HashStringAllocator::Position,
AddressableNonNullValueList::Hash,
AddressableNonNullValueList::EqualTo>
base;

/// Stores unique non-null keys.
AddressableNonNullValueList serializedKeys;

ComplexTypeMultiMapAccumulator(
const TypePtr& type,
HashStringAllocator* allocator)
: base{
AddressableNonNullValueList::Hash{},
AddressableNonNullValueList::EqualTo{type},
allocator} {}

size_t size() const {
return base.size();
}

size_t numValues() const {
return base.numValues();
}

/// Adds key-value pair.
void insert(
const DecodedVector& decodedKeys,
const DecodedVector& decodedValues,
vector_size_t index,
HashStringAllocator& allocator) {
const auto position = serializedKeys.append(decodedKeys, index, &allocator);

auto& values = insertKey(position);
values.appendValue(decodedValues, index, &allocator);
}

/// Adds a key with a list of values.
void insertMultiple(
const DecodedVector& decodedKeys,
vector_size_t keyIndex,
const DecodedVector& decodedValues,
vector_size_t valueIndex,
vector_size_t numValues,
HashStringAllocator& allocator) {
const auto position =
serializedKeys.append(decodedKeys, keyIndex, &allocator);

auto& values = insertKey(position);
for (auto i = 0; i < numValues; ++i) {
values.appendValue(decodedValues, valueIndex + i, &allocator);
}
}

ValueList& insertKey(HashStringAllocator::Position position) {
auto result = base.keys.insert({position, ValueList()});
if (!result.second) {
serializedKeys.removeLast(position);
}

return result.first->second;
}

void extract(
VectorPtr& mapKeys,
ArrayVector& mapValueArrays,
vector_size_t& keyOffset,
vector_size_t& valueOffset) {
auto& mapValues = mapValueArrays.elements();

for (auto& entry : base.keys) {
AddressableNonNullValueList::read(entry.first, *mapKeys, keyOffset);

const auto numValues = entry.second.size();
mapValueArrays.setOffsetAndSize(keyOffset, valueOffset, numValues);

aggregate::ValueListReader reader(entry.second);
for (auto i = 0; i < numValues; i++) {
reader.next(*mapValues, valueOffset++);
}

++keyOffset;
}
}

void free(HashStringAllocator& allocator) {
base.free(allocator);
serializedKeys.free(allocator);
}
};

template <typename T>
struct MultiMapAccumulatorTypeTraits {
using AccumulatorType = MultiMapAccumulator<T>;
};

template <>
struct MultiMapAccumulatorTypeTraits<ComplexType> {
using AccumulatorType = ComplexTypeMultiMapAccumulator;
};

template <typename K>
class MultiMapAggAggregate : public exec::Aggregate {
public:
explicit MultiMapAggAggregate(TypePtr resultType)
: exec::Aggregate(std::move(resultType)) {}

using AccumulatorType = MultiMapAccumulator<K>;
using AccumulatorType =
typename MultiMapAccumulatorTypeTraits<K>::AccumulatorType;

bool isFixedSize() const override {
return false;
Expand Down Expand Up @@ -365,32 +480,15 @@ class MultiMapAggAggregate : public exec::Aggregate {

exec::AggregateRegistrationResult registerMultiMapAggAggregate(
const std::string& prefix) {
static const std::vector<std::string> kSupportedKeyTypes = {
"boolean",
"tinyint",
"smallint",
"integer",
"bigint",
"real",
"double",
"timestamp",
"date",
"varbinary",
"varchar",
"unknown",
};

std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
for (const auto& keyType : kSupportedKeyTypes) {
signatures.emplace_back(
exec::AggregateFunctionSignatureBuilder()
.typeVariable("V")
.returnType(fmt::format("map({},array(V))", keyType))
.intermediateType(fmt::format("map({},array(V))", keyType))
.argumentType(keyType)
.argumentType("V")
.build());
}
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.typeVariable("K")
.typeVariable("V")
.returnType("map(K,array(V))")
.intermediateType("map(K,array(V))")
.argumentType("K")
.argumentType("V")
.build()};

auto name = prefix + kMultiMapAgg;
return exec::registerAggregateFunction(
Expand Down Expand Up @@ -426,6 +524,13 @@ exec::AggregateRegistrationResult registerMultiMapAggAggregate(
case TypeKind::VARCHAR:
return std::make_unique<MultiMapAggAggregate<StringView>>(
resultType);
case TypeKind::ARRAY:
[[fallthrough]];
case TypeKind::MAP:
[[fallthrough]];
case TypeKind::ROW:
return std::make_unique<MultiMapAggAggregate<ComplexType>>(
resultType);
case TypeKind::UNKNOWN:
return std::make_unique<MultiMapAggAggregate<int32_t>>(resultType);
default:
Expand Down
86 changes: 86 additions & 0 deletions velox/functions/prestosql/aggregates/tests/MultiMapAggTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,91 @@ TEST_F(MultiMapAggTest, stringKeyGroupBy) {
{expected});
}

TEST_F(MultiMapAggTest, arrayKeyGlobal) {
auto data = makeRowVector({
makeArrayVectorFromJson<int32_t>({
"[1, 2, 3]",
"[]",
"null",
"[1, 2, 3, 4]",
"[1, 2]",
"[1, 2, 3]",
"[1, 2]",
"[]",
}),
makeFlatVector<int64_t>({1, 2, 3, 4, 5, 6, 7, 8}),
});

auto expected = makeRowVector({
makeMapVector(
{0},
makeArrayVectorFromJson<int32_t>({
"[1, 2, 3]",
"[]",
"[1, 2, 3, 4]",
"[1, 2]",
}),
makeArrayVectorFromJson<int64_t>({
"[1, 6]",
"[2, 8]",
"[4]",
"[5, 7]",
})),
});

testAggregations(
{data},
{},
{"multimap_agg(c0, c1)"},
// Sort the result arrays to ensure deterministic results.
{"transform_values(a0, (k, v) -> array_sort(v))"},
{expected});
}

TEST_F(MultiMapAggTest, arrayKeyGroupBy) {
auto data = makeRowVector({
makeFlatVector<int16_t>({1, 2, 1, 2, 1, 2, 1, 2}),
makeArrayVectorFromJson<int32_t>({
"[1, 2, 3]",
"[]",
"null",
"[1, 2, 3, 4]",
"[1, 2]",
"[1, 2, 3]",
"[1, 2]",
"[]",
}),
makeFlatVector<int64_t>({1, 2, 3, 4, 5, 6, 7, 8}),
});

auto expected = makeRowVector({
makeFlatVector<int16_t>({1, 2}),
makeMapVector(
{0, 2},
makeArrayVectorFromJson<int32_t>({
"[1, 2, 3]",
"[1, 2]",
"[]",
"[1, 2, 3]",
"[1, 2, 3, 4]",
}),
makeArrayVectorFromJson<int64_t>({
"[1]",
"[5, 7]",
"[2, 8]",
"[6]",
"[4]",
})),
});

testAggregations(
{data},
{"c0"},
{"multimap_agg(c1, c2)"},
// Sort the result arrays to ensure deterministic results.
{"c0", "transform_values(a0, (k, v) -> array_sort(v))"},
{expected});
}

} // namespace
} // namespace facebook::velox::aggregate::prestosql

0 comments on commit aa60e3d

Please sign in to comment.