diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index d9085215860a..549fa4327e40 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -324,6 +324,8 @@ class ValuesNode : public PlanNode { const size_t repeatTimes_; }; +using ValuesNodePtr = std::shared_ptr; + class ArrowStreamNode : public PlanNode { public: ArrowStreamNode( diff --git a/velox/exec/fuzzer/DuckQueryRunner.cpp b/velox/exec/fuzzer/DuckQueryRunner.cpp index d6d606f6497e..30f293d53864 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.cpp +++ b/velox/exec/fuzzer/DuckQueryRunner.cpp @@ -102,6 +102,40 @@ DuckQueryRunner::aggregationFunctionDataSpecs() const { return kAggregationFunctionDataSpecs; } +std::string DuckQueryRunner::getTableName( + const core::ValuesNodePtr& valuesNode) { + return fmt::format("t_{}", valuesNode->id()); +} + +std::unordered_map> +DuckQueryRunner::getAllTablesAndNames(const core::PlanNodePtr& plan) { + std::unordered_map> result; + if (const auto valuesNode = + std::dynamic_pointer_cast(plan)) { + result.insert({getTableName(valuesNode), valuesNode->values()}); + } else { + for (const auto& source : plan->sources()) { + auto tablesAndNames = getAllTablesAndNames(source); + result.insert(tablesAndNames.begin(), tablesAndNames.end()); + } + } + return result; +} + +std::optional>> +DuckQueryRunner::execute(const core::PlanNodePtr& plan) { + if (const auto sql = toSql(plan)) { + DuckDbQueryRunner queryRunner; + std::unordered_map> inputMap = + getAllTablesAndNames(plan); + for (const auto& [tableName, input] : inputMap) { + queryRunner.createTable(tableName, input); + } + return queryRunner.execute(*sql, plan->outputType()); + } + return std::nullopt; +} + std::multiset> DuckQueryRunner::execute( const std::string& sql, const std::vector& input, @@ -341,38 +375,54 @@ std::optional DuckQueryRunner::toSql( return sql.str(); } -std::optional DuckQueryRunner::toSql( - const std::shared_ptr& joinNode) { - const auto& joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); +static const std::string joinKeysToSql( + const std::vector& keys) { + std::stringstream out; + for (auto i = 0; i < keys.size(); ++i) { + if (i > 0) { + out << ", "; } - return out.str(); - }; + out << keys[i]->name(); + } + return out.str(); +} - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; +static std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); +} - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); - } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); +static std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; } - return out.str(); - }; + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + out << " AND " << filterToSql(joinNode.filter()); + } + return out.str(); +} + +std::optional DuckQueryRunner::toSql( + const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[0])) { + probeTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[0])) { + probeTableName = fmt::format("({})", *subQuery); + } + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[1])) { + buildTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[1])) { + buildTableName = fmt::format("({})", *subQuery); + } const auto& outputNames = joinNode->outputType()->names(); @@ -386,24 +436,27 @@ std::optional DuckQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeftSemiFilter: // Multiple columns returned by a scalar subquery is not supported in - // DuckDB. A scalar subquery expression is a subquery that returns one + // Presto. A scalar subquery expression is a subquery that returns one // result row from exactly one column for every input row. if (joinNode->leftKeys().size() > 1) { return std::nullopt; } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } @@ -412,29 +465,31 @@ std::optional DuckQueryRunner::toSql( case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } - sql << ") FROM t"; + sql << ") FROM " << probeTableName; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; + sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << probeTableName; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } sql << ")"; } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " WHERE NOT EXISTS (SELECT * FROM " + << buildTableName << " WHERE " << joinConditionAsSql(*joinNode); sql << ")"; } break; @@ -448,6 +503,21 @@ std::optional DuckQueryRunner::toSql( std::optional DuckQueryRunner::toSql( const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[0])) { + probeTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[0])) { + probeTableName = fmt::format("({})", *subQuery); + } + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[1])) { + buildTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[1])) { + buildTableName = fmt::format("({})", *subQuery); + } + std::stringstream sql; sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); @@ -458,13 +528,16 @@ std::optional DuckQueryRunner::toSql( const std::string joinCondition{"(1 = 1)"}; switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinCondition; break; default: VELOX_UNREACHABLE( diff --git a/velox/exec/fuzzer/DuckQueryRunner.h b/velox/exec/fuzzer/DuckQueryRunner.h index 4fa826af0488..1edc65bf0ce1 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.h +++ b/velox/exec/fuzzer/DuckQueryRunner.h @@ -46,6 +46,18 @@ class DuckQueryRunner : public ReferenceQueryRunner { /// Assumes that source of AggregationNode or Window Node is 'tmp' table. std::optional toSql(const core::PlanNodePtr& plan) override; + /// Returns the name of the values node table in the form t_. + std::string getTableName(const core::ValuesNodePtr& valuesNode); + + // Traverses all nodes in the plan and returns all tables and their names. + std::unordered_map> + getAllTablesAndNames(const core::PlanNodePtr& plan); + + /// Executes the query based on the plan. Returns std::nullopt if the plan is + /// not supported. + std::optional>> execute( + const core::PlanNodePtr& plan) override; + /// Creates 'tmp' table with 'input' data and runs 'sql' query. Returns /// results according to 'resultType' schema. std::multiset> execute( diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 333a79c24b74..9c6f7752ae7c 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -679,9 +679,8 @@ std::optional JoinFuzzer::computeReferenceResults( VELOX_CHECK(!containsUnsupportedTypes(buildInput[0]->type())); } - if (auto sql = referenceQueryRunner_->toSql(plan)) { - return referenceQueryRunner_->execute( - sql.value(), probeInput, buildInput, plan->outputType()); + if (const auto result = referenceQueryRunner_->execute(plan)) { + return result; } LOG(INFO) << "Query not supported by the reference DB"; diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index c8bba9cdb64d..bee32f9d93da 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -554,46 +554,64 @@ std::optional PrestoQueryRunner::toSql( return sql.str(); } -std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& joinNode) { - if (!isSupportedDwrfType(joinNode->sources()[0]->outputType())) { - return std::nullopt; +static const std::string joinKeysToSql( + const std::vector& keys) { + std::stringstream out; + for (auto i = 0; i < keys.size(); ++i) { + if (i > 0) { + out << ", "; + } + out << keys[i]->name(); } + return out.str(); +}; - if (!isSupportedDwrfType(joinNode->sources()[1]->outputType())) { - return std::nullopt; - } +static std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); +}; - const auto joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); +static std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; } - return out.str(); - }; + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + out << " AND " << filterToSql(joinNode.filter()); + } + return out.str(); +}; - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; +std::string PrestoQueryRunner::getTableName( + const core::ValuesNodePtr& valuesNode) { + return fmt::format("t_{}", valuesNode->id()); +} - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); - } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); - } - return out.str(); - }; +std::optional PrestoQueryRunner::toSql( + const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[0])) { + probeTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[0])) { + probeTableName = fmt::format("({})", *subQuery); + } + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[1])) { + buildTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[1])) { + buildTableName = fmt::format("({})", *subQuery); + } + + if (!isSupportedDwrfType(joinNode->sources()[0]->outputType()) || + !isSupportedDwrfType(joinNode->sources()[1]->outputType())) { + return std::nullopt; + } const auto& outputNames = joinNode->outputType()->names(); @@ -607,13 +625,16 @@ std::optional PrestoQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeftSemiFilter: // Multiple columns returned by a scalar subquery is not supported in @@ -622,9 +643,9 @@ std::optional PrestoQueryRunner::toSql( if (joinNode->leftKeys().size() > 1) { return std::nullopt; } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } @@ -633,29 +654,31 @@ std::optional PrestoQueryRunner::toSql( case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } - sql << ") FROM t"; + sql << ") FROM " << probeTableName; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; + sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << probeTableName; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } sql << ")"; } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " WHERE NOT EXISTS (SELECT * FROM " + << buildTableName << " WHERE " << joinConditionAsSql(*joinNode); sql << ")"; } break; @@ -668,6 +691,21 @@ std::optional PrestoQueryRunner::toSql( std::optional PrestoQueryRunner::toSql( const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[0])) { + probeTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[0])) { + probeTableName = fmt::format("({})", *subQuery); + } + if (const auto valuesNode = std::dynamic_pointer_cast( + joinNode->sources()[1])) { + buildTableName = getTableName(valuesNode); + } else if (const auto subQuery = toSql(joinNode->sources()[1])) { + buildTableName = fmt::format("({})", *subQuery); + } + std::stringstream sql; sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); @@ -678,13 +716,16 @@ std::optional PrestoQueryRunner::toSql( const std::string joinCondition{"(1 = 1)"}; switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinCondition; break; default: VELOX_UNREACHABLE( @@ -695,11 +736,19 @@ std::optional PrestoQueryRunner::toSql( } std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& valuesNode) { + const core::ValuesNodePtr& valuesNode) { if (!isSupportedDwrfType(valuesNode->outputType())) { return std::nullopt; } - return "tmp"; + return getTableName(valuesNode); +} + +std::optional>> +PrestoQueryRunner::execute(const core::PlanNodePtr& plan) { + if (auto result = executeVector(plan)) { + return exec::test::materialize(*result); + } + return std::nullopt; } std::multiset> PrestoQueryRunner::execute( @@ -749,6 +798,59 @@ std::string PrestoQueryRunner::createTable( return tableDirectoryPath; } +std::unordered_map> +PrestoQueryRunner::getAllTablesAndNames(const core::PlanNodePtr& plan) { + std::unordered_map> result; + if (const auto valuesNode = + std::dynamic_pointer_cast(plan)) { + result.insert({getTableName(valuesNode), valuesNode->values()}); + } else { + for (const auto& source : plan->sources()) { + auto tablesAndNames = getAllTablesAndNames(source); + result.insert(tablesAndNames.begin(), tablesAndNames.end()); + } + } + return result; +} + +std::optional> +PrestoQueryRunner::executeVector(const core::PlanNodePtr& plan) { + if (const auto sql = toSql(plan)) { + std::unordered_map> inputMap = + getAllTablesAndNames(plan); + std::unordered_map> + inputMapWithNulls; + for (const auto& [tableName, input] : inputMap) { + auto inputType = asRowType(input[0]->type()); + if (inputType->size() == 0) { + inputMapWithNulls.insert( + {tableName, + {makeNullRows(input, fmt::format("{}x", tableName), pool())}}); + } + } + + auto writerPool = aggregatePool()->addAggregateChild("writer"); + for (const auto& [tableName, input] : inputMap) { + const std::vector& currInput = + inputMapWithNulls.contains(tableName) ? inputMapWithNulls[tableName] + : input; + auto tableDirectoryPath = createTable(tableName, currInput[0]->type()); + + // Create a new file in table's directory with fuzzer-generated data. + auto filePath = fs::path(tableDirectoryPath) + .append(fmt::format("{}.dwrf", tableName)) + .string() + .substr(strlen("file:")); + + writeToFile(filePath, currInput, writerPool.get()); + } + + // Run the query. + return execute(*sql); + } + return std::nullopt; +} + std::vector PrestoQueryRunner::executeVector( const std::string& sql, const std::vector& probeInput, diff --git a/velox/exec/fuzzer/PrestoQueryRunner.h b/velox/exec/fuzzer/PrestoQueryRunner.h index a72cae913e10..5be626181260 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.h +++ b/velox/exec/fuzzer/PrestoQueryRunner.h @@ -70,6 +70,9 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { bool isSupported(const exec::FunctionSignature& signature) override; + /// Returns the name of the values node table in the form t_. + std::string getTableName(const core::ValuesNodePtr& valuesNode); + /// Creates 'tmp' table using specified data, executes SQL query generated by /// 'toSql' and returns the results. /// @@ -83,6 +86,11 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { const std::vector& input, const velox::RowTypePtr& resultType) override; + /// Executes the query based on the plan. Returns std::nullopt if the plan is + /// not supported. + std::optional>> execute( + const core::PlanNodePtr& plan) override; + std::multiset> execute( const std::string& sql, const std::vector& probeInput, @@ -100,6 +108,13 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { bool supportsVeloxVectorResults() const override; + // Traverses all nodes in the plan and returns all tables and their names. + std::unordered_map> + getAllTablesAndNames(const core::PlanNodePtr& plan); + + std::optional> executeVector( + const core::PlanNodePtr& plan) override; + std::vector executeVector( const std::string& sql, const std::vector& input, @@ -137,13 +152,12 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { const std::shared_ptr& tableWriteNode); std::optional toSql( - const std::shared_ptr& joinNode); + const std::shared_ptr& joinNode); std::optional toSql( const std::shared_ptr& joinNode); - std::optional toSql( - const std::shared_ptr& valuesNode); + std::optional toSql(const core::ValuesNodePtr& valuesNode); std::string startQuery( const std::string& sql, diff --git a/velox/exec/fuzzer/ReferenceQueryRunner.h b/velox/exec/fuzzer/ReferenceQueryRunner.h index 5d0c24afdc24..f501d55eb97f 100644 --- a/velox/exec/fuzzer/ReferenceQueryRunner.h +++ b/velox/exec/fuzzer/ReferenceQueryRunner.h @@ -66,6 +66,11 @@ class ReferenceQueryRunner { return true; } + /// Executes the query based on the plan. Returns std::nullopt if the plan is + /// not supported. + virtual std::optional>> execute( + const core::PlanNodePtr& plan) = 0; + /// Executes SQL query returned by the 'toSql' method using 'input' data. /// Converts results using 'resultType' schema. virtual std::multiset> execute( @@ -88,6 +93,12 @@ class ReferenceQueryRunner { return false; } + /// Similar to 'execute' but returns results in RowVector format. + virtual std::optional> executeVector( + const core::PlanNodePtr& plan) { + VELOX_UNSUPPORTED(); + } + /// Similar to 'execute' but returns results in RowVector format. /// Caller should ensure 'supportsVeloxVectorResults' returns true. virtual std::vector executeVector( diff --git a/velox/exec/tests/PrestoQueryRunnerTest.cpp b/velox/exec/tests/PrestoQueryRunnerTest.cpp index 25b231dc6c7c..14447f5eb396 100644 --- a/velox/exec/tests/PrestoQueryRunnerTest.cpp +++ b/velox/exec/tests/PrestoQueryRunnerTest.cpp @@ -255,4 +255,122 @@ TEST_F(PrestoQueryRunnerTest, toSql) { } } +TEST_F(PrestoQueryRunnerTest, toSqlJoins) { + auto aggregatePool = rootPool_->addAggregateChild("toSqlJoins"); + auto queryRunner = std::make_unique( + aggregatePool.get(), + "http://unused", + "hive", + static_cast(1000)); + + auto t = makeRowVector( + {"t0", "t1", "t2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto u = makeRowVector( + {"u0", "u1", "u2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto v = makeRowVector( + {"v0", "v1", "v2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto w = makeRowVector( + {"w0", "w1", "w2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + + // Single join. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, t1 FROM t_0 INNER JOIN t_1 ON t0 = u0"); + } + + // Two joins with a filter. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0"}, + core::JoinType::kLeftSemiFilter) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).values({v}).planNode(), + "v1 > 0", + {"t0", "v1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, v1" + " FROM (SELECT t0 FROM t_0 WHERE t0 IN (SELECT u0 FROM t_1))" + " INNER JOIN t_3 ON t0 = v0 AND (cast(v1 as BIGINT) > BIGINT '0')"); + } + + // Three joins. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kLeft) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).values({v}).planNode(), + /*filter=*/"", + {"t0", "v1"}, + core::JoinType::kInner) + .hashJoin( + {"t0", "v1"}, + {"w0", "w1"}, + PlanBuilder(planNodeIdGenerator).values({w}).planNode(), + /*filter=*/"", + {"t0", "w1"}, + core::JoinType::kFull) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, w1" + " FROM (SELECT t0, v1 FROM (SELECT t0, t1 FROM t_0 LEFT JOIN t_1 ON t0 = u0)" + " INNER JOIN t_3 ON t0 = v0)" + " FULL OUTER JOIN t_5 ON t0 = w0 AND v1 = w1"); + } +} + } // namespace facebook::velox::exec::test