From 3befde245d09e4fb9b1eb33f955d78d983be56fe Mon Sep 17 00:00:00 2001 From: Daniel Hunte Date: Sun, 15 Dec 2024 22:00:50 -0800 Subject: [PATCH] feat(fuzzer): Support multiple joins in the join node "toSql" methods for reference query runners (#11801) Summary: Currently, the hash join and nested loop join "toSql" methods for all reference query runners only support a single join. This change extends it to support multiple joins, only needing the join node of the last join in the tree. It traverses up the tree and recursively builds the sql query. Differential Revision: D66977480 --- velox/core/PlanNode.h | 2 + velox/exec/fuzzer/DuckQueryRunner.cpp | 169 +++++++++++----- velox/exec/fuzzer/DuckQueryRunner.h | 12 ++ velox/exec/fuzzer/JoinFuzzer.cpp | 5 +- velox/exec/fuzzer/PrestoQueryRunner.cpp | 212 +++++++++++++++------ velox/exec/fuzzer/PrestoQueryRunner.h | 20 +- velox/exec/fuzzer/ReferenceQueryRunner.h | 11 ++ velox/exec/tests/PrestoQueryRunnerTest.cpp | 118 ++++++++++++ 8 files changed, 440 insertions(+), 109 deletions(-) diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index d9085215860a..549fa4327e40 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -324,6 +324,8 @@ class ValuesNode : public PlanNode { const size_t repeatTimes_; }; +using ValuesNodePtr = std::shared_ptr; + class ArrowStreamNode : public PlanNode { public: ArrowStreamNode( diff --git a/velox/exec/fuzzer/DuckQueryRunner.cpp b/velox/exec/fuzzer/DuckQueryRunner.cpp index d6d606f6497e..30f293d53864 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.cpp +++ b/velox/exec/fuzzer/DuckQueryRunner.cpp @@ -102,6 +102,40 @@ DuckQueryRunner::aggregationFunctionDataSpecs() const { return kAggregationFunctionDataSpecs; } +std::string DuckQueryRunner::getTableName( + const core::ValuesNodePtr& valuesNode) { + return fmt::format("t_{}", valuesNode->id()); +} + +std::unordered_map> +DuckQueryRunner::getAllTablesAndNames(const core::PlanNodePtr& plan) { + std::unordered_map> result; + if (const auto valuesNode = + std::dynamic_pointer_cast(plan)) { + result.insert({getTableName(valuesNode), valuesNode->values()}); + } else { + for (const auto& source : plan->sources()) { + auto tablesAndNames = getAllTablesAndNames(source); + result.insert(tablesAndNames.begin(), tablesAndNames.end()); + } + } + return result; +} + +std::optional>> +DuckQueryRunner::execute(const core::PlanNodePtr& plan) { + if (const auto sql = toSql(plan)) { + DuckDbQueryRunner queryRunner; + std::unordered_map> inputMap = + getAllTablesAndNames(plan); + for (const auto& [tableName, input] : inputMap) { + queryRunner.createTable(tableName, input); + } + return queryRunner.execute(*sql, plan->outputType()); + } + return std::nullopt; +} + std::multiset> DuckQueryRunner::execute( const std::string& sql, const std::vector& input, @@ -341,38 +375,54 @@ std::optional DuckQueryRunner::toSql( return sql.str(); } -std::optional DuckQueryRunner::toSql( - const std::shared_ptr& joinNode) { - const auto& joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); +static const std::string joinKeysToSql( + const std::vector& keys) { + std::stringstream out; + for (auto i = 0; i < keys.size(); ++i) { + if (i > 0) { + out << ", "; } - return out.str(); - }; + out << keys[i]->name(); + } + return out.str(); +} - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; +static std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); +} - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); - } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); +static std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; } - return out.str(); - }; + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + out << " AND " << filterToSql(joinNode.filter()); + } + return out.str(); +} + +std::optional DuckQueryRunner::toSql( + const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[0])) { + probeTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[0])) { + probeTableName = fmt::format("({})", *subQuery); + } + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[1])) { + buildTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[1])) { + buildTableName = fmt::format("({})", *subQuery); + } const auto& outputNames = joinNode->outputType()->names(); @@ -386,24 +436,27 @@ std::optional DuckQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeftSemiFilter: // Multiple columns returned by a scalar subquery is not supported in - // DuckDB. A scalar subquery expression is a subquery that returns one + // Presto. A scalar subquery expression is a subquery that returns one // result row from exactly one column for every input row. if (joinNode->leftKeys().size() > 1) { return std::nullopt; } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } @@ -412,29 +465,31 @@ std::optional DuckQueryRunner::toSql( case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } - sql << ") FROM t"; + sql << ") FROM " << probeTableName; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; + sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << probeTableName; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } sql << ")"; } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " WHERE NOT EXISTS (SELECT * FROM " + << buildTableName << " WHERE " << joinConditionAsSql(*joinNode); sql << ")"; } break; @@ -448,6 +503,21 @@ std::optional DuckQueryRunner::toSql( std::optional DuckQueryRunner::toSql( const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[0])) { + probeTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[0])) { + probeTableName = fmt::format("({})", *subQuery); + } + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[1])) { + buildTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[1])) { + buildTableName = fmt::format("({})", *subQuery); + } + std::stringstream sql; sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); @@ -458,13 +528,16 @@ std::optional DuckQueryRunner::toSql( const std::string joinCondition{"(1 = 1)"}; switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinCondition; break; default: VELOX_UNREACHABLE( diff --git a/velox/exec/fuzzer/DuckQueryRunner.h b/velox/exec/fuzzer/DuckQueryRunner.h index 4fa826af0488..1edc65bf0ce1 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.h +++ b/velox/exec/fuzzer/DuckQueryRunner.h @@ -46,6 +46,18 @@ class DuckQueryRunner : public ReferenceQueryRunner { /// Assumes that source of AggregationNode or Window Node is 'tmp' table. std::optional toSql(const core::PlanNodePtr& plan) override; + /// Returns the name of the values node table in the form t_. + std::string getTableName(const core::ValuesNodePtr& valuesNode); + + // Traverses all nodes in the plan and returns all tables and their names. + std::unordered_map> + getAllTablesAndNames(const core::PlanNodePtr& plan); + + /// Executes the query based on the plan. Returns std::nullopt if the plan is + /// not supported. + std::optional>> execute( + const core::PlanNodePtr& plan) override; + /// Creates 'tmp' table with 'input' data and runs 'sql' query. Returns /// results according to 'resultType' schema. std::multiset> execute( diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 333a79c24b74..9c6f7752ae7c 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -679,9 +679,8 @@ std::optional JoinFuzzer::computeReferenceResults( VELOX_CHECK(!containsUnsupportedTypes(buildInput[0]->type())); } - if (auto sql = referenceQueryRunner_->toSql(plan)) { - return referenceQueryRunner_->execute( - sql.value(), probeInput, buildInput, plan->outputType()); + if (const auto result = referenceQueryRunner_->execute(plan)) { + return result; } LOG(INFO) << "Query not supported by the reference DB"; diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index c8bba9cdb64d..bee32f9d93da 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -554,46 +554,64 @@ std::optional PrestoQueryRunner::toSql( return sql.str(); } -std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& joinNode) { - if (!isSupportedDwrfType(joinNode->sources()[0]->outputType())) { - return std::nullopt; +static const std::string joinKeysToSql( + const std::vector& keys) { + std::stringstream out; + for (auto i = 0; i < keys.size(); ++i) { + if (i > 0) { + out << ", "; + } + out << keys[i]->name(); } + return out.str(); +}; - if (!isSupportedDwrfType(joinNode->sources()[1]->outputType())) { - return std::nullopt; - } +static std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); +}; - const auto joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); +static std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; } - return out.str(); - }; + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + out << " AND " << filterToSql(joinNode.filter()); + } + return out.str(); +}; - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; +std::string PrestoQueryRunner::getTableName( + const core::ValuesNodePtr& valuesNode) { + return fmt::format("t_{}", valuesNode->id()); +} - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); - } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); - } - return out.str(); - }; +std::optional PrestoQueryRunner::toSql( + const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[0])) { + probeTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[0])) { + probeTableName = fmt::format("({})", *subQuery); + } + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[1])) { + buildTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[1])) { + buildTableName = fmt::format("({})", *subQuery); + } + + if (!isSupportedDwrfType(joinNode->sources()[0]->outputType()) || + !isSupportedDwrfType(joinNode->sources()[1]->outputType())) { + return std::nullopt; + } const auto& outputNames = joinNode->outputType()->names(); @@ -607,13 +625,16 @@ std::optional PrestoQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeftSemiFilter: // Multiple columns returned by a scalar subquery is not supported in @@ -622,9 +643,9 @@ std::optional PrestoQueryRunner::toSql( if (joinNode->leftKeys().size() > 1) { return std::nullopt; } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } @@ -633,29 +654,31 @@ std::optional PrestoQueryRunner::toSql( case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } - sql << ") FROM t"; + sql << ") FROM " << probeTableName; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; + sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << probeTableName; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } sql << ")"; } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " WHERE NOT EXISTS (SELECT * FROM " + << buildTableName << " WHERE " << joinConditionAsSql(*joinNode); sql << ")"; } break; @@ -668,6 +691,21 @@ std::optional PrestoQueryRunner::toSql( std::optional PrestoQueryRunner::toSql( const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[0])) { + probeTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[0])) { + probeTableName = fmt::format("({})", *subQuery); + } + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[1])) { + buildTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[1])) { + buildTableName = fmt::format("({})", *subQuery); + } + std::stringstream sql; sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); @@ -678,13 +716,16 @@ std::optional PrestoQueryRunner::toSql( const std::string joinCondition{"(1 = 1)"}; switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinCondition; break; default: VELOX_UNREACHABLE( @@ -695,11 +736,19 @@ std::optional PrestoQueryRunner::toSql( } std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& valuesNode) { + const core::ValuesNodePtr& valuesNode) { if (!isSupportedDwrfType(valuesNode->outputType())) { return std::nullopt; } - return "tmp"; + return getTableName(valuesNode); +} + +std::optional>> +PrestoQueryRunner::execute(const core::PlanNodePtr& plan) { + if (auto result = executeVector(plan)) { + return exec::test::materialize(*result); + } + return std::nullopt; } std::multiset> PrestoQueryRunner::execute( @@ -749,6 +798,59 @@ std::string PrestoQueryRunner::createTable( return tableDirectoryPath; } +std::unordered_map> +PrestoQueryRunner::getAllTablesAndNames(const core::PlanNodePtr& plan) { + std::unordered_map> result; + if (const auto valuesNode = + std::dynamic_pointer_cast(plan)) { + result.insert({getTableName(valuesNode), valuesNode->values()}); + } else { + for (const auto& source : plan->sources()) { + auto tablesAndNames = getAllTablesAndNames(source); + result.insert(tablesAndNames.begin(), tablesAndNames.end()); + } + } + return result; +} + +std::optional> +PrestoQueryRunner::executeVector(const core::PlanNodePtr& plan) { + if (const auto sql = toSql(plan)) { + std::unordered_map> inputMap = + getAllTablesAndNames(plan); + std::unordered_map> + inputMapWithNulls; + for (const auto& [tableName, input] : inputMap) { + auto inputType = asRowType(input[0]->type()); + if (inputType->size() == 0) { + inputMapWithNulls.insert( + {tableName, + {makeNullRows(input, fmt::format("{}x", tableName), pool())}}); + } + } + + auto writerPool = aggregatePool()->addAggregateChild("writer"); + for (const auto& [tableName, input] : inputMap) { + const std::vector& currInput = + inputMapWithNulls.contains(tableName) ? inputMapWithNulls[tableName] + : input; + auto tableDirectoryPath = createTable(tableName, currInput[0]->type()); + + // Create a new file in table's directory with fuzzer-generated data. + auto filePath = fs::path(tableDirectoryPath) + .append(fmt::format("{}.dwrf", tableName)) + .string() + .substr(strlen("file:")); + + writeToFile(filePath, currInput, writerPool.get()); + } + + // Run the query. + return execute(*sql); + } + return std::nullopt; +} + std::vector PrestoQueryRunner::executeVector( const std::string& sql, const std::vector& probeInput, diff --git a/velox/exec/fuzzer/PrestoQueryRunner.h b/velox/exec/fuzzer/PrestoQueryRunner.h index a72cae913e10..5be626181260 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.h +++ b/velox/exec/fuzzer/PrestoQueryRunner.h @@ -70,6 +70,9 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { bool isSupported(const exec::FunctionSignature& signature) override; + /// Returns the name of the values node table in the form t_. + std::string getTableName(const core::ValuesNodePtr& valuesNode); + /// Creates 'tmp' table using specified data, executes SQL query generated by /// 'toSql' and returns the results. /// @@ -83,6 +86,11 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { const std::vector& input, const velox::RowTypePtr& resultType) override; + /// Executes the query based on the plan. Returns std::nullopt if the plan is + /// not supported. + std::optional>> execute( + const core::PlanNodePtr& plan) override; + std::multiset> execute( const std::string& sql, const std::vector& probeInput, @@ -100,6 +108,13 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { bool supportsVeloxVectorResults() const override; + // Traverses all nodes in the plan and returns all tables and their names. + std::unordered_map> + getAllTablesAndNames(const core::PlanNodePtr& plan); + + std::optional> executeVector( + const core::PlanNodePtr& plan) override; + std::vector executeVector( const std::string& sql, const std::vector& input, @@ -137,13 +152,12 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { const std::shared_ptr& tableWriteNode); std::optional toSql( - const std::shared_ptr& joinNode); + const std::shared_ptr& joinNode); std::optional toSql( const std::shared_ptr& joinNode); - std::optional toSql( - const std::shared_ptr& valuesNode); + std::optional toSql(const core::ValuesNodePtr& valuesNode); std::string startQuery( const std::string& sql, diff --git a/velox/exec/fuzzer/ReferenceQueryRunner.h b/velox/exec/fuzzer/ReferenceQueryRunner.h index 5d0c24afdc24..f501d55eb97f 100644 --- a/velox/exec/fuzzer/ReferenceQueryRunner.h +++ b/velox/exec/fuzzer/ReferenceQueryRunner.h @@ -66,6 +66,11 @@ class ReferenceQueryRunner { return true; } + /// Executes the query based on the plan. Returns std::nullopt if the plan is + /// not supported. + virtual std::optional>> execute( + const core::PlanNodePtr& plan) = 0; + /// Executes SQL query returned by the 'toSql' method using 'input' data. /// Converts results using 'resultType' schema. virtual std::multiset> execute( @@ -88,6 +93,12 @@ class ReferenceQueryRunner { return false; } + /// Similar to 'execute' but returns results in RowVector format. + virtual std::optional> executeVector( + const core::PlanNodePtr& plan) { + VELOX_UNSUPPORTED(); + } + /// Similar to 'execute' but returns results in RowVector format. /// Caller should ensure 'supportsVeloxVectorResults' returns true. virtual std::vector executeVector( diff --git a/velox/exec/tests/PrestoQueryRunnerTest.cpp b/velox/exec/tests/PrestoQueryRunnerTest.cpp index 25b231dc6c7c..14447f5eb396 100644 --- a/velox/exec/tests/PrestoQueryRunnerTest.cpp +++ b/velox/exec/tests/PrestoQueryRunnerTest.cpp @@ -255,4 +255,122 @@ TEST_F(PrestoQueryRunnerTest, toSql) { } } +TEST_F(PrestoQueryRunnerTest, toSqlJoins) { + auto aggregatePool = rootPool_->addAggregateChild("toSqlJoins"); + auto queryRunner = std::make_unique( + aggregatePool.get(), + "http://unused", + "hive", + static_cast(1000)); + + auto t = makeRowVector( + {"t0", "t1", "t2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto u = makeRowVector( + {"u0", "u1", "u2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto v = makeRowVector( + {"v0", "v1", "v2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto w = makeRowVector( + {"w0", "w1", "w2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + + // Single join. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, t1 FROM t_0 INNER JOIN t_1 ON t0 = u0"); + } + + // Two joins with a filter. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0"}, + core::JoinType::kLeftSemiFilter) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).values({v}).planNode(), + "v1 > 0", + {"t0", "v1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, v1" + " FROM (SELECT t0 FROM t_0 WHERE t0 IN (SELECT u0 FROM t_1))" + " INNER JOIN t_3 ON t0 = v0 AND (cast(v1 as BIGINT) > BIGINT '0')"); + } + + // Three joins. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kLeft) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).values({v}).planNode(), + /*filter=*/"", + {"t0", "v1"}, + core::JoinType::kInner) + .hashJoin( + {"t0", "v1"}, + {"w0", "w1"}, + PlanBuilder(planNodeIdGenerator).values({w}).planNode(), + /*filter=*/"", + {"t0", "w1"}, + core::JoinType::kFull) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, w1" + " FROM (SELECT t0, v1 FROM (SELECT t0, t1 FROM t_0 LEFT JOIN t_1 ON t0 = u0)" + " INNER JOIN t_3 ON t0 = v0)" + " FULL OUTER JOIN t_5 ON t0 = w0 AND v1 = w1"); + } +} + } // namespace facebook::velox::exec::test