Skip to content

Commit

Permalink
Enable testReadFromFiles in testAggregations (#6852)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #6852

Reviewed By: Yuhta

Differential Revision: D48737844
  • Loading branch information
kagamiori authored and facebook-github-bot committed Oct 5, 2023
1 parent f747c38 commit 58746be
Show file tree
Hide file tree
Showing 17 changed files with 562 additions and 146 deletions.
5 changes: 5 additions & 0 deletions velox/exec/tests/SimpleAggregateAdapterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ TEST_F(SimpleArrayAggAggregationTest, numbers) {
{inputVectors},
{},
{"simple_array_agg(c2)", "simple_array_agg(c3)"},
{"array_sort(a0)", "array_sort(a1)"},
{expected});

expected = makeRowVector(
Expand All @@ -196,6 +197,7 @@ TEST_F(SimpleArrayAggAggregationTest, numbers) {
{inputVectors},
{"c0"},
{"simple_array_agg(c2)", "simple_array_agg(c3)"},
{"c0", "array_sort(a0)", "array_sort(a1)"},
{expected});

expected = makeRowVector(
Expand All @@ -209,6 +211,7 @@ TEST_F(SimpleArrayAggAggregationTest, numbers) {
{inputVectors},
{"c1"},
{"simple_array_agg(c2)", "simple_array_agg(c3)"},
{"c1", "array_sort(a0)", "array_sort(a1)"},
{expected});

inputVectors = makeRowVector({makeNullableFlatVector<int64_t>(
Expand Down Expand Up @@ -248,6 +251,7 @@ TEST_F(SimpleArrayAggAggregationTest, nestedArray) {
{inputVectors},
{"c0"},
{"simple_array_agg(c1)", "simple_array_agg(c2)"},
{"c0", "array_sort(a0)", "array_sort(a1)"},
{expected});

expected = makeRowVector(
Expand All @@ -269,6 +273,7 @@ TEST_F(SimpleArrayAggAggregationTest, nestedArray) {
{inputVectors},
{},
{"simple_array_agg(c1)", "simple_array_agg(c2)"},
{"array_sort(a0)", "array_sort(a1)"},
{expected});
}

Expand Down
79 changes: 65 additions & 14 deletions velox/functions/lib/aggregates/tests/AggregationTestBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,28 @@ void AggregationTestBase::testAggregations(
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::string& duckDbSql,
const std::unordered_map<std::string, std::string>& config) {
const std::unordered_map<std::string, std::string>& config,
bool testWithTableScan) {
SCOPED_TRACE(duckDbSql);
testAggregations(data, groupingKeys, aggregates, {}, duckDbSql, config);
testAggregations(
data, groupingKeys, aggregates, {}, duckDbSql, config, testWithTableScan);
}

void AggregationTestBase::testAggregations(
const std::vector<RowVectorPtr>& data,
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::vector<RowVectorPtr>& expectedResult,
const std::unordered_map<std::string, std::string>& config) {
testAggregations(data, groupingKeys, aggregates, {}, expectedResult, config);
const std::unordered_map<std::string, std::string>& config,
bool testWithTableScan) {
testAggregations(
data,
groupingKeys,
aggregates,
{},
expectedResult,
config,
testWithTableScan);
}

void AggregationTestBase::testAggregations(
Expand All @@ -108,15 +118,17 @@ void AggregationTestBase::testAggregations(
const std::vector<std::string>& aggregates,
const std::vector<std::string>& postAggregationProjections,
const std::string& duckDbSql,
const std::unordered_map<std::string, std::string>& config) {
const std::unordered_map<std::string, std::string>& config,
bool testWithTableScan) {
SCOPED_TRACE(duckDbSql);
testAggregations(
[&](PlanBuilder& builder) { builder.values(data); },
groupingKeys,
aggregates,
postAggregationProjections,
[&](auto& builder) { return builder.assertResults(duckDbSql); },
config);
config,
testWithTableScan);
}

namespace {
Expand Down Expand Up @@ -540,7 +552,8 @@ void AggregationTestBase::testReadFromFiles(
const std::vector<std::string>& aggregates,
const std::vector<std::string>& postAggregationProjections,
std::function<std::shared_ptr<exec::Task>(exec::test::AssertQueryBuilder&)>
assertResults) {
assertResults,
const std::unordered_map<std::string, std::string>& config) {
PlanBuilder builder(pool());
makeSource(builder);
auto input = AssertQueryBuilder(builder.planNode()).copyResults(pool());
Expand All @@ -565,12 +578,17 @@ void AggregationTestBase::testReadFromFiles(
// so it would be the same as the original test.
{
ScopedChange<bool> disableTestStreaming(&testStreaming_, false);
testAggregations(
testAggregationsImpl(
[&](auto& builder) { builder.tableScan(asRowType(input->type())); },
groupingKeys,
aggregates,
postAggregationProjections,
[&](auto& builder) { return assertResults(builder.splits(splits)); });
[&](auto& builder) { return assertResults(builder.splits(splits)); },
config);
}

for (const auto& file : files) {
remove(file->path.c_str());
}
}

Expand All @@ -580,29 +598,33 @@ void AggregationTestBase::testAggregations(
const std::vector<std::string>& aggregates,
const std::vector<std::string>& postAggregationProjections,
const std::vector<RowVectorPtr>& expectedResult,
const std::unordered_map<std::string, std::string>& config) {
const std::unordered_map<std::string, std::string>& config,
bool testWithTableScan) {
testAggregations(
[&](PlanBuilder& builder) { builder.values(data); },
groupingKeys,
aggregates,
postAggregationProjections,
[&](auto& builder) { return builder.assertResults(expectedResult); },
config);
config,
testWithTableScan);
}

void AggregationTestBase::testAggregations(
std::function<void(PlanBuilder&)> makeSource,
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::string& duckDbSql,
const std::unordered_map<std::string, std::string>& config) {
const std::unordered_map<std::string, std::string>& config,
bool testWithTableScan) {
testAggregations(
makeSource,
groupingKeys,
aggregates,
{},
[&](auto& builder) { return builder.assertResults(duckDbSql); },
config);
config,
testWithTableScan);
}

RowVectorPtr AggregationTestBase::validateStreamingInTestAggregations(
Expand Down Expand Up @@ -667,7 +689,7 @@ RowVectorPtr AggregationTestBase::validateStreamingInTestAggregations(
return expected;
}

void AggregationTestBase::testAggregations(
void AggregationTestBase::testAggregationsImpl(
std::function<void(PlanBuilder&)> makeSource,
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
Expand Down Expand Up @@ -886,6 +908,35 @@ void AggregationTestBase::testAggregations(
}
}

void AggregationTestBase::testAggregations(
std::function<void(PlanBuilder&)> makeSource,
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::vector<std::string>& postAggregationProjections,
std::function<std::shared_ptr<exec::Task>(AssertQueryBuilder&)>
assertResults,
const std::unordered_map<std::string, std::string>& config,
bool testWithTableScan) {
testAggregationsImpl(
makeSource,
groupingKeys,
aggregates,
postAggregationProjections,
assertResults,
config);

if (testWithTableScan) {
SCOPED_TRACE("Test reading input from table scan");
testReadFromFiles(
makeSource,
groupingKeys,
aggregates,
postAggregationProjections,
assertResults,
config);
}
}

namespace {
std::pair<TypePtr, TypePtr> getResultTypes(
const std::string& name,
Expand Down
36 changes: 27 additions & 9 deletions velox/functions/lib/aggregates/tests/AggregationTestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class AggregationTestBase : public exec::test::OperatorTestBase {
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::string& duckDbSql,
const std::unordered_map<std::string, std::string>& config = {});
const std::unordered_map<std::string, std::string>& config = {},
bool testWithTableScan = true);

/// Same as above, but allows to specify a set of projections to apply after
/// the aggregation.
Expand All @@ -79,7 +80,8 @@ class AggregationTestBase : public exec::test::OperatorTestBase {
const std::vector<std::string>& postAggregationProjections,
std::function<std::shared_ptr<exec::Task>(
exec::test::AssertQueryBuilder& builder)> assertResults,
const std::unordered_map<std::string, std::string>& config = {});
const std::unordered_map<std::string, std::string>& config = {},
bool testWithTableScan = true);

/// Convenience version that allows to specify input data instead of a
/// function to build Values plan node.
Expand All @@ -88,7 +90,8 @@ class AggregationTestBase : public exec::test::OperatorTestBase {
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::string& duckDbSql,
const std::unordered_map<std::string, std::string>& config = {});
const std::unordered_map<std::string, std::string>& config = {},
bool testWithTableScan = true);

/// Convenience version that allows to specify input data instead of a
/// function to build Values plan node.
Expand All @@ -98,7 +101,8 @@ class AggregationTestBase : public exec::test::OperatorTestBase {
const std::vector<std::string>& aggregates,
const std::vector<std::string>& postAggregationProjections,
const std::string& duckDbSql,
const std::unordered_map<std::string, std::string>& config = {});
const std::unordered_map<std::string, std::string>& config = {},
bool testWithTableScan = true);

/// Convenience version that allows to specify input data instead of a
/// function to build Values plan node, and the expected result instead of a
Expand All @@ -108,7 +112,8 @@ class AggregationTestBase : public exec::test::OperatorTestBase {
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::vector<RowVectorPtr>& expectedResult,
const std::unordered_map<std::string, std::string>& config = {});
const std::unordered_map<std::string, std::string>& config = {},
bool testWithTableScan = true);

/// Convenience version that allows to specify input data instead of a
/// function to build Values plan node, and the expected result instead of a
Expand All @@ -119,7 +124,8 @@ class AggregationTestBase : public exec::test::OperatorTestBase {
const std::vector<std::string>& aggregates,
const std::vector<std::string>& postAggregationProjections,
const std::vector<RowVectorPtr>& expectedResult,
const std::unordered_map<std::string, std::string>& config = {});
const std::unordered_map<std::string, std::string>& config = {},
bool testWithTableScan = true);

/// Ensure the function is working in streaming use case. Create a first
/// aggregation function, add the rawInput1, then extract the accumulator,
Expand Down Expand Up @@ -189,21 +195,24 @@ class AggregationTestBase : public exec::test::OperatorTestBase {
const std::vector<std::string>& aggregates,
const std::vector<std::string>& postAggregationProjections,
std::function<std::shared_ptr<exec::Task>(
exec::test::AssertQueryBuilder&)> assertResults);
exec::test::AssertQueryBuilder&)> assertResults,
const std::unordered_map<std::string, std::string>& config = {});

void testReadFromFiles(
const std::vector<RowVectorPtr>& data,
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::vector<RowVectorPtr>& expectedResult) {
const std::vector<RowVectorPtr>& expectedResult,
const std::unordered_map<std::string, std::string>& config = {}) {
testReadFromFiles(
[&](auto& planBuilder) { planBuilder.values(data); },
groupingKeys,
aggregates,
{},
[&](auto& assertBuilder) {
return assertBuilder.assertResults({expectedResult});
});
},
config);
}

/// Generates a variety of logically equivalent plans to compute aggregations
Expand Down Expand Up @@ -256,6 +265,15 @@ class AggregationTestBase : public exec::test::OperatorTestBase {
vector_size_t rawInput2Size,
const std::unordered_map<std::string, std::string>& config = {});

void testAggregationsImpl(
std::function<void(exec::test::PlanBuilder&)> makeSource,
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::vector<std::string>& postAggregationProjections,
std::function<std::shared_ptr<exec::Task>(
exec::test::AssertQueryBuilder&)> assertResults,
const std::unordered_map<std::string, std::string>& config);

bool allowInputShuffle_{false};
};

Expand Down
Loading

0 comments on commit 58746be

Please sign in to comment.