From 082f80d49b4c6f111ae71926923284f05e637be7 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 10 May 2022 19:03:27 +0800 Subject: [PATCH 1/4] v1.0 --- programs/server/Server.cpp | 2 + src/CMakeLists.txt | 2 + src/Common/ProfileEvents.cpp | 3 +- src/Core/Settings.h | 2 + .../ASTRewriters/ASTAnalyzeUtil.cpp | 234 ++++++ .../ASTRewriters/ASTAnalyzeUtil.h | 53 ++ .../ASTRewriters/ASTBuildUtil.cpp | 216 +++++ src/Interpreters/ASTRewriters/ASTBuildUtil.h | 82 ++ .../ASTRewriters/ASTDepthFirstVisitor.h | 104 +++ .../CollectAliasColumnElementsAction.h | 66 ++ .../ASTRewriters/CollectQueryStoragesAction.h | 74 ++ .../CollectRequiredColumnsAction.cpp | 260 ++++++ .../CollectRequiredColumnsAction.h | 59 ++ .../ASTRewriters/IdentRenameRewriteAction.cpp | 323 ++++++++ .../ASTRewriters/IdentRenameRewriteAction.h | 90 ++ .../IdentifierQualiferRemoveAction.cpp | 69 ++ .../IdentifierQualiferRemoveAction.h | 24 + .../NestedJoinQueryRewriteAction.cpp | 283 +++++++ .../NestedJoinQueryRewriteAction.h | 66 ++ ...eryDistributedAggregationRewriteAction.cpp | 369 +++++++++ ...QueryDistributedAggregationRewriteAction.h | 71 ++ ...StageQueryDistributedJoinRewriteAction.cpp | 561 +++++++++++++ .../StageQueryDistributedJoinRewriteAction.h | 107 +++ ...geQueryShuffleFinishEventRewriteAction.cpp | 105 +++ ...tageQueryShuffleFinishEventRewriteAction.h | 41 + src/Interpreters/InterpreterFactory.cpp | 6 + src/Interpreters/InterpreterStageQuery.cpp | 422 ++++++++++ src/Interpreters/InterpreterStageQuery.h | 48 ++ .../StorageDistributedTasksBuilder.cpp | 37 + .../StorageDistributedTasksBuilder.h | 38 + src/Interpreters/executeQuery.cpp | 40 + src/Parsers/ASTStageQuery.cpp | 47 ++ src/Parsers/ASTStageQuery.h | 26 + src/Processors/QueryPlan/StageQueryStep.cpp | 98 +++ src/Processors/QueryPlan/StageQueryStep.h | 39 + .../Transforms/StageQueryTransform.cpp | 324 ++++++++ .../Transforms/StageQueryTransform.h | 99 +++ .../DistributedShuffle/ShuffleBlockTable.cpp | 242 ++++++ .../DistributedShuffle/ShuffleBlockTable.h | 141 ++++ .../DistributedShuffle/StorageShuffle.cpp | 771 ++++++++++++++++++ .../DistributedShuffle/StorageShuffle.h | 179 ++++ src/TableFunctions/TableFunctionShuffle.cpp | 197 +++++ src/TableFunctions/TableFunctionShuffle.h | 111 +++ src/TableFunctions/registerTableFunctions.cpp | 2 + src/TableFunctions/registerTableFunctions.h | 2 + 45 files changed, 6134 insertions(+), 1 deletion(-) create mode 100644 src/Interpreters/ASTRewriters/ASTAnalyzeUtil.cpp create mode 100644 src/Interpreters/ASTRewriters/ASTAnalyzeUtil.h create mode 100644 src/Interpreters/ASTRewriters/ASTBuildUtil.cpp create mode 100644 src/Interpreters/ASTRewriters/ASTBuildUtil.h create mode 100644 src/Interpreters/ASTRewriters/ASTDepthFirstVisitor.h create mode 100644 src/Interpreters/ASTRewriters/CollectAliasColumnElementsAction.h create mode 100644 src/Interpreters/ASTRewriters/CollectQueryStoragesAction.h create mode 100644 src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.cpp create mode 100644 src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.h create mode 100644 src/Interpreters/ASTRewriters/IdentRenameRewriteAction.cpp create mode 100644 src/Interpreters/ASTRewriters/IdentRenameRewriteAction.h create mode 100644 src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp create mode 100644 src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.h create mode 100644 src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.cpp create mode 100644 src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.h create mode 100644 src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp create mode 100644 src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.h create mode 100644 src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp create mode 100644 src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.h create mode 100644 src/Interpreters/ASTRewriters/StageQueryShuffleFinishEventRewriteAction.cpp create mode 100644 src/Interpreters/ASTRewriters/StageQueryShuffleFinishEventRewriteAction.h create mode 100644 src/Interpreters/InterpreterStageQuery.cpp create mode 100644 src/Interpreters/InterpreterStageQuery.h create mode 100644 src/Interpreters/StorageDistributedTasksBuilder.cpp create mode 100644 src/Interpreters/StorageDistributedTasksBuilder.h create mode 100644 src/Parsers/ASTStageQuery.cpp create mode 100644 src/Parsers/ASTStageQuery.h create mode 100644 src/Processors/QueryPlan/StageQueryStep.cpp create mode 100644 src/Processors/QueryPlan/StageQueryStep.h create mode 100644 src/Processors/Transforms/StageQueryTransform.cpp create mode 100644 src/Processors/Transforms/StageQueryTransform.h create mode 100644 src/Storages/DistributedShuffle/ShuffleBlockTable.cpp create mode 100644 src/Storages/DistributedShuffle/ShuffleBlockTable.h create mode 100644 src/Storages/DistributedShuffle/StorageShuffle.cpp create mode 100644 src/Storages/DistributedShuffle/StorageShuffle.h create mode 100644 src/TableFunctions/TableFunctionShuffle.cpp create mode 100644 src/TableFunctions/TableFunctionShuffle.h diff --git a/programs/server/Server.cpp b/programs/server/Server.cpp index 476725c5627c..7326e83a20e3 100644 --- a/programs/server/Server.cpp +++ b/programs/server/Server.cpp @@ -92,6 +92,7 @@ #include #include #include +#include #include "config_core.h" #include "Common/config_version.h" @@ -688,6 +689,7 @@ int Server::main(const std::vector & /*args*/) } } + registerAllStorageDistributedTaskBuilderMakers(); Poco::ThreadPool server_pool(3, config().getUInt("max_connections", 1024)); std::mutex servers_lock; std::vector servers; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 10bdc464ac67..1ecd01658c44 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -127,6 +127,7 @@ if (TARGET ch_contrib::hdfs) endif() add_headers_and_sources(dbms Storages/Cache) +add_headers_and_sources(dbms Storages/DistributedShuffle) if (TARGET ch_contrib::hivemetastore) add_headers_and_sources(dbms Storages/Hive) endif() @@ -241,6 +242,7 @@ add_object_library(clickhouse_databases_mysql Databases/MySQL) add_object_library(clickhouse_disks Disks) add_object_library(clickhouse_interpreters Interpreters) add_object_library(clickhouse_interpreters_access Interpreters/Access) +add_object_library(clickhouse_interpreters_ast_rewriters Interpreters/ASTRewriters) add_object_library(clickhouse_interpreters_mysql Interpreters/MySQL) add_object_library(clickhouse_interpreters_clusterproxy Interpreters/ClusterProxy) add_object_library(clickhouse_interpreters_jit Interpreters/JIT) diff --git a/src/Common/ProfileEvents.cpp b/src/Common/ProfileEvents.cpp index b8e552f60234..19c5a41f97e0 100644 --- a/src/Common/ProfileEvents.cpp +++ b/src/Common/ProfileEvents.cpp @@ -343,7 +343,8 @@ \ M(ScalarSubqueriesGlobalCacheHit, "Number of times a read from a scalar subquery was done using the global cache") \ M(ScalarSubqueriesLocalCacheHit, "Number of times a read from a scalar subquery was done using the local cache") \ - M(ScalarSubqueriesCacheMiss, "Number of times a read from a scalar subquery was not cached and had to be calculated completely") + M(ScalarSubqueriesCacheMiss, "Number of times a read from a scalar subquery was not cached and had to be calculated completely") \ + M(ClearTimeoutShuffleStorageSession, "Number of sessions cleared by timeout") \ namespace ProfileEvents { diff --git a/src/Core/Settings.h b/src/Core/Settings.h index 54d1f1a6d88f..8d441e8d8b9f 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -599,6 +599,8 @@ static constexpr UInt64 operator""_GiB(unsigned long long value) M(Bool, allow_experimental_object_type, false, "Allow Object and JSON data types", 0) \ M(String, insert_deduplication_token, "", "If not empty, used for duplicate detection instead of data digest", 0) \ M(Bool, count_distinct_optimization, false, "Rewrite count distinct to subquery of group by", 0) \ + M(String, use_cluster_for_distributed_shuffle, "", "If you want to run the join and group by in distributed shuffle mode, set it as one of the available cluster.", 0) \ + M(UInt64, shuffle_storage_session_timeout, 1800, "How long a session can be alive before expired by timeout", 0) \ M(Bool, throw_on_unsupported_query_inside_transaction, true, "Throw exception if unsupported query is used inside transaction", 0) \ M(TransactionsWaitCSNMode, wait_changes_become_visible_after_commit_mode, TransactionsWaitCSNMode::WAIT_UNKNOWN, "Wait for committed changes to become actually visible in the latest snapshot", 0) \ M(Bool, throw_if_no_data_to_insert, true, "Enables or disables empty INSERTs, enabled by default", 0) \ diff --git a/src/Interpreters/ASTRewriters/ASTAnalyzeUtil.cpp b/src/Interpreters/ASTRewriters/ASTAnalyzeUtil.cpp new file mode 100644 index 000000000000..4913d0da6419 --- /dev/null +++ b/src/Interpreters/ASTRewriters/ASTAnalyzeUtil.cpp @@ -0,0 +1,234 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ + +String ColumnWithDetailNameAndType::toString() const +{ + WriteBufferFromOwnString buf; + buf << "full_name: " << full_name << ", short_name: " << short_name + << ", alias_name: " << alias_name; + buf << ", data_type: " << type->getName(); + return buf.str(); +} + +NamesAndTypesList ColumnWithDetailNameAndType::toNamesAndTypesList(const std::vector & columns) +{ + std::list names_and_types; + for (const auto & col : columns) + { + names_and_types.emplace_back(NameAndTypePair(col.short_name, col.type)); + } + NamesAndTypesList res(names_and_types.begin(), names_and_types.end()); + return res; +} + +void ColumnWithDetailNameAndType::makeAliasByFullName(std::vector & columns) +{ + for (auto & column : columns) + { + if (column.full_name != column.short_name && column.alias_name.empty()) + { + column.alias_name = column.full_name; + std::replace(column.alias_name.begin(), column.alias_name.end(), '.', '_'); + } + } +} + +std::vector ColumnWithDetailNameAndType::splitedFullName() const +{ + Poco::StringTokenizer splitter(full_name, "."); + std::vector res; + for (const auto & token : splitter) + { + res.push_back(token); + } + return res; +} + +bool ASTAnalyzeUtil::hasGroupByRecursively(const ASTPtr & ast) +{ + return hasGroupByRecursively(ast.get()); +} +bool ASTAnalyzeUtil::hasGroupByRecursively(const IAST * ast) +{ + if (!ast) + return false; + if (const auto * insert_ast = ast->as()) + { + return hasGroupByRecursively(insert_ast->select); + } + else if (const auto * select_with_union = ast->as()) + { + for (auto & child : select_with_union->list_of_selects->children) + { + if (hasGroupByRecursively(child)) + return true; + } + } + else if (const auto * select_ast = ast->as()) + { + if (select_ast->groupBy() != nullptr) + return true; + return hasGroupByRecursively(select_ast->groupBy().get()); + } + else if (const auto * tables_ast = ast->as()) + { + for (const auto & child : tables_ast->children) + { + if (hasGroupByRecursively(child.get())) + return true; + } + } + else if (const auto * table_element = ast->as()) + { + const auto * table_expr = table_element->table_expression->as(); + return hasGroupByRecursively(table_expr->subquery.get()); + } + else if (const auto * subquery = ast->as()) + { + for (const auto & child : subquery->children) + { + if (hasGroupByRecursively(child.get())) + return true; + } + } + return false; +} + + +bool ASTAnalyzeUtil::hasGroupBy(const ASTPtr & ast) +{ + return hasGroupBy(ast.get()); +} + +bool ASTAnalyzeUtil::hasGroupBy(const IAST * ast) +{ + if (const auto * select_with_union_ast = ast->as()) + { + if (select_with_union_ast->list_of_selects->children.size() > 1) + return false; + return hasGroupBy(select_with_union_ast->list_of_selects->children[0]); + } + else if (const auto * select_ast = ast->as()) + { + return select_ast->groupBy() != nullptr; + } + return false; +} + +bool ASTAnalyzeUtil::hasAggregationColumn(const ASTPtr & ast) +{ + return hasAggregationColumn(ast.get()); +} +bool ASTAnalyzeUtil::hasAggregationColumn(const IAST * ast) +{ + if (const auto * select_ast = ast->as()) + { + const auto * select_list = select_ast->select()->as(); + for (const auto & child : select_list->children) + { + if (const auto * function = child->as()) + { + if (function->name == "count" || function->name == "avg" || function->name == "sum") + { + return true; + } + } + } + } + return false; +} + +bool ASTAnalyzeUtil::hasAggregationColumnRecursively(const ASTPtr & ast) +{ + return hasAggregationColumnRecursively(ast.get()); +} + +bool ASTAnalyzeUtil::hasAggregationColumnRecursively(const IAST * ast) +{ + if (!ast) + return false; + if (const auto * insert_ast = ast->as()) + { + return hasAggregationColumnRecursively(insert_ast->select.get()); + } + else if (const auto * select_with_union_ast = ast->as()) + { + for (const auto & child : select_with_union_ast->list_of_selects->children) + { + if (hasAggregationColumnRecursively(child.get())) + return true; + } + } + else if (const auto * select_ast = ast->as()) + { + if (hasAggregationColumn(select_ast)) + return true; + return hasAggregationColumnRecursively(select_ast->tables().get()); + } + else if (const auto * tables_ast = ast->as()) + { + for (const auto & child : tables_ast->children) + { + if (hasAggregationColumnRecursively(child.get())) + return true; + } + } + else if (const auto * table_element = ast->as()) + { + const auto * table_expr = table_element->table_expression->as(); + return hasAggregationColumnRecursively(table_expr->subquery.get()); + } + else if (const auto * subquery = ast->as()) + { + for (const auto & child : subquery->children) + { + if (hasAggregationColumnRecursively(child.get())) + return true; + } + } + return false; +} + +String ASTAnalyzeUtil::tryGetTableExpressionAlias(const ASTTableExpression * table_expr) +{ + String res; + if (table_expr->table_function) + { + res = table_expr->table_function->as()->tryGetAlias(); + } + else if (table_expr->subquery) + { + res = table_expr->subquery->as()->tryGetAlias(); + } + else if (table_expr->database_and_table_name) + { + if (const auto * with_alias_ast = table_expr->database_and_table_name->as()) + { + res = with_alias_ast->tryGetAlias(); + if (res.empty()) + { + res = with_alias_ast->shortName(); + } + } + } + return res; +} + +} diff --git a/src/Interpreters/ASTRewriters/ASTAnalyzeUtil.h b/src/Interpreters/ASTRewriters/ASTAnalyzeUtil.h new file mode 100644 index 000000000000..19928dd138a8 --- /dev/null +++ b/src/Interpreters/ASTRewriters/ASTAnalyzeUtil.h @@ -0,0 +1,53 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +struct ColumnWithDetailNameAndType +{ + String full_name; + String short_name; + String alias_name; + DataTypePtr type; + String toString() const; + + static void makeAliasByFullName(std::vector & columns); + static NamesAndTypesList toNamesAndTypesList(const std::vector & columns); + std::vector splitedFullName() const; +}; +using ColumnWithDetailNameAndTypes = std::vector; +class ASTAnalyzeUtil +{ +public: + static bool hasGroupByRecursively(const ASTPtr & ast); + static bool hasGroupBy(const ASTPtr & ast); + static bool hasGroupByRecursively(const IAST * ast); + static bool hasGroupBy(const IAST * ast); + + //static bool hasAggregationColumnRecursively(ASTPtr ast); + static bool hasAggregationColumn(const ASTPtr & ast); + static bool hasAggregationColumn(const IAST * ast); + static bool hasAggregationColumnRecursively(const ASTPtr & ast); + static bool hasAggregationColumnRecursively(const IAST * ast); + static String tryGetTableExpressionAlias(const ASTTableExpression * table_expr); + +}; + + +class ShuffleTableIdGenerator +{ +public: + ShuffleTableIdGenerator():id(0){} + inline UInt32 nextId() { return id++; } +private: + UInt32 id; +}; +using ShuffleTableIdGeneratorPtr = std::shared_ptr; +} diff --git a/src/Interpreters/ASTRewriters/ASTBuildUtil.cpp b/src/Interpreters/ASTRewriters/ASTBuildUtil.cpp new file mode 100644 index 000000000000..6457e85a4a54 --- /dev/null +++ b/src/Interpreters/ASTRewriters/ASTBuildUtil.cpp @@ -0,0 +1,216 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} + +String ASTBuildUtil::getTableExpressionAlias(const ASTTableExpression * ast) +{ + String res; + if (ast->table_function) + { + res = ast->table_function->as()->tryGetAlias(); + } + else if (ast->subquery) + { + res = ast->subquery->as()->tryGetAlias(); + } + else if (ast->database_and_table_name) + { + if (const auto * with_alias_ast = ast->database_and_table_name->as()) + res = with_alias_ast->tryGetAlias(); + } + return res; +} + +std::shared_ptr ASTBuildUtil::toShortNameExpressionList(const ColumnWithDetailNameAndTypes & columns) +{ + auto expression_list = std::make_shared(); + for (const auto & col : columns) + { + auto ident = std::make_shared(col.short_name); + expression_list->children.push_back(ident); + } + return expression_list; +} +String ASTBuildUtil::toTableStructureDescription(const ColumnWithDetailNameAndTypes & columns) +{ + WriteBufferFromOwnString buf; + int i = 0; + for (const auto & col : columns) + { + if (i) + { + buf << ","; + } + buf << col.short_name << " " << col.type->getName(); + i++; + } + return buf.str(); +} + + +ASTPtr ASTBuildUtil::createShuffleTableFunction( + const String & function_name, + const String & cluster_name, + const String & session_id, + const String & table_id, + const NamesAndTypesList & columns, + const ASTPtr & hash_expression_list, + const String & alias) +{ + auto table_func = std::make_shared(); + table_func->name = function_name; + table_func->arguments = std::make_shared(); + table_func->children.push_back(table_func->arguments); + + Field cluster_name_field(cluster_name); + table_func->arguments->children.push_back(std::make_shared(cluster_name_field)); + + table_func->arguments->children.push_back(std::make_shared(Field(session_id))); + table_func->arguments->children.push_back(std::make_shared(Field(table_id))); + + WriteBufferFromOwnString struct_buf; + int i = 0; + for (const auto & name_and_type : columns) + { + if (i) + struct_buf << ","; + struct_buf << name_and_type.name << " " << name_and_type.type->getName(); + i++; + } + auto hash_table_structure = std::make_shared(struct_buf.str()); + table_func->arguments->children.push_back(hash_table_structure); + + if (hash_expression_list) + { + auto hash_table_key = std::make_shared(queryToString(hash_expression_list)); + table_func->arguments->children.push_back(hash_table_key); + } + if (!alias.empty()) + table_func->setAlias(alias); + return table_func; +} + +ASTPtr ASTBuildUtil::createTableFunctionInsertSelectQuery(ASTPtr table_function, ASTPtr select_query) +{ + if (!table_function->as() || !select_query->as()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Invalide ast type."); + } + + auto insert_query = std::make_shared(); + insert_query->table_function = table_function; + insert_query->select = select_query; + return insert_query; +} + +ASTPtr ASTBuildUtil::wrapSelectQuery(const ASTSelectQuery * select_query) +{ + auto list_of_selects = std::make_shared(); + list_of_selects->children.push_back(select_query->clone()); + auto select_with_union_query = std::make_shared(); + select_with_union_query->children.push_back(list_of_selects); + select_with_union_query->list_of_selects = list_of_selects; + return select_with_union_query; +} + +ASTPtr ASTBuildUtil::createSelectExpression(const NamesAndTypesList & names_and_types) +{ + auto select_expression = std::make_shared(); + for (const auto & name_and_type : names_and_types) + { + auto ident = std::make_shared(name_and_type.name); + select_expression->children.push_back(ident); + } + return select_expression; +} + +void ASTBuildUtil::updateSelectQueryTables(ASTSelectQuery * select_query, const ASTTableExpression * table_expr_) +{ + auto table_expr = table_expr_->clone(); + select_query->setExpression(ASTSelectQuery::Expression::TABLES, std::make_shared()); + auto tables_in_select = select_query->tables(); + auto table_element = std::make_shared(); + table_element->children.push_back(table_expr); + table_element->table_expression = table_expr; + tables_in_select->children.push_back(table_element); +} + +void ASTBuildUtil::updateSelectQueryTables(ASTSelectQuery * select_query, const ASTTablesInSelectQueryElement * table_element_) +{ + auto table_element = table_element_->clone(); + select_query->setExpression(ASTSelectQuery::Expression::TABLES, std::make_shared()); + auto tables = select_query->tables(); + tables->children.push_back(table_element); +} + + +void ASTBuildUtil::updateSelectQueryTables( + ASTSelectQuery * select_query, + const ASTTablesInSelectQueryElement * left_table_element_, + const ASTTablesInSelectQueryElement * right_table_element_) +{ + auto left_table_element = left_table_element_->clone(); + auto right_table_element = right_table_element_->clone(); + select_query->setExpression(ASTSelectQuery::Expression::TABLES, std::make_shared()); + auto tables = select_query->tables(); + tables->children.push_back(left_table_element); + tables->children.push_back(right_table_element); + +} + +ASTPtr ASTBuildUtil::createTablesInSelectQueryElement(const ASTTableExpression * table_expr_, ASTPtr table_join_) +{ + auto table_expr = table_expr_->clone(); + auto table_element = std::make_shared(); + table_element->children.push_back(table_expr); + table_element->table_expression = table_expr; + if (table_join_) + table_element->table_join = table_join_->clone(); + return table_element; +} +ASTPtr ASTBuildUtil::createTablesInSelectQueryElement(const ASTFunction * func_, ASTPtr table_join_) +{ + auto func = func_->clone(); + auto table_expr = std::make_shared(); + table_expr->table_function = func; + + auto table_element = std::make_shared(); + table_element->children.push_back(table_expr); + table_element->table_expression = table_expr; + if (table_join_) + table_element->table_join = table_join_->clone(); + return table_element; +} + +ASTPtr ASTBuildUtil::createTablesInSelectQueryElement(const ASTSelectWithUnionQuery * select_query, const String & alias) +{ + auto table_expr = std::make_shared(); + table_expr->subquery = std::make_shared(); + table_expr->subquery->children.push_back(select_query->clone()); + if (!alias.empty()) + { + table_expr->subquery->as()->alias = alias; + } + + auto table_element = std::make_shared(); + table_element->children.push_back(table_expr); + table_element->table_expression = table_expr; + + return table_element; +} + +} diff --git a/src/Interpreters/ASTRewriters/ASTBuildUtil.h b/src/Interpreters/ASTRewriters/ASTBuildUtil.h new file mode 100644 index 000000000000..cfc042960229 --- /dev/null +++ b/src/Interpreters/ASTRewriters/ASTBuildUtil.h @@ -0,0 +1,82 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ +class ASTBuildUtil +{ +public: + static String getTableExpressionAlias(const ASTTableExpression * ast); + static std::shared_ptr toShortNameExpressionList(const ColumnWithDetailNameAndTypes & columns); + static String toTableStructureDescription(const ColumnWithDetailNameAndTypes & columns); + + /** + * @brief Create a Shuffle Table Function object + * + * @param function_name which shuffle function to use. see TableFunctionShuffle.h + * @param session_id session_od + * @param cluster_name cluster name + * @param table_id table_id + * @param columns describe the table structure. etc. 'x int, y string' + * @param hash_expression_list hash expression list for shuffle hashing. etc. 'x, y' + * @param alias table alias + * @return ASTPtr ASTFunction + */ + static ASTPtr createShuffleTableFunction( + const String & function_name, + const String & cluster_name, + const String & session_id, + const String & table_id, + const NamesAndTypesList & columns, + const ASTPtr & hash_expression_list, + const String & alias = ""); + + /** + * @brief Create a Table Function Insert Select Query object + * + * @param table_function must be a ASTFunction + * @param select_query must be a ASTSelectWithUnionQuery + * @return ASTPtr it's a ASTInsertQuery + */ + static ASTPtr createTableFunctionInsertSelectQuery(ASTPtr table_function, ASTPtr select_query); + + /** + * @brief Create a ASTSelectWithUnionQuery with a ASTSelectQuery + * + * @param select_query must be a ASTSelectQuery + * @return ASTPtr it's a ASTSelectWithUnionQuery + */ + static ASTPtr wrapSelectQuery(const ASTSelectQuery * select_query); + + /** + * @brief Create a Select Expression object + * + * @param names_and_types Use the names to build the select expression + * @return ASTPtr + */ + static ASTPtr createSelectExpression(const NamesAndTypesList & names_and_types); + + /** + * @brief Update ASTSelectQuery::TABLES ASTTableExpressions + */ + static void updateSelectQueryTables(ASTSelectQuery * select_query, const ASTTableExpression * table_expr_); + + static void updateSelectQueryTables(ASTSelectQuery * select_query, const ASTTablesInSelectQueryElement * table_element_); + + static void updateSelectQueryTables( + ASTSelectQuery * select_query, + const ASTTablesInSelectQueryElement * left_table_element_, + const ASTTablesInSelectQueryElement * right_table_element_); + + static ASTPtr createTablesInSelectQueryElement(const ASTTableExpression * table_expr_, ASTPtr table_join_ = nullptr); + static ASTPtr createTablesInSelectQueryElement(const ASTFunction * func_, ASTPtr table_join_ = nullptr); + static ASTPtr createTablesInSelectQueryElement(const ASTSelectWithUnionQuery * select_query, const String & alias = ""); + +}; + +} diff --git a/src/Interpreters/ASTRewriters/ASTDepthFirstVisitor.h b/src/Interpreters/ASTRewriters/ASTDepthFirstVisitor.h new file mode 100644 index 000000000000..f61bc20cb356 --- /dev/null +++ b/src/Interpreters/ASTRewriters/ASTDepthFirstVisitor.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include +#include +#include + +namespace DB +{ +template +class ASTDepthFirstVisitor +{ +public: + using Result = typename Action::Result; + ASTDepthFirstVisitor(Action & action_, const ASTPtr & ast_) + : action(action_) + , original_ast(ast_) + {} + + Result visit() + { + visit(original_ast); + return action.getResult(); + } + +private: + Action & action; + ASTPtr original_ast; + void visit(const ASTPtr & ast) + { + ASTs children = action.collectChildren(ast); + action.beforeVisitChildren(ast); + for (const auto & child : children) + { + action.beforeVisitChild(child); + visit(child); + action.afterVisitChild(child); + } + action.afterVisitChildren(ast); + + return action.visit(ast); + } +}; +struct SimpleVisitFrame +{ + explicit SimpleVisitFrame(const ASTPtr & ast) : original_ast(ast) { } + ASTPtr original_ast; + ASTPtr result_ast; + ASTs children_results; +}; +using SimpleVisitFramePtr = std::shared_ptr; + +template +struct SimpleVisitFrameStack +{ + std::list> stack; + + inline void pushFrame(const ASTPtr ast) + { + stack.emplace_back(std::make_shared(ast)); + } + + inline void pushFrame(std::shared_ptr frame) + { + stack.emplace_back(frame); + } + + inline std::shared_ptr getTopFrame() + { + if (stack.empty()) + return nullptr; + return stack.back(); + } + inline std::shared_ptr getPrevFrame() + { + if (stack.size() < 2) + return nullptr; + auto iter = stack.rbegin(); + iter++; + return *iter; + } + inline void popFrame() + { + stack.pop_back(); + } + + inline size_t size() const + { + return stack.size(); + } +}; + +class EmptyASTDepthFirstVisitAction +{ +public: + + virtual ~EmptyASTDepthFirstVisitAction() = default; + virtual ASTs collectChildren(const ASTPtr & /*current_ast*/) { return {}; } + virtual void beforeVisitChildren(const ASTPtr & /*current_ast*/) {} + virtual void beforeVisitChild(const ASTPtr & /*child*/) {} + virtual void afterVisitChild(const ASTPtr & /*child*/) {} + virtual void afterVisitChildren(const ASTPtr & /*current_ast*/) {} +}; +} diff --git a/src/Interpreters/ASTRewriters/CollectAliasColumnElementsAction.h b/src/Interpreters/ASTRewriters/CollectAliasColumnElementsAction.h new file mode 100644 index 000000000000..9e10458dd8ab --- /dev/null +++ b/src/Interpreters/ASTRewriters/CollectAliasColumnElementsAction.h @@ -0,0 +1,66 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ +class CollectAliasColumnElementAction : public EmptyASTDepthFirstVisitAction +{ +public: + using Result = std::map; + ~CollectAliasColumnElementAction() override = default; + + ASTs collectChildren(const ASTPtr & ast) override + { + if (!ast) + return {}; + if (const auto * expr_list_ast = ast->as()) + { + return expr_list_ast->children; + } + else if (const auto * select_ast = ast->as()) + { + return ASTs{select_ast->select()}; + } + else if (const auto * union_select_ast = ast->as()) + { + return union_select_ast->list_of_selects->children; + } + return {}; + } + + void visit(const ASTPtr & ast) + { + if (const auto * ident_ast = ast->as()) + { + auto alias = ident_ast->tryGetAlias(); + if (!alias.empty()) + { + result[alias] = ast->clone(); + } + + } + else if (const auto * function_ast = ast->as()) + { + auto alias = function_ast->tryGetAlias(); + if (!alias.empty()) + { + result[alias] = ast->clone(); + } + } + } + + Result getResult() { return result; } + +private: + Result result; + +}; +} diff --git a/src/Interpreters/ASTRewriters/CollectQueryStoragesAction.h b/src/Interpreters/ASTRewriters/CollectQueryStoragesAction.h new file mode 100644 index 000000000000..0a41db75c75e --- /dev/null +++ b/src/Interpreters/ASTRewriters/CollectQueryStoragesAction.h @@ -0,0 +1,74 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +class CollectQueryStoragesAction : public EmptyASTDepthFirstVisitAction +{ +public: + using Result = std::vector; + explicit CollectQueryStoragesAction(ContextPtr context_) : context(context_) {} + ~CollectQueryStoragesAction() override = default; + + ASTs collectChildren(const ASTPtr & ast) override + { + if (!ast) + return {}; + ASTs children; + if (const auto * union_select_ast = ast->as()) + { + children = union_select_ast->list_of_selects->children; + } + else if (const auto * select_ast = ast->as()) + { + if (const auto * left_table_expr = getTableExpression(*select_ast, 0)) + children.emplace_back(left_table_expr->clone()); + if (const auto * right_table_expr = getTableExpression(*select_ast, 1)) + children.emplace_back(right_table_expr->clone()); + } + else if (const auto * table_expr_ast = ast->as()) + { + if (table_expr_ast->subquery) + children.emplace_back(table_expr_ast->subquery); + } + else if (const auto * subquery_ast = ast->as()) + { + children = subquery_ast->children; + } + return children; + } + + void visit(const ASTPtr & ast) + { + if (const auto * table_expr_ast = ast->as()) + { + if (table_expr_ast->database_and_table_name) + { + auto * table_ident = table_expr_ast->database_and_table_name->as(); + auto table_id = context->resolveStorageID(table_ident->getTableId()); + auto storage = DatabaseCatalog::instance().getTable(table_id, context); + if (storage) + storages.emplace_back(storage); + } + else if (table_expr_ast->table_function) + { + storages.push_back(context->getQueryContext()->executeTableFunction(table_expr_ast->table_function)); + } + } + } + Result getResult() { return storages; } +private: + ContextPtr context; + std::vector storages; +}; +} diff --git a/src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.cpp b/src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.cpp new file mode 100644 index 000000000000..87844036d8b0 --- /dev/null +++ b/src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.cpp @@ -0,0 +1,260 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int AMBIGUOUS_COLUMN_NAME; + extern const int LOGICAL_ERROR; +} + +void CollectRequiredColumnsAction::beforeVisitChild(const ASTPtr & /*ast*/) {} + +ASTs CollectRequiredColumnsAction::collectChildren(const ASTPtr & ast) +{ + if (!ast) + return {}; + if (const auto * function_ast = ast->as()) + { + if (Poco::toLower(function_ast->name) == "count") + return {}; + return function_ast->arguments->children; + } + else if (const auto * select_ast = ast->as()) + { + LOG_TRACE(&Poco::Logger::get("CollectRequiredColumnsAction"), "{} match select {}", __LINE__, queryToString(ast)); + ASTs children; + children.emplace_back(select_ast->select()); + children.emplace_back(select_ast->where()); + children.emplace_back(select_ast->groupBy()); + children.emplace_back(select_ast->orderBy()); + return children; + } + else if (const auto * expr_list_ast = ast->as()) + { + return expr_list_ast->children; + } + else if (const auto * join_ast = ast->as()) + { + ASTs children; + if (join_ast->using_expression_list) + children.emplace_back(join_ast->using_expression_list); + + if (join_ast->on_expression) + children.emplace_back(join_ast->on_expression); + return children; + } + else if (const auto * orderby_ast = ast->as()) + { + return orderby_ast->children; + } + else if (const auto * ident_ast = ast->as()) + { + ASTs children; + if (alias_asts.count(ident_ast->name())) + children.emplace_back(alias_asts[ident_ast->name()]); + return children; + } + + LOG_TRACE(&Poco::Logger::get("CollectRequiredColumnsAction"), "{} unknow ast({}) : {}", __LINE__, ast->getID(), queryToString(ast)); + return {}; +} + +void CollectRequiredColumnsAction::afterVisitChild(const ASTPtr & /*ast*/) {} + +void CollectRequiredColumnsAction::visit(const ASTPtr & ast) +{ + if (!ast) + return; + + if (const auto * function_ast = ast->as()) + { + visit(function_ast); + } + else if (const auto * ident_ast = ast->as()) + { + visit(ident_ast); + } + else if (const auto * asterisk_ast = ast->as()) + { + visit(asterisk_ast); + } + else if (const auto * qualified_asterisk_ast = ast->as()) + { + visit(qualified_asterisk_ast); + } +} + +void CollectRequiredColumnsAction::visit(const ASTFunction * function_ast) +{ + if (Poco::toLower(function_ast->name) != "count") + return; + + if (!final_result.required_columns[0].empty()) + return; + + const auto & table = tables[0]; + for (const auto & col : table.columns) + { + if (added_names.count(col.name)) + continue; + ColumnWithDetailNameAndType column_metadta + = {.full_name = table.table.alias.empty() ? col.name : table.table.alias + "." + col.name, + .short_name = col.name, + .alias_name = "", + .type = col.type}; + final_result.required_columns[0].push_back(column_metadta); + added_names.insert(col.name); + } +} + +void CollectRequiredColumnsAction::visit(const ASTIdentifier * ident_ast) +{ + if (alias_asts.count(ident_ast->name())) + return; + if (auto best_pos = IdentifierSemantic::chooseTableColumnMatch(*ident_ast, tables, false)) + { + bool found = false; + if (*best_pos < tables.size()) + { + for (const auto & col : tables[*best_pos].columns) + { + if (col.name == ident_ast->shortName()) + { + if (added_names.count(ident_ast->name())) + { + continue; + } + ColumnWithDetailNameAndType column_metadta = { + .full_name = ident_ast->name(), + .short_name = ident_ast->shortName(), + .alias_name = ident_ast->tryGetAlias(), + .type = col.type + }; + final_result.required_columns[*best_pos].push_back(column_metadta); + found = true; + added_names.insert(ident_ast->name()); + if (!ident_ast->alias.empty()) + alias_names.insert(ident_ast->alias); + break; + } + } + } + else + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid table pos: {} for ident: {}", *best_pos, queryToString(*ident_ast)); + } + if (!found) + { + LOG_TRACE(&Poco::Logger::get("CollectRequiredColumnsAction"), "Not found match column for {} - {} ", queryToString(*ident_ast), ident_ast->name()); + } + } + else + { + if (!alias_names.count(ident_ast->name())) + { + throw Exception( + ErrorCodes::AMBIGUOUS_COLUMN_NAME, + "Position of identifier {} can't be deteminated. full_name={}, short_name={}, alias={}", + queryToString(*ident_ast), + ident_ast->name(), + ident_ast->shortName(), + ident_ast->alias); + } + } + +} + +void CollectRequiredColumnsAction::visit(const ASTAsterisk * /*asterisk_ast*/) +{ + for (size_t i = 0; i < tables.size(); ++i) + { + const auto & table_with_cols = tables[i]; + const auto & table = table_with_cols.table; + const auto & cols = table_with_cols.columns; + auto & required_cols = final_result.required_columns[i]; + String qualifier; + if (!table.alias.empty()) + qualifier = table.alias; + else if (table.table.empty()) + qualifier = table.table; + + for (const auto & col : cols) + { + bool has_exists = false; + for (const auto & added_col : required_cols) + { + if (added_col.short_name == col.name) + { + has_exists = true; + break; + } + } + if (has_exists) + continue; + ColumnWithDetailNameAndType to_add_col = { + .full_name = qualifier.empty() ? col.name : qualifier + "." + col.name, + .short_name = col.name, + .type = col.type + }; + required_cols.emplace_back(to_add_col); + } + } +} + +void CollectRequiredColumnsAction::visit(const ASTQualifiedAsterisk * qualified_asterisk) +{ + const auto & ident = qualified_asterisk->children[0]; + DatabaseAndTableWithAlias db_and_table(ident); + for (size_t i = 0; i < tables.size(); ++i) + { + const auto & table_with_cols = tables[i]; + const auto & table = table_with_cols.table; + const auto & cols = table_with_cols.columns; + auto & required_cols = final_result.required_columns[i]; + if (!db_and_table.satisfies(table, true)) + continue; + + String qualifier = queryToString(ident); + + for (const auto & col : cols) + { + bool has_exists = false; + for (const auto & added_col : required_cols) + { + if (added_col.short_name == col.name) + { + has_exists = true; + break; + } + } + if (has_exists) + continue; + ColumnWithDetailNameAndType to_add_col = { + .full_name = qualifier.empty() ? col.name : qualifier + "." + col.name, + .short_name = col.name, + .type = col.type + }; + required_cols.emplace_back(to_add_col); + } + } + +} + +} diff --git a/src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.h b/src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.h new file mode 100644 index 000000000000..6353e46502eb --- /dev/null +++ b/src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.h @@ -0,0 +1,59 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ +class CollectRequiredColumnsAction : public EmptyASTDepthFirstVisitAction +{ +public: + struct Result + { + std::vector required_columns; + }; + + explicit CollectRequiredColumnsAction(const TablesWithColumns & tables_, const std::map alias_asts_ = {}) + : tables(tables_), alias_asts(alias_asts_) + { + for (size_t i = 0; i < tables.size(); ++i) + { + final_result.required_columns.emplace_back(ColumnWithDetailNameAndTypes{}); + } + } + ~CollectRequiredColumnsAction() override = default; + + ASTs collectChildren(const ASTPtr & ast) override; + void beforeVisitChild(const ASTPtr & ast) override; + void afterVisitChild(const ASTPtr & ast) override; + void visit(const ASTPtr & ast); + + Result getResult() const { return final_result; } + +private: + const TablesWithColumns & tables; + Result final_result; + std::set added_names; // not need to filled by outside; + std::set alias_names; + std::map alias_asts; // asts with alias in the select expression list + + void visit(const ASTFunction * function_ast); + void visit(const ASTIdentifier * ident_ast); + void visit(const ASTAsterisk * asterisk_ast); + void visit(const ASTQualifiedAsterisk * qualified_asterisk); +}; +} diff --git a/src/Interpreters/ASTRewriters/IdentRenameRewriteAction.cpp b/src/Interpreters/ASTRewriters/IdentRenameRewriteAction.cpp new file mode 100644 index 000000000000..2294d200b254 --- /dev/null +++ b/src/Interpreters/ASTRewriters/IdentRenameRewriteAction.cpp @@ -0,0 +1,323 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +ASTs IdentifierRenameAction::collectChildren(const ASTPtr & ast) +{ + if (!ast) + return {}; + if (const auto * function_ast = ast->as()) + { + return function_ast->arguments->children; + } + else if (const auto * expr_list_ast = ast->as()) + { + return expr_list_ast->children; + } + else if (const auto * orderby_ast = ast->as()) + { + ASTs children; + children.emplace_back(orderby_ast->collation); + children.emplace_back(orderby_ast->fill_from); + children.emplace_back(orderby_ast->fill_to); + children.emplace_back(orderby_ast->fill_step); + return children; + } + else if (const auto * ident_ast = ast->as()) + { + return {}; + } + else if (const auto * asterisk_ast = ast->as()) + { + return {}; + } + else if (const auto * qualified_asterisk_ast = ast->as()) + { + return {}; + } + else + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknow ast type {}. {}", ast->getID(), queryToString(ast)); + } +} +void IdentifierRenameAction::beforeVisitChildren(const ASTPtr & ast) +{ + frames.pushFrame(ast); +} + +void IdentifierRenameAction::afterVisitChild(const ASTPtr & /*ast*/) +{ + auto child_result = frames.getTopFrame()->result_ast; + frames.popFrame(); + assert(frames.size() != 0); + frames.getTopFrame()->children_results.emplace_back(child_result); +} + +void IdentifierRenameAction::visit(const ASTPtr & ast) +{ + assert(!frames.empty()); + if (!ast) + return; + if (const auto * ident_ast = ast->as()) + { + visit(ident_ast); + } + else if (ast->as()) + { + auto frame = frames.getTopFrame(); + frame->result_ast = frame->original_ast->clone(); + auto * result_function_ast = frame->result_ast->as(); + result_function_ast->arguments->children = frame->children_results; + } + else if (ast->as()) + { + auto frame = frames.getTopFrame(); + frame->result_ast = frame->original_ast->clone(); + auto * result_expr_list_ast = frame->result_ast->as(); + result_expr_list_ast->children = frame->children_results; + } + else if (ast->as()) + { + auto frame = frames.getTopFrame(); + frame->result_ast = frame->original_ast->clone(); + auto * result_orderby_ast = frame->result_ast->as(); + result_orderby_ast->collation = frame->children_results[0]; + result_orderby_ast->fill_from = frame->children_results[1]; + result_orderby_ast->fill_to = frame->children_results[2]; + result_orderby_ast->fill_step = frame->children_results[3]; + } + else + { + auto frame = frames.getTopFrame(); + if (frame->original_ast) + frame->result_ast = frame->original_ast->clone(); + } + +} + +void IdentifierRenameAction::visit(const ASTIdentifier * ident_ast) +{ + auto frame = frames.getTopFrame(); + const auto & name = ident_ast->name(); + auto iter = renamed_idents.find(name); + if (iter == renamed_idents.end()) + frame->result_ast = ident_ast->clone(); + else + { + auto result_ast = std::make_shared(iter->second); + result_ast->alias = ident_ast->tryGetAlias(); + frame->result_ast = result_ast; + } +} + +ASTs MakeFunctionColumnAliasAction::collectChildren(const ASTPtr & ast) +{ + if (!ast) + return {}; + if (const auto * expr_list_ast = ast->as()) + { + auto prev_frame = frames.getTopFrame(); + if (!prev_frame || prev_frame->original_ast->as()) + { + LOG_TRACE(&Poco::Logger::get("MakeFunctionColumnAliasAction"), "is expr list from select. {}", queryToString(ast)); + return expr_list_ast->children; + } + return {}; + } + else if (const auto * table_expr = ast->as()) + { + ASTs children; + if (table_expr->subquery) + { + children.emplace_back(table_expr->subquery); + } + return children; + } + else if (const auto * subquery = ast->as()) + { + ASTs children; + children.emplace_back(subquery->children[0]); + return children; + } + else if (const auto * select_ast = ast->as()) + { + ASTs children; + children.emplace_back(select_ast->select()); + + if (const auto * left_table_expr = getTableExpression(*select_ast, 0)) + { + children.emplace_back(left_table_expr->clone()); + } + + if (const auto * right_table_expr = getTableExpression(*select_ast, 1)) + { + children.emplace_back(right_table_expr->clone()); + } + return children; + } + else if (const auto * select_with_union_ast = ast->as()) + { + return select_with_union_ast->list_of_selects->children; + } + return {}; +} + +void MakeFunctionColumnAliasAction::beforeVisitChildren(const ASTPtr & ast) +{ + frames.pushFrame(ast); +} + +void MakeFunctionColumnAliasAction::afterVisitChild(const ASTPtr & /*ast*/) +{ + auto child_result = frames.getTopFrame()->result_ast; + frames.popFrame(); + assert(frames.size() != 0); + frames.getTopFrame()->children_results.emplace_back(child_result); +} + +void MakeFunctionColumnAliasAction::visit(const ASTPtr & ast) +{ + if (const auto * function_ast = ast->as()) + { + visit(function_ast); + } + else if (const auto * expr_list_ast = ast->as()) + { + visit(expr_list_ast); + } + else if (const auto * table_expr = ast->as()) + { + visit(table_expr); + } + else if (const auto * subquery = ast->as()) + { + visit(subquery); + } + else if (const auto * select_ast = ast->as()) + { + visit(select_ast); + } + else if (const auto * select_with_union_ast = ast->as()) + { + visit(select_with_union_ast); + } + else + { + frames.getTopFrame()->result_ast = frames.getTopFrame()->original_ast; + } +} + +void MakeFunctionColumnAliasAction::visit(const ASTFunction * function_ast) +{ + auto frame = frames.getTopFrame(); + if (!function_ast->tryGetAlias().empty()) + { + frame->result_ast = function_ast->clone(); + return; + } + auto iter = functions_alias_id->find(function_ast->name); + size_t id = 0; + if (iter == functions_alias_id->end()) + { + (*functions_alias_id)[function_ast->name] = 0; + } + else + { + id = iter->second++; + } + frame->result_ast = function_ast->clone(); + frame->result_ast->setAlias(function_ast->name + "_" + std::to_string(id)); + LOG_TRACE(&Poco::Logger::get("MakeFunctionColumnAliasAction"), "make alias for function: {}", queryToString(frame->result_ast)); +} + +void MakeFunctionColumnAliasAction::visit(const ASTExpressionList * expr_list_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = expr_list_ast->clone(); + auto * result_expr_ast = frame->result_ast->as(); + auto prev_frame = frames.getPrevFrame(); + if (!prev_frame || prev_frame->original_ast->as()) + { + LOG_TRACE(&Poco::Logger::get("MakeFunctionColumnAliasAction"), "ASTExpressionList used rewrite children. size: {}. {}", frame->children_results.size(), queryToString(*expr_list_ast)); + result_expr_ast->children = frame->children_results; + } +} + + +void MakeFunctionColumnAliasAction::visit(const ASTTableExpression * table_expr_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = table_expr_ast->clone(); + auto * result_table_expr_ast = frame->result_ast->as(); + if (!frame->children_results.empty()) + { + result_table_expr_ast->subquery = frame->children_results[0]; + } +} + +void MakeFunctionColumnAliasAction::visit(const ASTSubquery * subquery_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = subquery_ast->clone(); + auto * result_subquery_ast = frame->result_ast->as(); + result_subquery_ast->children = frame->children_results; +} + +void MakeFunctionColumnAliasAction::visit(const ASTSelectQuery * select_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = select_ast->clone(); + auto * result_select_ast = frame->result_ast->as(); + LOG_TRACE(&Poco::Logger::get("MakeFunctionColumnAliasAction"), "ASTSelectQuery: {}", queryToString(*select_ast)); + LOG_TRACE(&Poco::Logger::get("MakeFunctionColumnAliasAction"), "children results size:{}", frame->children_results.size()); + + result_select_ast->setExpression(ASTSelectQuery::Expression::SELECT, frame->children_results[0]->clone()); + if (frame->children_results.size() == 2) + { + ASTBuildUtil::updateSelectQueryTables(result_select_ast, + ASTBuildUtil::createTablesInSelectQueryElement(frame->children_results[1]->as())->as()); + } + else if (frame->children_results.size() == 3) + { + ASTBuildUtil::updateSelectQueryTables( + result_select_ast, + ASTBuildUtil::createTablesInSelectQueryElement(frame->children_results[1]->as()) + ->as(), + ASTBuildUtil::createTablesInSelectQueryElement( + frame->children_results[2]->as(), select_ast->join()->table_join) + ->as()); + } +} + +void MakeFunctionColumnAliasAction::visit(const ASTSelectWithUnionQuery * select_with_union_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = select_with_union_ast->clone(); + auto * result_select_ast = frame->result_ast->as(); + result_select_ast->list_of_selects->children = frame->children_results; +} + +} diff --git a/src/Interpreters/ASTRewriters/IdentRenameRewriteAction.h b/src/Interpreters/ASTRewriters/IdentRenameRewriteAction.h new file mode 100644 index 000000000000..56a4e78fd1f0 --- /dev/null +++ b/src/Interpreters/ASTRewriters/IdentRenameRewriteAction.h @@ -0,0 +1,90 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ + +class IdentifierRenameAction : public EmptyASTDepthFirstVisitAction +{ +public: + using Result = ASTPtr; + struct VisitFrame + { + explicit VisitFrame(const ASTPtr & ast) : original_ast(ast) {} + const ASTPtr & original_ast; + ASTPtr result_ast; + ASTs children_results; + }; + + explicit IdentifierRenameAction(ContextPtr context_, const std::map & renamed_idents_) + : context(context_) + , renamed_idents(renamed_idents_) + {} + + ~IdentifierRenameAction() override = default; + + ASTs collectChildren(const ASTPtr & ast) override; + void beforeVisitChildren(const ASTPtr & ast) override; + void afterVisitChild(const ASTPtr & ast) override; + void visit(const ASTPtr & ast); + + ASTPtr getResult() + { + assert(frames.size() == 1); + return frames.getTopFrame()->result_ast; + } + +private: + ContextPtr context; + std::map renamed_idents; + SimpleVisitFrameStack<> frames; + + void visit(const ASTIdentifier * ident_ast); +}; + +class MakeFunctionColumnAliasAction : public EmptyASTDepthFirstVisitAction +{ +public: + using Result = ASTPtr; + explicit MakeFunctionColumnAliasAction(std::map * functions_alias_id_ = nullptr) : functions_alias_id(functions_alias_id_) + { + if (!functions_alias_id) + functions_alias_id = &local_functions_alias_id; + } + ~MakeFunctionColumnAliasAction() override = default; + + ASTs collectChildren(const ASTPtr & ast) override; + void beforeVisitChildren(const ASTPtr & ast) override; + void afterVisitChild(const ASTPtr & ast) override; + void visit(const ASTPtr & ast); + + ASTPtr getResult() + { + assert(frames.size() == 1); + return frames.getTopFrame()->result_ast; + } +private: + std::map * functions_alias_id; + std::map local_functions_alias_id; + SimpleVisitFrameStack<> frames; + + void visit(const ASTFunction * function_ast); + void visit(const ASTExpressionList * expr_list_ast); + void visit(const ASTTableExpression * table_expr_ast); + void visit(const ASTSubquery * subquery_ast); + void visit(const ASTSelectQuery * select_ast); + void visit(const ASTSelectWithUnionQuery * select_with_union_ast); +}; +} diff --git a/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp b/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp new file mode 100644 index 000000000000..7910ee1b94bd --- /dev/null +++ b/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp @@ -0,0 +1,69 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +ASTs IdentifiterQualiferRemoveAction::collectChildren(const ASTPtr & ast) +{ + if (!ast) + return {}; + ASTs children; + if (const auto * function_ast = ast->as()) + { + children = function_ast->arguments->children; + } + return children; +} + +void IdentifiterQualiferRemoveAction::visit(const ASTPtr & ast) +{ + if (const auto * function_ast = ast->as()) + { + auto frame = frames.getTopFrame(); + auto result_function_ast = std::make_shared(); + result_function_ast->name = function_ast->name; + result_function_ast->arguments = std::make_shared(); + result_function_ast->arguments->children = frame->children_results; + frame->result_ast = result_function_ast; + result_function_ast->alias = function_ast->tryGetAlias(); + } + else if (const auto * literal_ast = ast->as()) + { + auto frame = frames.getTopFrame(); + frame->result_ast = frame->original_ast->clone(); + } + else if (const auto * ident_ast = ast->as()) + { + auto frame = frames.getTopFrame(); + frame->result_ast = std::make_shared(ident_ast->shortName()); + auto * result_ast = frame->result_ast->as(); + result_ast->alias = ident_ast->tryGetAlias(); + } + else + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid ast({}): {}", ast->getID(), queryToString(ast)); + } +} + +void IdentifiterQualiferRemoveAction::beforeVisitChildren(const ASTPtr & ast) +{ + frames.pushFrame(ast); +} + +void IdentifiterQualiferRemoveAction::afterVisitChild(const ASTPtr & /*ast*/) +{ + auto result_ast = frames.getTopFrame()->result_ast; + frames.popFrame(); + frames.getTopFrame()->children_results.emplace_back(result_ast); +} +} diff --git a/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.h b/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.h new file mode 100644 index 000000000000..da17c7d1ad21 --- /dev/null +++ b/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.h @@ -0,0 +1,24 @@ +#pragma once +#include +#include +namespace DB +{ +class IdentifiterQualiferRemoveAction : public EmptyASTDepthFirstVisitAction +{ +public: + using Result = ASTPtr; + ~IdentifiterQualiferRemoveAction() override = default; + + ASTs collectChildren(const ASTPtr & ast) override; + void beforeVisitChildren(const ASTPtr & ast) override; + void afterVisitChild(const ASTPtr & ast) override; + void visit(const ASTPtr & ast); + + Result getResult() + { + return frames.getTopFrame()->result_ast; + } +private: + SimpleVisitFrameStack<> frames; +}; +} diff --git a/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.cpp b/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.cpp new file mode 100644 index 000000000000..3f298fbccc93 --- /dev/null +++ b/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.cpp @@ -0,0 +1,283 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +void NestedJoinQueryRewriteAction::beforeVisitChildren(const ASTPtr & ast) +{ + frames.pushFrame(ast); +} + +void NestedJoinQueryRewriteAction::afterVisitChild(const ASTPtr & /*ast*/) +{ + auto child_result = frames.getTopFrame()->result_ast; + frames.popFrame(); + assert(frames.size() != 0); + frames.getTopFrame()->children_results.emplace_back(child_result); +} + +ASTs NestedJoinQueryRewriteAction::collectChildren(const ASTPtr & ast) +{ + + ASTs children; + if (!ast) + return children; + + if (const auto * table_expr_ast = ast->as()) + { + if (table_expr_ast->subquery) + children.emplace_back(table_expr_ast->subquery); + } + else if (const auto * subquery_ast = ast->as()) + { + children.emplace_back(subquery_ast->children[0]); + } + else if (const auto * select_ast = ast->as()) + { + if (const auto * left_table_expr = getTableExpression(*select_ast, 0)) + children.emplace_back(left_table_expr->clone()); + if (const auto * right_table_expr = getTableExpression(*select_ast, 1)) + children.emplace_back(right_table_expr->clone()); + } + else if (const auto * select_with_union_ast = ast->as()) + { + children = select_with_union_ast->list_of_selects->children; + } + return children; +} + +void NestedJoinQueryRewriteAction::visit(const ASTPtr & ast) +{ + if (const auto * select_with_union_ast = ast->as()) + { + visit(select_with_union_ast); + } + else if (const auto * select_ast = ast->as()) + { + visit(select_ast); + } + else if (const auto * table_expr_ast = ast->as()) + { + visit(table_expr_ast); + } + else if (const auto * subquery_ast = ast->as()) + { + visit(subquery_ast); + } + else + { + auto frame = frames.getTopFrame(); + frame->result_ast = frame->original_ast; + } +} + +void NestedJoinQueryRewriteAction::visit(const ASTSelectWithUnionQuery * select_with_union_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = select_with_union_ast->clone(); + auto * result_select_ast = frame->result_ast->as(); + result_select_ast->list_of_selects->children = frame->children_results; +} + +void NestedJoinQueryRewriteAction::visit(const ASTSelectQuery * select_ast) +{ + auto frame = frames.getTopFrame(); + if (frame->children_results.size() == 1) + { + frame->result_ast = select_ast->clone(); + auto * result_select_ast = frame->result_ast->as(); + auto * table_expr_ast = frame->children_results[0]->as(); + if (table_expr_ast->subquery) + { + auto tables_with_columns = getDatabaseAndTablesWithColumns(getTableExpressions(*select_ast), context, true, true); + String table_alias = tables_with_columns[0].table.alias; + if (table_alias.empty() && !tables_with_columns[0].table.table.empty()) + table_alias = tables_with_columns[0].table.table; + + table_expr_ast->subquery->as()->setAlias(table_alias); + } + auto table_element = ASTBuildUtil::createTablesInSelectQueryElement(table_expr_ast); + ASTBuildUtil::updateSelectQueryTables( + result_select_ast, table_element->as()); + renameSelectQueryIdentifiers(result_select_ast); + } + else if (frame->children_results.size() == 2) + { + //auto nested_select_ast_ref = select_ast->clone(); + auto nested_select_ast_ref = std::make_shared(); + auto * nested_select_ast = nested_select_ast_ref->as(); + ASTBuildUtil::updateSelectQueryTables( + nested_select_ast, + ASTBuildUtil::createTablesInSelectQueryElement(frame->children_results[0]->as()) + ->as(), + ASTBuildUtil::createTablesInSelectQueryElement( + frame->children_results[1]->as(), select_ast->join()->table_join) + ->as()); + auto tables_with_columns = getDatabaseAndTablesWithColumns(getTableExpressions(*select_ast), context, true, true); + CollectRequiredColumnsAction action(tables_with_columns); + ASTDepthFirstVisitor visitor(action, select_ast->clone()); + auto tables_required_cols = visitor.visit().required_columns; + auto & right_required_cols = tables_required_cols[1]; + ColumnWithDetailNameAndType::makeAliasByFullName(right_required_cols); + + auto nested_select_expr_list = std::make_shared(); + for (const auto & col : tables_required_cols[0]) + { + auto ident = std::make_shared(col.splitedFullName()); + if (!col.alias_name.empty()) + ident->alias = col.alias_name; + else if (!col.short_name.empty() && col.short_name != col.full_name) + ident->alias = col.short_name; + nested_select_expr_list->children.emplace_back(ident); + } + for (const auto & col : tables_required_cols[1]) + { + auto ident = std::make_shared(col.splitedFullName()); + ident->alias = col.alias_name; + if (ident->alias.empty()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Alias name is expected for {}", col.full_name); + nested_select_expr_list->children.emplace_back(ident); + } + nested_select_ast->setExpression(ASTSelectQuery::Expression::SELECT, nested_select_expr_list); + if (auto where_expr = select_ast->where()) + nested_select_ast->setExpression(ASTSelectQuery::Expression::WHERE, where_expr->clone()); + // replace all columns with the renamed map from inner sub-queries + renameSelectQueryIdentifiers(nested_select_ast); + + // add new columns to be renamed for the outside queries. + for (const auto & col : right_required_cols) + { + if (col.full_name != col.short_name) + { + if (!columns_alias.count(col.full_name)) + { + columns_alias[col.full_name] = col.alias_name; + } + } + } + + frame->result_ast = select_ast->clone(); + auto * result_select_ast = frame->result_ast->as(); + auto select_with_union_ast = std::make_shared(); + select_with_union_ast->list_of_selects = std::make_shared(); + select_with_union_ast->list_of_selects->children.emplace_back(nested_select_ast_ref); + String table_alias = tables_with_columns[0].table.alias; + if (table_alias.empty() && !tables_with_columns[0].table.table.empty()) + { + table_alias = tables_with_columns[0].table.table; + } + ASTBuildUtil::updateSelectQueryTables( + result_select_ast, + ASTBuildUtil::createTablesInSelectQueryElement(select_with_union_ast.get(), table_alias)->as()); + + renameSelectQueryIdentifiers(result_select_ast); + } + + clearRenameAlias(select_ast); +} + +void NestedJoinQueryRewriteAction::visit(const ASTTableExpression * table_expr_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = table_expr_ast->clone(); + auto * result_table_expr_ast = frame->result_ast->as(); + if (!frame->children_results.empty()) + result_table_expr_ast->subquery = frame->children_results[0]; +} + +void NestedJoinQueryRewriteAction::visit(const ASTSubquery * subquery_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = subquery_ast->clone(); + auto * result_subquery_ast = frame->result_ast->as(); + result_subquery_ast->children = frame->children_results; +} + +void NestedJoinQueryRewriteAction::updateIdentNames(ASTSelectQuery * ast, ASTSelectQuery::Expression index) +{ + auto expression = ast->getExpression(index, false); + if (!expression) + return; + IdentifierRenameAction action(context, columns_alias); + ASTDepthFirstVisitor visitor(action, expression); + auto result = visitor.visit(); + ast->setExpression(index, std::move(result)); +} + +void NestedJoinQueryRewriteAction::renameSelectQueryIdentifiers(ASTSelectQuery * select_ast) +{ + updateIdentNames(select_ast, ASTSelectQuery::Expression::SELECT); + updateIdentNames(select_ast, ASTSelectQuery::Expression::WHERE); + updateIdentNames(select_ast, ASTSelectQuery::Expression::GROUP_BY); + updateIdentNames(select_ast, ASTSelectQuery::Expression::ORDER_BY); + updateIdentNames(select_ast, ASTSelectQuery::Expression::PREWHERE); + + if (select_ast->join()) + { + auto * join = select_ast->join()->table_join->as(); + if (join->using_expression_list) + { + IdentifierRenameAction action(context, columns_alias); + ASTDepthFirstVisitor visitor(action, join->using_expression_list); + join->using_expression_list = visitor.visit(); + } + if (join->on_expression) + { + IdentifierRenameAction action(context, columns_alias); + ASTDepthFirstVisitor visitor(action, join->on_expression); + join->on_expression = visitor.visit(); + } + } +} +void NestedJoinQueryRewriteAction::clearRenameAlias(const ASTSelectQuery * select_ast) +{ + auto re_alias = collectAliasColumns(select_ast->select()->as()); + for (const auto & alias : re_alias) + { + columns_alias.erase(alias.first); + } +} + +std::map NestedJoinQueryRewriteAction::collectAliasColumns(const ASTExpressionList * select_expression) +{ + std::map res; + for (const auto & child : select_expression->children) + { + if (auto * ident = child->as()) + { + auto alias = ident->tryGetAlias(); + if (!alias.empty()) + { + res[ident->name()] = alias; + } + } + } + return res; +} +} diff --git a/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.h b/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.h new file mode 100644 index 000000000000..f8a74dc66017 --- /dev/null +++ b/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.h @@ -0,0 +1,66 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ +/** + * An motivate example for this rewriter + * select a, n from ( select l.a, sum(r.b) as n from t1 as l join t2 as r on l.a = r.a group by l.a ) where n > 100 + * + * Rewrite the above sql to eliminate 'r.b', rename it as r_b. + * select a, n from ( + * select l.a, sum(r_b) as n from ( + * select l.a as a, r.b as r_b from t1 as l join t2 as r on l.a = r.a + * )as l group by l.a + * ) where n > 100 + * + * Make the join into a nested sub-query. + * If a column comes from the right table without an alias, make one alias for it. + * If a column is the result of a function without an alias, make one alias for it. + * This is for the convinience of making the join/groupby action could be run in distruted mode. + */ + +class NestedJoinQueryRewriteAction : public EmptyASTDepthFirstVisitAction +{ +public: + using Result = ASTPtr; + explicit NestedJoinQueryRewriteAction(ContextPtr context_) : context(context_) {} + ~NestedJoinQueryRewriteAction() override = default; + + ASTs collectChildren(const ASTPtr & ast) override; + void beforeVisitChildren(const ASTPtr & ast) override; + void afterVisitChild(const ASTPtr & ast) override; + void visit(const ASTPtr & ast); + + ASTPtr getResult() + { + assert(frames.size() == 1); + return frames.getTopFrame()->result_ast; + } + +private: + SimpleVisitFrameStack<> frames; + ContextPtr context; + std::map columns_alias; + + void visit(const ASTSelectWithUnionQuery * select_with_union_ast); + void visit(const ASTSelectQuery * select_ast); + void visit(const ASTTableExpression * table_expr_ast); + void visit(const ASTSubquery * subquery_ast); + + void updateIdentNames(ASTSelectQuery * ast, ASTSelectQuery::Expression index); + void renameSelectQueryIdentifiers(ASTSelectQuery * select_ast); + void clearRenameAlias(const ASTSelectQuery * select_ast); + + static std::map collectAliasColumns(const ASTExpressionList * select_expression); +}; +} diff --git a/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp b/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp new file mode 100644 index 000000000000..e0ec51530b7a --- /dev/null +++ b/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp @@ -0,0 +1,369 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +StageQueryDistributedAggregationRewriteAction::StageQueryDistributedAggregationRewriteAction( + ContextPtr context_, ShuffleTableIdGeneratorPtr id_gen_) + : context(context_), id_generator(id_gen_ ? id_gen_ : std::make_shared()) +{ +} + +void StageQueryDistributedAggregationRewriteAction::beforeVisitChildren(const ASTPtr & ast) +{ + frames.pushFrame(ast); +} + +void StageQueryDistributedAggregationRewriteAction::afterVisitChild(const ASTPtr & /*ast*/) +{ + auto frame = frames.getTopFrame(); + auto result_ast = frame->result_ast; + ASTs upstream_queries; + frame->mergeChildrenUpstreamQueries(); + if (frame->upstream_queries.size() == 1) + upstream_queries = frame->upstream_queries[0]; + frames.popFrame(); + frame = frames.getTopFrame(); + frame->children_results.emplace_back(result_ast); + frame->addChildUpstreamQueries(upstream_queries); +} + +ASTs StageQueryDistributedAggregationRewriteAction::collectChildren(const ASTPtr & ast) +{ + if (!ast) + return {}; + ASTs children; + if (const auto * union_select_ast = ast->as()) + { + children = union_select_ast->list_of_selects->children; + } + else if (const auto * select_ast = ast->as()) + { + if (!select_ast->join()) + { + children.emplace_back(getTableExpression(*select_ast, 0)->clone()); + } + else + { + if (ASTAnalyzeUtil::hasAggregationColumn(ast) || ASTAnalyzeUtil::hasGroupBy(ast)) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "There is join in the SELECT with aggregation columns. query: {}", queryToString(ast)); + } + } + } + else if (const auto * table_expr_ast = ast->as()) + { + if (table_expr_ast->subquery) + children.emplace_back(table_expr_ast->subquery); + } + else if (const auto * subquery_ast = ast->as()) + { + children = subquery_ast->children; + } + return children; +} + +void StageQueryDistributedAggregationRewriteAction::visit(const ASTPtr & ast) +{ + if (const auto * select_ast = ast->as()) + { + visit(select_ast); + } + else if (const auto * select_with_union_ast = ast->as()) + { + visit(select_with_union_ast); + } + else if (const auto * table_expr_ast = ast->as()) + { + visit(table_expr_ast); + } + else if (const auto * subquery_ast = ast->as()) + { + visit(subquery_ast); + } + else + { + LOG_TRACE(logger, "Emptry action for ast({}): {}", ast->getID(), queryToString(ast)); + auto frame = frames.getTopFrame(); + frame->result_ast = frame->original_ast->clone(); + } +} + +void StageQueryDistributedAggregationRewriteAction::visit(const ASTSelectWithUnionQuery * union_select_ast) +{ + auto frame = frames.getTopFrame(); + auto result_union_select_ast_ref = union_select_ast->clone(); + auto * result_union_select_ast = result_union_select_ast_ref->as(); + result_union_select_ast->list_of_selects->children = frame->children_results; + + if (frames.size() == 1) + { + frame->mergeChildrenUpstreamQueries(); + frame->result_ast = ASTStageQuery::make(result_union_select_ast_ref, frame->upstream_queries[0]); + } + else + frame->result_ast = result_union_select_ast_ref; +} + +void StageQueryDistributedAggregationRewriteAction::visit(const ASTTableExpression * table_expr_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = table_expr_ast->clone(); + auto * result_ast = frame->result_ast->as(); + if (!frame->children_results.empty()) + { + result_ast->subquery = frame->children_results[0]; + } +} + +void StageQueryDistributedAggregationRewriteAction::visit(const ASTSubquery * subquery_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = subquery_ast->clone(); + auto * result_subquery_ast = frame->result_ast->as(); + result_subquery_ast->children = frame->children_results; +} + +void StageQueryDistributedAggregationRewriteAction::visit(const ASTSelectQuery * select_ast) +{ + auto frame = frames.getTopFrame(); + if (frame->children_results.empty()) // join query + { + StageQueryDistributedJoinRewriteAction action(context, id_generator); + ASTDepthFirstVisitor visitor(action, select_ast->clone()); + auto rewrite_ast = visitor.visit(); + + auto * stage_query = rewrite_ast->as(); + if (!stage_query) + throw Exception(ErrorCodes::LOGICAL_ERROR, "ASTStageQuery is expected. return query is : {}", queryToString(rewrite_ast)); + auto * return_select_ast = stage_query->current_query->as(); + if (!return_select_ast) + throw Exception(ErrorCodes::LOGICAL_ERROR, "ASTSelectQuery is expected. return query is : {}", queryToString(stage_query->current_query)); + + + ASTs upstream_queries; + upstream_queries.insert(upstream_queries.end(), stage_query->upstream_queries.begin(), stage_query->upstream_queries.end()); + frame->upstream_queries = std::vector{upstream_queries}; + frame->result_ast = stage_query->current_query; + } + else + { + if (ASTAnalyzeUtil::hasAggregationColumn(select_ast) || ASTAnalyzeUtil::hasGroupBy(select_ast)) + { + if (!ASTAnalyzeUtil::hasGroupBy(select_ast)) + { + visitSelectQueryWithAggregation(select_ast); + } + else + { + visitSelectQueryWithGroupby(select_ast); + } + } + else + { + frame->result_ast = select_ast->clone(); + auto * result_ast = frame->result_ast->as(); + auto * table_expr_ast = frame->children_results[0]->as(); + if (table_expr_ast->subquery) + { + auto tables_with_columns = getDatabaseAndTablesWithColumns(getTableExpressions(*select_ast), context, true, true); + String table_alias = tables_with_columns[0].table.alias; + if (table_alias.empty() && !tables_with_columns[0].table.table.empty()) + table_alias = tables_with_columns[0].table.table; + + table_expr_ast->subquery->as()->setAlias(table_alias); + } + auto table_element = ASTBuildUtil::createTablesInSelectQueryElement(table_expr_ast); + ASTBuildUtil::updateSelectQueryTables(result_ast, table_element->as()); + } + } + + if (frames.size() == 1) + { + frame->mergeChildrenUpstreamQueries(); + frame->result_ast = ASTStageQuery::make(frame->result_ast, frame->upstream_queries[0]); + } +} + +void StageQueryDistributedAggregationRewriteAction::visitSelectQueryWithAggregation(const ASTSelectQuery * select_ast) +{ + auto frame = frames.getTopFrame(); + auto * rewrite_table_expr = frame->children_results[0]->as(); + if (!rewrite_table_expr) + throw Exception(ErrorCodes::LOGICAL_ERROR, "ASTTableExpression is expected. return query is : {}", queryToString(frame->children_results[0])); + if (rewrite_table_expr->subquery) + { + auto tables_with_columns = getDatabaseAndTablesWithColumns(getTableExpressions(*select_ast), context, true, true); + String table_alias = tables_with_columns[0].table.alias; + if (table_alias.empty() && !tables_with_columns[0].table.table.empty()) + table_alias = tables_with_columns[0].table.table; + + rewrite_table_expr->subquery->as()->setAlias(table_alias); + } + + CollectQueryStoragesAction collect_storage_action(context); + ASTDepthFirstVisitor collect_storage_visitor(collect_storage_action, frame->children_results[0]); + auto storages = collect_storage_visitor.visit(); + + bool all_is_shuffle_storage = true; + if (storages.size() == 2) // It should be a join query in the subquery + { + for (const auto & storage : storages) + { + if (storage->getName() != StorageShuffleJoin::NAME) + { + all_is_shuffle_storage = false; + break; + } + } + + if (all_is_shuffle_storage) + { + auto tables = getDatabaseAndTablesWithColumns(getTableExpressions(*select_ast), context, true, true); + if (tables.size() != 1) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Tables size should be 1"); + + CollectRequiredColumnsAction collect_columns_action(tables); + ASTDepthFirstVisitor collect_columns_visitor(collect_columns_action, select_ast->clone()); + auto required_columns = collect_columns_visitor.visit().required_columns; + + auto insert_query = createShuffleInsert( + TableFunctionLocalShuffle::name, + rewrite_table_expr, + ColumnWithDetailNameAndType::toNamesAndTypesList(required_columns[0]), + nullptr); + + ASTs upstream_queries; + frame->mergeChildrenUpstreamQueries(); + if (!frame->upstream_queries[0].empty()) + upstream_queries.emplace_back(ASTStageQuery::make(insert_query, frame->upstream_queries[0])); + else + upstream_queries.emplace_back(insert_query); + frame->upstream_queries = std::vector{upstream_queries}; + + auto table_function = insert_query->as()->table_function->clone(); + table_function->as()->alias = ASTBuildUtil::getTableExpressionAlias(rewrite_table_expr); + + frame->result_ast = select_ast->clone(); + auto * result_select_ast = frame->result_ast->as(); + ASTBuildUtil::updateSelectQueryTables( + result_select_ast, + ASTBuildUtil::createTablesInSelectQueryElement(table_function->as())->as()); + } + } + else + { + all_is_shuffle_storage = false; + } + + if (!all_is_shuffle_storage) + { + frame->result_ast = select_ast->clone(); + auto * result_ast = frame->result_ast->as(); + ASTBuildUtil::updateSelectQueryTables( + result_ast, ASTBuildUtil::createTablesInSelectQueryElement(rewrite_table_expr)->as()); + + } +} + +void StageQueryDistributedAggregationRewriteAction::visitSelectQueryWithGroupby(const ASTSelectQuery * select_ast) +{ + auto frame = frames.getTopFrame(); + auto * rewrite_table_expr = frame->children_results[0]->as(); + if (!rewrite_table_expr) + throw Exception(ErrorCodes::LOGICAL_ERROR, "ASTTableExpression is expected. return query is : {}", queryToString(frame->children_results[0])); + if (rewrite_table_expr->subquery) + { + auto tables_with_columns = getDatabaseAndTablesWithColumns(getTableExpressions(*select_ast), context, true, true); + String table_alias = tables_with_columns[0].table.alias; + if (table_alias.empty() && !tables_with_columns[0].table.table.empty()) + table_alias = tables_with_columns[0].table.table; + + rewrite_table_expr->subquery->as()->setAlias(table_alias); + } + + auto tables = getDatabaseAndTablesWithColumns(getTableExpressions(*select_ast), context, true, true); + if (tables.size() != 1) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Tables size should be 1"); + + CollectRequiredColumnsAction collect_columns_action(tables); + ASTDepthFirstVisitor collect_columns_visitor(collect_columns_action, select_ast->clone()); + auto required_columns = collect_columns_visitor.visit().required_columns; + + auto insert_query = createShuffleInsert( + TableFunctionLocalShuffle::name, + rewrite_table_expr, + ColumnWithDetailNameAndType::toNamesAndTypesList(required_columns[0]), + select_ast->groupBy()); + + ASTs upstream_queries; + frame->mergeChildrenUpstreamQueries(); + if (!frame->upstream_queries[0].empty()) + upstream_queries.emplace_back(ASTStageQuery::make(insert_query, frame->upstream_queries[0])); + else + upstream_queries.emplace_back(insert_query); + frame->upstream_queries = std::vector{upstream_queries}; + + auto table_function = insert_query->as()->table_function->clone(); + table_function->as()->alias = ASTBuildUtil::getTableExpressionAlias(rewrite_table_expr); + + frame->result_ast = select_ast->clone(); + auto * result_select_ast = frame->result_ast->as(); + ASTBuildUtil::updateSelectQueryTables( + result_select_ast, + ASTBuildUtil::createTablesInSelectQueryElement(table_function->as())->as()); +} + +ASTPtr StageQueryDistributedAggregationRewriteAction::createShuffleInsert( + const String & table_function_name, ASTTableExpression * table_expr, const NamesAndTypesList & table_desc, ASTPtr groupby_clause) +{ + ASTPtr hash_expr; + if (groupby_clause) + { + IdentifiterQualiferRemoveAction remove_qualifier_action; + ASTDepthFirstVisitor remove_qualifier_visitor(remove_qualifier_action, groupby_clause); + hash_expr = remove_qualifier_visitor.visit(); + } + + auto session_id = context->getClientInfo().current_query_id; + auto cluster_name = context->getSettings().use_cluster_for_distributed_shuffle.value; + auto table_id = getNextId(); + + auto table_function + = ASTBuildUtil::createShuffleTableFunction(table_function_name, cluster_name, session_id, table_id, table_desc, hash_expr); + + auto select_query_ref = std::make_shared(); + auto * select_query = select_query_ref->as(); + ASTBuildUtil::updateSelectQueryTables(select_query, table_expr); + select_query->setExpression(ASTSelectQuery::Expression::SELECT, ASTBuildUtil::createSelectExpression(table_desc)); + + return ASTBuildUtil::createTableFunctionInsertSelectQuery(table_function, ASTBuildUtil::wrapSelectQuery(select_query)); +} +} diff --git a/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.h b/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.h new file mode 100644 index 000000000000..97f645e46863 --- /dev/null +++ b/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.h @@ -0,0 +1,71 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +class StageQueryDistributedAggregationRewriteAction : public EmptyASTDepthFirstVisitAction +{ +public: + using Result = ASTPtr; + + struct Frame : public SimpleVisitFrame + { + explicit Frame(const ASTPtr & ast) : SimpleVisitFrame(ast) {} + std::vector upstream_queries; + + void addChildUpstreamQueries(const ASTs & queries) + { + upstream_queries.emplace_back(queries); + } + + void mergeChildrenUpstreamQueries() + { + ASTs result_upstream_queris; + for (auto & queries : upstream_queries) + result_upstream_queris.insert(result_upstream_queris.end(), queries.begin(), queries.end()); + upstream_queries = std::vector{result_upstream_queris}; + } + }; + + explicit StageQueryDistributedAggregationRewriteAction(ContextPtr context_, ShuffleTableIdGeneratorPtr id_gen_ = nullptr); + ~StageQueryDistributedAggregationRewriteAction() override = default; + + ASTs collectChildren(const ASTPtr & ast) override; + void beforeVisitChildren(const ASTPtr & ast) override; + void afterVisitChild(const ASTPtr & ast) override; + + void visit(const ASTPtr & ast); + + Result getResult() { return frames.getTopFrame()->result_ast; } + +private: + ContextPtr context; + ShuffleTableIdGeneratorPtr id_generator; + SimpleVisitFrameStack frames; + Poco::Logger * logger = &Poco::Logger::get("StageQueryDistributedAggregationRewriteAction"); + + + void visit(const ASTSelectWithUnionQuery * union_select_ast); + void visit(const ASTSelectQuery * select_ast); + void visitSelectQueryWithGroupby(const ASTSelectQuery * select_ast); + void visitSelectQueryWithAggregation(const ASTSelectQuery * select_ast); + void visit(const ASTTableExpression * table_expr_ast); + void visit(const ASTSubquery * subquery_ast); + + String getNextId() + { + static const String prefix = "agg_"; + return prefix + std::to_string(id_generator->nextId()); + } + + ASTPtr createShuffleInsert( + const String & table_function_name, ASTTableExpression * table_expr, const NamesAndTypesList & table_desc, ASTPtr groupby_clause); +}; +} diff --git a/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp b/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp new file mode 100644 index 000000000000..40034530db0c --- /dev/null +++ b/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp @@ -0,0 +1,561 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +StageQueryDistributedJoinRewriteAction::StageQueryDistributedJoinRewriteAction(ContextPtr context_, ShuffleTableIdGeneratorPtr id_gen_) + : context(context_) + , id_generator(id_gen_ ? id_gen_ : std::make_shared()) +{} + +void StageQueryDistributedJoinRewriteAction::beforeVisitChildren(const ASTPtr & ast) +{ + frames.pushFrame(ast); +} + +void StageQueryDistributedJoinRewriteAction::afterVisitChild(const ASTPtr & /*ast*/) +{ + auto frame = frames.getTopFrame(); + auto result_ast = frame->result_ast; + ASTs upstream_queries; + frame->mergeChildrenUpstreamQueries(); + if (frame->upstream_queries.size() == 1) + upstream_queries = frame->upstream_queries[0]; + frames.popFrame(); + frame = frames.getTopFrame(); + frame->children_results.emplace_back(result_ast); + frame->addChildUpstreamQueries(upstream_queries); + //frame->upstream_queries.insert(frame->upstream_queries.end(), upstream_queries.begin(), upstream_queries.end()); + +} + +ASTs StageQueryDistributedJoinRewriteAction::collectChildren(const ASTPtr & ast) +{ + if (!ast) + return {}; + + ASTs children; + if (const auto * select_ast = ast->as()) + { + if (ASTAnalyzeUtil::hasAggregationColumn(ast) || ASTAnalyzeUtil::hasGroupBy(ast)) + { + if (select_ast->join()) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "There is join in the SELECT with aggregation columns. query: {}", queryToString(ast)); + } + } + else + { + if (const auto * left_table_expr = getTableExpression(*select_ast, 0)) + children.emplace_back(left_table_expr->clone()); + if (const auto * right_table_expr = getTableExpression(*select_ast, 1)) + children.emplace_back(right_table_expr->clone()); + } + + } + else if (const auto * select_with_union_ast = ast->as()) + { + children = select_with_union_ast->list_of_selects->children; + } + else if (const auto * table_expr_ast = ast->as()) + { + if (table_expr_ast->subquery) + children.emplace_back(table_expr_ast->subquery); + } + else if (const auto * subquery_ast = ast->as()) + { + children.emplace_back(subquery_ast->children[0]); + } + else + { + LOG_TRACE(logger, "Return empty children for ast({}): {}", ast->getID(), queryToString(ast)); + } + return children; +} + +void StageQueryDistributedJoinRewriteAction::visit(const ASTPtr & ast) +{ + if (const auto * select_ast = ast->as()) + { + visit(select_ast); + } + else if (const auto * select_with_union_ast = ast->as()) + { + visit(select_with_union_ast); + } + else if (const auto * table_expr_ast = ast->as()) + { + visit(table_expr_ast); + } + else if (const auto * subquery_ast = ast->as()) + { + visit(subquery_ast); + } + else + { + LOG_TRACE(logger, "Emptry action for ast({}): {}", ast->getID(), queryToString(ast)); + auto frame = frames.getTopFrame(); + frame->result_ast = frame->original_ast->clone(); + } +} + +void StageQueryDistributedJoinRewriteAction::visit(const ASTTableExpression * table_expr_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = table_expr_ast->clone(); + auto * result_ast = frame->result_ast->as(); + if (!frame->children_results.empty()) + { + result_ast->subquery = frame->children_results[0]; + } +} + +void StageQueryDistributedJoinRewriteAction::visit(const ASTSubquery * subquery_ast) +{ + auto frame = frames.getTopFrame(); + frame->result_ast = subquery_ast->clone(); + auto * result_subquery_ast = frame->result_ast->as(); + result_subquery_ast->children = frame->children_results; +} + +void StageQueryDistributedJoinRewriteAction::visit(const ASTSelectWithUnionQuery * select_with_union_ast) +{ + auto frame = frames.getTopFrame(); + auto result_union_select_ast_ref = select_with_union_ast->clone(); + auto * result_union_select_ast = result_union_select_ast_ref->as(); + result_union_select_ast->list_of_selects->children = frame->children_results; + + if (frames.size() == 1) + { + frame->mergeChildrenUpstreamQueries(); + frame->result_ast = ASTStageQuery::make(result_union_select_ast_ref, frame->upstream_queries[0]); + } + else + { + frame->result_ast = result_union_select_ast_ref; + } +} + +void StageQueryDistributedJoinRewriteAction::visit(const ASTSelectQuery * select_ast) +{ + auto frame = frames.getTopFrame(); + if (frame->children_results.empty()) + { + visitSelectQueryWithAggregation(select_ast); + } + else if (frame->children_results.size() == 1) + { + frame->result_ast = select_ast->clone(); + auto * result_ast = frame->result_ast->as(); + auto * table_expr_ast = frame->children_results[0]->as(); + if (table_expr_ast->subquery) + { + auto tables_with_columns = getDatabaseAndTablesWithColumns(getTableExpressions(*select_ast), context, true, true); + String table_alias = tables_with_columns[0].table.alias; + if (table_alias.empty() && !tables_with_columns[0].table.table.empty()) + table_alias = tables_with_columns[0].table.table; + + table_expr_ast->subquery->as()->setAlias(table_alias); + } + auto table_element = ASTBuildUtil::createTablesInSelectQueryElement(table_expr_ast); + ASTBuildUtil::updateSelectQueryTables(result_ast, table_element->as()); + + if (frames.size() == 1) + { + frame->mergeChildrenUpstreamQueries(); + frame->result_ast = ASTStageQuery::make(frame->result_ast, frame->upstream_queries[0]); + } + } + else if (frame->children_results.size() == 2) + { + visitSelectQueryOnJoin(select_ast); + } + else + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid children size."); + } + + if (frames.size() == 1) + { + frame->mergeChildrenUpstreamQueries(); + frame->result_ast = ASTStageQuery::make(frame->result_ast, frame->upstream_queries[0]); + } +} + +void StageQueryDistributedJoinRewriteAction::visitSelectQueryWithAggregation(const ASTSelectQuery * select_ast) +{ + auto frame = frames.getTopFrame(); + StageQueryDistributedAggregationRewriteAction distributed_join_action(context, id_generator); + ASTDepthFirstVisitor distributed_join_visitor(distributed_join_action, select_ast->clone()); + auto rewrite_ast = distributed_join_visitor.visit(); + auto * stage_query = rewrite_ast->as(); + if (!stage_query) + throw Exception(ErrorCodes::LOGICAL_ERROR, "ASTStageQuery is expected. return query is : {}", queryToString(rewrite_ast)); + if (!stage_query->current_query->as()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "ASTSelectQuery is expected. return query is : {}", queryToString(stage_query->current_query)); + + ASTs upstream_queries; + upstream_queries.insert(upstream_queries.end(), stage_query->upstream_queries.begin(), stage_query->upstream_queries.end()); + frame->upstream_queries = std::vector{upstream_queries}; + frame->result_ast = stage_query->current_query; +} + +void StageQueryDistributedJoinRewriteAction::visitSelectQueryOnJoin(const ASTSelectQuery * select_ast) +{ + StageQueryDistributedJoinRewriteAnalyzer join_analyzer(select_ast, context); + auto analyze_result = join_analyzer.analyze(); + if (!analyze_result) // Not rewrite it into shuffle pattern. + { + auto frame = frames.getTopFrame(); + frame->result_ast = select_ast->clone(); + auto * result_ast = frame->result_ast->as(); + ASTBuildUtil::updateSelectQueryTables( + result_ast, + ASTBuildUtil::createTablesInSelectQueryElement(frame->children_results[0]->as(), nullptr) + ->as(), + ASTBuildUtil::createTablesInSelectQueryElement( + frame->children_results[1]->as(), result_ast->join()->table_join) + ->as()); + if (frames.size() == 1) + { + frame->mergeChildrenUpstreamQueries(); + frame->result_ast = ASTStageQuery::make(frame->result_ast, frame->upstream_queries[0]); + } + } + else + { + auto frame = frames.getTopFrame(); + auto current_table_id = getNextTableId(); + auto left_shuffle_table_id = current_table_id + "_left"; + auto right_shuffle_table_id = current_table_id + "_right"; + auto * left_visited_table_expr = frame->children_results[0]->as(); + auto * right_visited_table_expr = frame->children_results[1]->as(); + ASTs upstream_queris; + + auto left_shuffle_query = createShuffleInsertForJoin( + left_shuffle_table_id, left_visited_table_expr, analyze_result->tables_columns[0], analyze_result->tables_hash_keys[0]); + if (frame->upstream_queries[0].empty()) + upstream_queris.emplace_back(left_shuffle_query); + else + upstream_queris.emplace_back(ASTStageQuery::make(left_shuffle_query, frame->upstream_queries[0])); + auto right_shuffle_query = createShuffleInsertForJoin( + right_shuffle_table_id, right_visited_table_expr, analyze_result->tables_columns[1], analyze_result->tables_hash_keys[1]); + if (frame->upstream_queries[1].empty()) + upstream_queris.emplace_back(right_shuffle_query); + else + upstream_queris.emplace_back(ASTStageQuery::make(right_shuffle_query, frame->upstream_queries[1])); + frame->upstream_queries = std::vector{upstream_queris}; + + const auto & cluster_name = context->getSettings().use_cluster_for_distributed_shuffle.value; + const auto & query_id = context->getClientInfo().current_query_id; + frame->result_ast = select_ast->clone(); + auto * result_ast = frame->result_ast->as(); + + auto left_table = ASTBuildUtil::createTablesInSelectQueryElement( + ASTBuildUtil::createShuffleTableFunction( + TableFunctionShuffleJoin::name, + cluster_name, + query_id, + left_shuffle_table_id, + analyze_result->tables_columns[0], + analyze_result->tables_hash_keys[0], + ASTBuildUtil::getTableExpressionAlias(getTableExpression(*select_ast, 0))) + ->as(), + nullptr); + auto right_table = ASTBuildUtil::createTablesInSelectQueryElement( + ASTBuildUtil::createShuffleTableFunction( + TableFunctionShuffleJoin::name, + cluster_name, + query_id, + right_shuffle_table_id, + analyze_result->tables_columns[1], + analyze_result->tables_hash_keys[1], + ASTBuildUtil::getTableExpressionAlias(getTableExpression(*select_ast, 1))) + ->as(), + select_ast->join()->table_join); + ASTBuildUtil::updateSelectQueryTables( + result_ast, left_table->as(), right_table->as()); + } +} + +ASTPtr StageQueryDistributedJoinRewriteAction::createShuffleInsertForJoin( + const String & table_id, ASTTableExpression * table_expr, const NamesAndTypesList & table_desc, const ASTPtr & hash_expr) +{ + auto select_query_ref = std::make_shared(); + auto * select_query = select_query_ref->as(); + ASTBuildUtil::updateSelectQueryTables(select_query, table_expr); + select_query->setExpression(ASTSelectQuery::Expression::SELECT, ASTBuildUtil::createSelectExpression(table_desc)); + + auto hash_table = ASTBuildUtil::createShuffleTableFunction( + TableFunctionShuffleJoin::name, + context->getSettings().use_cluster_for_distributed_shuffle.value, + context->getClientInfo().current_query_id, + table_id, + table_desc, + hash_expr); + + return ASTBuildUtil::createTableFunctionInsertSelectQuery(hash_table, ASTBuildUtil::wrapSelectQuery(select_query)); +} + +StageQueryDistributedJoinRewriteAnalyzer::StageQueryDistributedJoinRewriteAnalyzer(const ASTSelectQuery * query_, ContextPtr context_) + : from_query(query_) + , context(context_) + , tables_columns_from_select({{}, {}}) + , tables_columns_from_on_join({{}, {}}) + , tables_hash_keys({std::make_shared(), std::make_shared()}) + , tables_columns({{}, {}}) +{ +} + +std::optional StageQueryDistributedJoinRewriteAnalyzer::analyze() +{ + if (!isApplicableJoinType()) + return {}; + + tables_with_columns = getDatabaseAndTablesWithColumns(getTableExpressions(*from_query), context, true, true); + + if (!collectHashKeys()) + return {}; + + collectTablesColumns(); + + return Result{.tables_columns = tables_columns, .tables_hash_keys = tables_hash_keys}; +} + + +bool StageQueryDistributedJoinRewriteAnalyzer::isApplicableJoinType() +{ + const auto * join_tables = from_query->join(); + auto * table_join = join_tables->table_join->as(); + if (table_join->kind == ASTTableJoin::Kind::Cross) + return false; + + // TODO if right table is dict or special storage, return false; + if (table_join->on_expression) + { + if (auto * or_func = table_join->on_expression->as(); or_func && or_func->name == "or") + { + LOG_INFO(logger, "Not support or join. {}", queryToString(*table_join)); + return false; + } + } + else + return false;// using clause + + // if it is a special storage, return false + const auto * join_ast = from_query->join(); + const auto & table_to_join = join_ast->table_expression->as(); + if (table_to_join.database_and_table_name) + { + auto joined_table_id = context->resolveStorageID(table_to_join.database_and_table_name); + StoragePtr storage = DatabaseCatalog::instance().tryGetTable(joined_table_id, context); + if (storage) + { + return false; + } + } + + return true; +} + +bool StageQueryDistributedJoinRewriteAnalyzer::collectHashKeys() +{ + CollectAliasColumnElementAction collect_alias_cols_action; + ASTDepthFirstVisitor collect_alias_cols_visitor(collect_alias_cols_action, from_query->clone()); + auto alias_cols = collect_alias_cols_visitor.visit(); + + auto * table_join = from_query->join()->table_join->as(); + if (table_join->on_expression) + { + auto * func = table_join->on_expression->as(); + if (func->name == "equals") + { + return collectHashKeysOnEqual(table_join->on_expression, tables_hash_keys, alias_cols); + } + else if (func->name == "and") + { + return collectHashKeysOnAnd(table_join->on_expression, tables_hash_keys, alias_cols); + } + else + LOG_TRACE(logger, "Unsupport function({}) on join clause.", func->name); + return false; + } + else + { + LOG_INFO(logger, "Join by using clause is not support now"); + return false; + } + __builtin_unreachable(); +} + +bool StageQueryDistributedJoinRewriteAnalyzer::collectHashKeysOnEqual( + ASTPtr ast, ASTs & keys, const std::map & alias_columns) +{ + auto * func = ast->as(); + auto & left_arg = func->arguments->children[0]; + auto & right_arg = func->arguments->children[1]; + ASTPtr left_key = nullptr, right_key = nullptr; + { + CollectRequiredColumnsAction action(tables_with_columns, alias_columns); + ASTDepthFirstVisitor visitor(action, left_arg); + auto columns = visitor.visit().required_columns; + if (!columns[0].empty() && columns[1].empty()) + { + left_key = left_arg; + } + else if (columns[0].empty() && !columns[1].empty()) + { + right_key = left_arg; + } + else + { + LOG_INFO(logger, "Cannot find pos for arg: {}", queryToString(left_arg)); + return false; + } + } + { + CollectRequiredColumnsAction action(tables_with_columns, alias_columns); + ASTDepthFirstVisitor visitor(action, right_arg); + auto columns = visitor.visit().required_columns; + if (!columns[0].empty() && columns[1].empty()) + { + left_key = right_arg; + } + else if (columns[0].empty() && !columns[1].empty()) + { + right_key = right_arg; + } + else + { + LOG_INFO(logger, "Cannot find pos for arg: {}", queryToString(right_arg)); + return false; + } + } + if (!left_key || !right_key) + { + LOG_INFO(logger, "Collect join keys failed for {}", queryToString(ast)); + return false; + } + + auto remove_alias = [](ASTPtr alias_ast) + { + if (auto * ident = alias_ast->as()) + { + ident->alias = ""; + } + else if (auto * func_ast = alias_ast->as()) + { + func_ast->alias = ""; + } + }; + + auto replace_alias_ast = [&alias_columns, remove_alias](ASTPtr & to_replace_ast) + { + if (auto * ident_ast = to_replace_ast->as()) + { + auto iter = alias_columns.find(ident_ast->name()); + if (iter != alias_columns.end()) + { + to_replace_ast = iter->second; + remove_alias(to_replace_ast); + } + } + }; + replace_alias_ast(left_key); + replace_alias_ast(right_key); + + IdentifiterQualiferRemoveAction left_key_action; + ASTDepthFirstVisitor left_key_visitor(left_key_action, left_key); + left_key = left_key_visitor.visit(); + IdentifiterQualiferRemoveAction right_key_action; + ASTDepthFirstVisitor right_key_visitor(right_key_action, right_key); + right_key = right_key_visitor.visit(); + keys[0]->children.push_back(left_key); + keys[1]->children.push_back(right_key); + return true; +} + +bool StageQueryDistributedJoinRewriteAnalyzer::collectHashKeysOnAnd( + ASTPtr ast, ASTs & keys_list, const std::map & alias_columns) +{ + auto * func = ast->as(); + for (auto & arg : func->arguments->children) + { + if (!collectHashKeysOnEqual(arg, keys_list, alias_columns)) + return false; + } + return true; +} + +void StageQueryDistributedJoinRewriteAnalyzer::collectTablesColumns() +{ + CollectAliasColumnElementAction collect_alias_col_action; + ASTDepthFirstVisitor collect_alias_col_visitor(collect_alias_col_action, from_query->select()); + auto alias_cols = collect_alias_col_visitor.visit(); + + CollectRequiredColumnsAction collect_select_action(tables_with_columns); + ASTDepthFirstVisitor collect_select_visitor(collect_select_action, from_query->select()); + tables_columns_from_select = collect_select_visitor.visit().required_columns; + + CollectRequiredColumnsAction collect_join_action(tables_with_columns, alias_cols); + ASTDepthFirstVisitor collect_join_visitor(collect_join_action, from_query->join()->table_join); + tables_columns_from_on_join = collect_join_visitor.visit().required_columns; + + for (size_t i = 0; i < tables_columns_from_select.size(); ++i) + { + auto & select_columns = tables_columns_from_select[i]; + auto & join_columns = tables_columns_from_on_join[i]; + auto & result_columns = tables_columns[i]; + + std::set added_columns; + for (const auto & col : select_columns) + { + const auto & name = col.short_name; + if (added_columns.count(name)) + continue; + result_columns.emplace_back(NameAndTypePair(name, col.type)); + added_columns.insert(name); + } + + for (const auto & col : join_columns) + { + const auto & name = col.short_name; + if (added_columns.count(name)) + continue; + result_columns.emplace_back(NameAndTypePair(name, col.type)); + added_columns.insert(name); + } + } +} + + +} diff --git a/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.h b/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.h new file mode 100644 index 000000000000..371d5acaedbb --- /dev/null +++ b/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.h @@ -0,0 +1,107 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ +class StageQueryDistributedJoinRewriteAnalyzer +{ +public: + struct Result + { + // for descrite table structure + std::vector tables_columns; + // each is ASTExpressionList, containes all required columns for hashing + ASTs tables_hash_keys; + }; + explicit StageQueryDistributedJoinRewriteAnalyzer(const ASTSelectQuery * query_, ContextPtr context_); + + // If this query cannot be rewrite under this strategy, return empty + std::optional analyze(); +private: + const ASTSelectQuery * from_query; + ContextPtr context; + TablesWithColumns tables_with_columns; + std::vector tables_columns_from_select; + std::vector tables_columns_from_on_join; + ASTs tables_hash_keys; + std::vector tables_columns; + + Poco::Logger * logger = &Poco::Logger::get("StageQueryDistributedJoinRewriteAnalyzer"); + + bool isApplicableJoinType(); + + bool collectHashKeys(); + bool collectHashKeysOnEqual(ASTPtr ast, ASTs & keys, const std::map & alias_columns); + bool collectHashKeysOnAnd(ASTPtr ast, ASTs & keys, const std::map & alias_columns); + void collectTablesColumns(); +}; +class StageQueryDistributedJoinRewriteAction : public EmptyASTDepthFirstVisitAction +{ +public: + struct Frame : public SimpleVisitFrame + { + explicit Frame(const ASTPtr & ast) : SimpleVisitFrame(ast) {} + std::vector upstream_queries; + + void addChildUpstreamQueries(const ASTs & queries) + { + upstream_queries.emplace_back(queries); + } + + void mergeChildrenUpstreamQueries() + { + ASTs result_upstream_queris; + for (auto & queries : upstream_queries) + result_upstream_queris.insert(result_upstream_queris.end(), queries.begin(), queries.end()); + upstream_queries = std::vector{result_upstream_queris}; + } + }; + + using Result = ASTPtr; + explicit StageQueryDistributedJoinRewriteAction(ContextPtr context_, ShuffleTableIdGeneratorPtr id_gen_ = nullptr); + ~StageQueryDistributedJoinRewriteAction() override = default; + + ASTs collectChildren(const ASTPtr & ast) override; + void beforeVisitChildren(const ASTPtr & ast) override; + void afterVisitChild(const ASTPtr & ast) override; + void visit(const ASTPtr & ast); + + inline Result getResult() + { + return frames.getTopFrame()->result_ast; + } + +private: + ContextPtr context; + ShuffleTableIdGeneratorPtr id_generator; + SimpleVisitFrameStack frames; + Poco::Logger * logger = &Poco::Logger::get("StageQueryDistributedJoinRewriteAction"); + + + void visit(const ASTSelectWithUnionQuery * select_with_union_ast); + void visit(const ASTSelectQuery * select_ast); + void visitSelectQueryWithAggregation(const ASTSelectQuery * select_ast); + void visitSelectQueryOnJoin(const ASTSelectQuery * select_ast); + void visit(const ASTTableExpression * table_expr_ast); + void visit(const ASTSubquery * subquery_ast); + + ASTPtr createShuffleInsertForJoin(const String & table_id, + ASTTableExpression * table_expr, + const NamesAndTypesList & table_desc, + const ASTPtr & hash_expr); + + + String getNextTableId() + { + static const String prefix = "join_"; + return prefix + std::to_string(id_generator->nextId()); + } +}; +} diff --git a/src/Interpreters/ASTRewriters/StageQueryShuffleFinishEventRewriteAction.cpp b/src/Interpreters/ASTRewriters/StageQueryShuffleFinishEventRewriteAction.cpp new file mode 100644 index 000000000000..f541e1d85b5c --- /dev/null +++ b/src/Interpreters/ASTRewriters/StageQueryShuffleFinishEventRewriteAction.cpp @@ -0,0 +1,105 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ + +void StageQueryShuffleFinishEventRewriteAction::beforeVisitChildren(const ASTPtr & ast) +{ + frames.pushFrame(ast); +} + +void StageQueryShuffleFinishEventRewriteAction::afterVisitChild(const ASTPtr & /*ast*/) +{ + auto child_result = frames.getTopFrame()->result_ast; + frames.popFrame(); + frames.getTopFrame()->children_results.emplace_back(child_result); +} + +ASTs StageQueryShuffleFinishEventRewriteAction::collectChildren(const ASTPtr & ast) +{ + if (!ast) + return {}; + ASTs children; + if (const auto * stage_ast = ast->as()) + { + children.emplace_back(stage_ast->current_query); + children.insert(children.end(), stage_ast->upstream_queries.begin(), stage_ast->upstream_queries.end()); + } + return children; +} + +void StageQueryShuffleFinishEventRewriteAction::visit(const ASTPtr & ast) +{ + if (const auto * stage_ast = ast->as()) + { + visit(stage_ast); + } + else if (const auto * insert_ast = ast->as()) + { + visit(insert_ast); + } + else + { + auto frame = frames.getTopFrame(); + frame->result_ast = frame->original_ast->clone(); + } +} + +void StageQueryShuffleFinishEventRewriteAction::visit(const ASTStageQuery * /*stage_ast*/) +{ + auto frame = frames.getTopFrame(); + auto current_query = frame->children_results[0]; + ASTs upstream_queries; + if (frame->children_results.size() > 1) + { + upstream_queries.insert(upstream_queries.end(), frame->children_results.begin() + 1, frame->children_results.end()); + } + + frame->result_ast = ASTStageQuery::make(current_query, upstream_queries); +} + +void StageQueryShuffleFinishEventRewriteAction::visit(const ASTInsertQuery * insert_ast) +{ + auto frame = frames.getTopFrame(); + if (!insert_ast->table_function) + { + frame->result_ast = frame->original_ast->clone(); + return; + } + + auto * table_function = insert_ast->table_function->as(); + auto & function_name = table_function->name; + if (function_name == TableFunctionShuffleJoin::name || function_name == TableFunctionShuffleAggregation::name + || function_name == TableFunctionLocalShuffle::name) + { + auto & args = table_function->arguments; + auto cluster_name = args->children[0]->as()->value.safeGet(); + auto session_id = args->children[1]->as()->value.safeGet(); + auto table_id = args->children[2]->as()->value.safeGet(); + + auto event_table_func = std::make_shared(); + event_table_func->name = TableFunctionClosedShuffle::name; + event_table_func->arguments = std::make_shared(); + event_table_func->arguments->children.push_back(args->children[0]); // cluster name + event_table_func->arguments->children.push_back(args->children[1]); // session id + event_table_func->arguments->children.push_back(args->children[2]); // table id + + auto event_insert_query = std::make_shared(); + event_insert_query->table_function = event_table_func; + frame->result_ast = ASTStageQuery::make(event_insert_query, {insert_ast->clone()}); + } + else + { + frame->result_ast = frame->original_ast->clone(); + } +} + +} diff --git a/src/Interpreters/ASTRewriters/StageQueryShuffleFinishEventRewriteAction.h b/src/Interpreters/ASTRewriters/StageQueryShuffleFinishEventRewriteAction.h new file mode 100644 index 000000000000..a955e62aba74 --- /dev/null +++ b/src/Interpreters/ASTRewriters/StageQueryShuffleFinishEventRewriteAction.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ +/** + * @brief For constucting a insert query to signal a shuffle block table has finished insertion. Always follow a insert query + * that inserts into a shuffle block table. + * + */ + +class StageQueryShuffleFinishEventRewriteAction : public EmptyASTDepthFirstVisitAction +{ +public: + using Result = ASTPtr; + StageQueryShuffleFinishEventRewriteAction() = default; + ~StageQueryShuffleFinishEventRewriteAction() override = default; + + ASTs collectChildren(const ASTPtr & ast) override; + void beforeVisitChildren(const ASTPtr & ast) override; + void afterVisitChild(const ASTPtr & ast) override; + void visit(const ASTPtr & ast); + + Result getResult() + { + auto frame = frames.getTopFrame(); + return frame->result_ast; + } +private: + SimpleVisitFrameStack<> frames; + + void visit(const ASTStageQuery * stage_ast); + void visit(const ASTInsertQuery * insert_ast); +}; +} diff --git a/src/Interpreters/InterpreterFactory.cpp b/src/Interpreters/InterpreterFactory.cpp index 5dcee1eae05c..7f0041796b9f 100644 --- a/src/Interpreters/InterpreterFactory.cpp +++ b/src/Interpreters/InterpreterFactory.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include @@ -65,6 +66,7 @@ #include #include #include +#include #include #include @@ -296,6 +298,10 @@ std::unique_ptr InterpreterFactory::get(ASTPtr & query, ContextMut { return std::make_unique(query, context); } + else if (query->as()) + { + return std::make_unique(query, context, options); + } else { throw Exception("Unknown type of query: " + query->getID(), ErrorCodes::UNKNOWN_TYPE_OF_QUERY); diff --git a/src/Interpreters/InterpreterStageQuery.cpp b/src/Interpreters/InterpreterStageQuery.cpp new file mode 100644 index 000000000000..d723df6611e2 --- /dev/null +++ b/src/Interpreters/InterpreterStageQuery.cpp @@ -0,0 +1,422 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int LOGICAL_ERROR; +} +InterpreterStageQuery::InterpreterStageQuery(ASTPtr query_, ContextPtr context_, SelectQueryOptions options_) + : query(query_) + , context(context_) + , options(options_) +{ +} + +BlockIO InterpreterStageQuery::execute() +{ + Stopwatch watch; + auto stage_query = std::dynamic_pointer_cast(query); + if (!stage_query) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ASTStageQuery is expected, but we get {}", query->getID()); + QueryBlockIOs input_block_ios; + std::mutex input_block_ios_mutex; + auto build_block_io_task = [&](ASTPtr input_query) + { + auto res = buildBlockIO(input_query); + std::lock_guard lock{input_block_ios_mutex}; + input_block_ios.emplace_back(res); + }; + const size_t default_thread_nums = 1; + size_t thread_nums = std::min(stage_query->upstream_queries.size() + 1, std::max(default_thread_nums, stage_query->upstream_queries.size() + 1)); + // Since it may need to query informations from outside storages, and the cost is high, make it run currently + ThreadPool thread_pool{thread_nums}; + for (auto & input_query : stage_query->upstream_queries) + { + thread_pool.scheduleOrThrowOnError([&input_query, build_block_io_task]() { build_block_io_task(input_query); }); + //input_block_ios.emplace_back(buildBlockIO(input_query)); + } + QueryBlockIO output_block_io; + thread_pool.scheduleOrThrowOnError( + [&]() + { + output_block_io = buildBlockIO(stage_query->current_query); + }); + thread_pool.wait(); + //auto output_block_io = buildBlockIO(stage_query->current_query); + + auto res = execute(output_block_io, input_block_ios); + return res; +} + +BlockIO InterpreterStageQuery::execute(const QueryBlockIO & output_io, const QueryBlockIOs & input_block_ios) +{ + QueryPlan query_plan; + if (couldRunParallelly(query->as())) + { + query_plan.addStep(std::make_unique(context, output_io, input_block_ios)); + } + else + { + query_plan.addStep(std::make_unique(context, output_io, input_block_ios)); + } + auto pipeline_builder = query_plan.buildQueryPipeline( + QueryPlanOptimizationSettings::fromContext(context), + BuildQueryPipelineSettings::fromContext(context)); + pipeline_builder->addInterpreterContext(context); + BlockIO res; + res.pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder)); + return res; +} + +QueryBlockIO InterpreterStageQuery::buildBlockIO(ASTPtr query_) +{ + QueryBlockIO res; + if (query_->as()) + { + res = buildInsertBlockIO(query_); + } + else if (query_->as()) + { + res = buildSelectBlockIO(query_); + } + else if (query_->as()) + { + res = buildTreeBlockIO(query_); + } + else + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknow query type : {}", query_->getID()); + } + return res; +} + +QueryBlockIO InterpreterStageQuery::buildInsertBlockIO(ASTPtr insert_query) +{ + Stopwatch watch; + auto distributed_queries = tryToMakeDistributedInsertQueries(insert_query); + if (!distributed_queries) + { + auto interpreter = InterpreterInsertQuery(insert_query, context); + auto res = std::make_shared(interpreter.execute()); + return {.block_io = res, .query = insert_query}; + } + const Scalars & scalars = context->hasQueryContext() ? context->getQueryContext()->getScalars() : Scalars{}; + Pipes pipes; + for (auto & shard_query : *distributed_queries) + { + auto & node = shard_query.first.first; + auto & task_extension = shard_query.first.second; + auto & remote_query = shard_query.second; + auto connection = std::make_shared( + node.host_name, + node.port, + context->getGlobalContext()->getCurrentDatabase(), + node.user, + node.password, + node.cluster, + node.cluster_secret, + "InterpreterStageQuery remote insert", + node.compression, + node.secure); + LOG_TRACE(logger, "run on node:{}, query:{}", node.host_name, remote_query); + auto remote_query_executor = std::make_shared( + connection, + remote_query, + Block{}, + context, + nullptr, + scalars, + Tables(), + QueryProcessingStage::Complete, + task_extension); + pipes.emplace_back(std::make_shared(remote_query_executor, false, false)); + } + auto pipe = Pipe::unitePipes(std::move(pipes)); + QueryPipelineBuilder pipeline_builder; + pipeline_builder.init(std::move(pipe)); + auto res = std::make_shared(); + res->pipeline = QueryPipelineBuilder::getPipeline(std::move(pipeline_builder)); + res->pipeline.setNumThreads(3); + return {.block_io = res, .query = insert_query}; +} + +std::vector InterpreterStageQuery::getSelectStorages(ASTPtr ast) +{ + CollectQueryStoragesAction action(context); + ASTDepthFirstVisitor visitor(action, ast); + return visitor.visit(); +} + +ASTPtr InterpreterStageQuery::unwrapSingleSelectQuery(const ASTPtr & ast) +{ + if (auto * select_with_union = ast->as()) + { + if (select_with_union->list_of_selects->children.size() > 1) + return nullptr; + return unwrapSingleSelectQuery(select_with_union->list_of_selects->children[0]); + } + else if (auto * select = ast->as()) + { + if (select->join()) + return nullptr; + auto table_expression = extractTableExpression(*select, 0); + if (auto * select_ast = table_expression->as()) + { + return unwrapSingleSelectQuery(table_expression); + } + return ast; + } + return nullptr; +} + +std::optional>> InterpreterStageQuery::tryToMakeDistributedInsertQueries(ASTPtr from_query) +{ + String cluster_name = context->getSettings().use_cluster_for_distributed_shuffle.value; + auto * insert_query = from_query->as(); + + if (insert_query->table_function && insert_query->table_function->as()->name == TableFunctionClosedShuffle::name) + { + String query_str = queryToString(from_query); + query_str += " (1)"; + + std::list> res; + auto cluster = context->getCluster(cluster_name)->getClusterWithReplicasAsShards(context->getSettings()); + for (const auto & replicas : cluster->getShardsAddresses()) + { + for (const auto & node : replicas) + { + DistributedTask task(node, RemoteQueryExecutor::Extension{}); + res.emplace_back(std::make_pair(task, query_str)); + } + } + return res; + } + + auto storages = getSelectStorages(insert_query->select); + bool has_groupby = ASTAnalyzeUtil::hasGroupByRecursively(from_query); + bool has_agg = ASTAnalyzeUtil::hasAggregationColumnRecursively(from_query); + DistributedTasks tasks; + if (storages.size() == 2) + { + for (const auto & storage : storages) + { + if (storage->getName() != StorageShuffleJoin::NAME) + { + return {}; + } + } + if (has_groupby || has_agg) + return {}; + auto cluster = context->getCluster(cluster_name)->getClusterWithReplicasAsShards(context->getSettings()); + for (const auto & replicas : cluster->getShardsAddresses()) + { + for (const auto & node : replicas) + { + DistributedTask task(node, RemoteQueryExecutor::Extension{}); + tasks.emplace_back(task); + } + } + } + else if (storages.size() == 1) + { + if (storages[0]->getName() == StorageLocalShuffle::NAME) + return {}; + auto distributed_tasks_builder = StorageDistributedTaskBuilderFactory::getInstance().getBuilder(storages[0]->getName()); + if (!distributed_tasks_builder) + { + LOG_INFO(logger, "Not found builder for {}", storages[0]->getName()); + return {}; + } + + auto nested_select_query = unwrapSingleSelectQuery(insert_query->select); + if (!nested_select_query) + return {}; + tasks = distributed_tasks_builder->getDistributedTasks( + cluster_name, context, nested_select_query, storages[0]); + if (tasks.empty()) + { + return {}; + } + } + else + { + return {}; + } + // just for test + std::list> res; + auto query_str = queryToString(from_query); + for (const auto & task : tasks) + { + res.emplace_back(std::make_pair(task, query_str)); + } + return res; +} + +QueryBlockIO InterpreterStageQuery::buildSelectBlockIO(ASTPtr select_query) +{ + auto distributed_queries = tryToMakeDistributedSelectQueries(select_query); + if (!distributed_queries) + { + InterpreterSelectWithUnionQuery interpreter(select_query, context, options); + auto res = std::make_shared(interpreter.execute()); + return {.block_io = res, .query = select_query}; + } + InterpreterSelectWithUnionQuery select_interpreter(select_query, context, options); + Block header = select_interpreter.getSampleBlock(); + const Scalars & scalars = context->hasQueryContext() ? context->getQueryContext()->getScalars() : Scalars{}; + Pipes pipes; + for (auto & shard_query : *distributed_queries) + { + auto & node = shard_query.first.first; + auto task_extension = shard_query.first.second; + auto & remote_query = shard_query.second; + auto connection = std::make_shared( + node.host_name, + node.port, + context->getGlobalContext()->getCurrentDatabase(), + node.user, + node.password, + node.cluster, + node.cluster_secret, + "InterpreterStageQuery remote select", + node.compression, + node.secure); + + auto remote_query_executor = std::make_shared( + connection, + remote_query, + header, + context, + nullptr, + scalars, + Tables(), + QueryProcessingStage::Complete, + RemoteQueryExecutor::Extension{}); + pipes.emplace_back(std::make_shared(remote_query_executor, false, false)); + } + auto pipe = Pipe::unitePipes(std::move(pipes)); + QueryPipelineBuilder pipeline_builder; + pipeline_builder.init(std::move(pipe)); + auto res = std::make_shared(); + res->pipeline = QueryPipelineBuilder::getPipeline(std::move(pipeline_builder)); + return {.block_io = res, .query = select_query}; +} + +std::list> InterpreterStageQuery::buildSelectTasks(ASTPtr from_query) +{ + std::list> res; + String query_str = queryToString(from_query); + String cluster_name = context->getSettings().use_cluster_for_distributed_shuffle.value; + auto cluster = context->getCluster(cluster_name)->getClusterWithReplicasAsShards(context->getSettings()); + for (const auto & replicas : cluster->getShardsAddresses()) + { + for (const auto & node : replicas) + { + DistributedTask task(node, RemoteQueryExecutor::Extension{}); + res.emplace_back(std::make_pair(task, query_str)); + } + } + return res; + +} +std::optional>> InterpreterStageQuery::tryToMakeDistributedSelectQueries(ASTPtr from_query) +{ + auto storages = getSelectStorages(from_query); + bool has_groupby = ASTAnalyzeUtil::hasGroupByRecursively(from_query); + bool has_agg = ASTAnalyzeUtil::hasAggregationColumnRecursively(from_query); + if (storages.size() == 2) + { + for (const auto & storage : storages) + { + if (storage->getName() != StorageShuffleJoin::NAME) + { + LOG_TRACE(logger, "Not hash storage:{} {}", storage->getName(), StorageShuffleJoin::NAME); + return {}; + } + } + if (has_groupby || has_agg) + return {}; + + return buildSelectTasks(from_query); + + } + else if (storages.size() == 1) + { + if (storages[0]->getName() == StorageShuffleAggregation::NAME) + { + if (!has_groupby) + throw Exception(ErrorCodes::LOGICAL_ERROR, "There should be group by clause here. query: {}", queryToString(from_query)); + return buildSelectTasks(from_query); + } + } + return {}; + +} + +QueryBlockIO InterpreterStageQuery::buildTreeBlockIO(ASTPtr stage_query) +{ + InterpreterStageQuery interpreter(stage_query, context, options); + return {.block_io = std::make_shared(interpreter.execute()), .query = stage_query}; +} + +bool InterpreterStageQuery::couldRunParallelly(const ASTStageQuery * query) +{ + if (query->upstream_queries.size() != 1) + return true; + + const auto * insert_query = query->upstream_queries[0]->as(); + if (!insert_query) + return true; + + const auto * event_insert_query = query->current_query->as(); + if (!event_insert_query) + return true; + + if (event_insert_query->table_function->as()->name != TableFunctionClosedShuffle::name) + return true; + return false; +} + +} diff --git a/src/Interpreters/InterpreterStageQuery.h b/src/Interpreters/InterpreterStageQuery.h new file mode 100644 index 000000000000..11cb62db27d5 --- /dev/null +++ b/src/Interpreters/InterpreterStageQuery.h @@ -0,0 +1,48 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +class InterpreterStageQuery : public IInterpreter +{ +public: + explicit InterpreterStageQuery(ASTPtr query_, ContextPtr context_, SelectQueryOptions options_); + BlockIO execute() override; +private: + ASTPtr query; + ContextPtr context; + SelectQueryOptions options; + Poco::Logger * logger = &Poco::Logger::get("InterpreterStageQuery"); + + QueryBlockIO buildBlockIO(ASTPtr query_); + + QueryBlockIO buildInsertBlockIO(ASTPtr insert_query); + QueryBlockIO buildSelectBlockIO(ASTPtr select_query); + QueryBlockIO buildTreeBlockIO(ASTPtr stage_query); + + BlockIO execute(const QueryBlockIO & output_io, const QueryBlockIOs & input_block_ios); + + std::optional>> tryToMakeDistributedInsertQueries(ASTPtr from_query); + std::optional>> tryToMakeDistributedSelectQueries(ASTPtr from_query); + + std::vector getSelectStorages(ASTPtr ast); + + std::list> buildSelectTasks(ASTPtr from_query); + + static ASTPtr unwrapSingleSelectQuery(const ASTPtr & ast); + + static bool couldRunParallelly(const ASTStageQuery * query); +}; +} diff --git a/src/Interpreters/StorageDistributedTasksBuilder.cpp b/src/Interpreters/StorageDistributedTasksBuilder.cpp new file mode 100644 index 000000000000..75f497a59737 --- /dev/null +++ b/src/Interpreters/StorageDistributedTasksBuilder.cpp @@ -0,0 +1,37 @@ +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +StorageDistributedTaskBuilderFactory & StorageDistributedTaskBuilderFactory::getInstance() +{ + static StorageDistributedTaskBuilderFactory instance; + return instance; +} + +void StorageDistributedTaskBuilderFactory::registerMaker(const String & name, StorageDistributedTaskBuilderMaker maker) +{ + auto iter = makers.find(name); + if (iter != makers.end()) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Duplicated maker : {}", name); + } + makers[name] = maker; +} + +StorageDistributedTaskBuilderPtr StorageDistributedTaskBuilderFactory::getBuilder(const String & name) +{ + auto iter = makers.find(name); + if (iter == makers.end()) + return nullptr; + return iter->second(); +} + +void registerAllStorageDistributedTaskBuilderMakers() +{ +} +} diff --git a/src/Interpreters/StorageDistributedTasksBuilder.h b/src/Interpreters/StorageDistributedTasksBuilder.h new file mode 100644 index 000000000000..1848933d139d --- /dev/null +++ b/src/Interpreters/StorageDistributedTasksBuilder.h @@ -0,0 +1,38 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +using DistributedTask = std::pair; +using DistributedTasks = std::vector; +class IStorageDistributedTaskBuilder +{ +public: + virtual ~IStorageDistributedTaskBuilder() = default; + virtual DistributedTasks getDistributedTasks(const String & cluster_name, ContextPtr context, ASTPtr ast, StoragePtr storage) = 0; +}; +using StorageDistributedTaskBuilderPtr = std::shared_ptr; + +using StorageDistributedTaskBuilderMaker = std::function; +class StorageDistributedTaskBuilderFactory : boost::noncopyable +{ +public: + static StorageDistributedTaskBuilderFactory & getInstance(); + void registerMaker(const String & name, StorageDistributedTaskBuilderMaker maker); + StorageDistributedTaskBuilderPtr getBuilder(const String & name); +protected: + StorageDistributedTaskBuilderFactory() = default; + +private: + std::map makers; + +}; + +void registerAllStorageDistributedTaskBuilderMakers(); +} diff --git a/src/Interpreters/executeQuery.cpp b/src/Interpreters/executeQuery.cpp index 24649128cee5..d8052837dbb1 100644 --- a/src/Interpreters/executeQuery.cpp +++ b/src/Interpreters/executeQuery.cpp @@ -54,6 +54,11 @@ #include #include #include +#include +#include +#include +#include +#include #include #include @@ -626,6 +631,41 @@ static std::tuple executeQueryImpl( } else { + if (context->getClientInfo().query_kind == ClientInfo::QueryKind::INITIAL_QUERY && std::dynamic_pointer_cast(ast)) + { + { + auto select_with_union = std::dynamic_pointer_cast(ast); + LOG_TRACE(&Poco::Logger::get("executeQuery"), "union_mode:{}, SelectUnionModes.size:{}, SelectUnionModesSet.size:{}, list_of_selects:{}", + select_with_union->union_mode, + select_with_union->list_of_modes.size(), + select_with_union->set_of_modes.size(), + select_with_union->list_of_selects->getID()); + } + if (!context->getSettings().use_cluster_for_distributed_shuffle.value.empty()) + { + MakeFunctionColumnAliasAction function_alias_action; + ASTDepthFirstVisitor function_alias_visitor(function_alias_action, ast); + auto function_alias_visit_result = function_alias_visitor.visit(); + LOG_TRACE(&Poco::Logger::get("executeQuery"), "function_alias_visit_result={}", queryToString(function_alias_visit_result)); + + NestedJoinQueryRewriteAction nested_join_query_action(context); + ASTDepthFirstVisitor nested_join_query_visitor(nested_join_query_action, function_alias_visit_result); + auto nested_join_query_visit_result = nested_join_query_visitor.visit(); + LOG_TRACE(&Poco::Logger::get("executeQuery"), "nested_join_query_visit_result={}", queryToString(nested_join_query_visit_result)); + + StageQueryDistributedJoinRewriteAction join_rewrite_action(context); + ASTDepthFirstVisitor join_rewrite_visitor(join_rewrite_action, nested_join_query_visit_result); + auto join_rewrite_result = join_rewrite_visitor.visit(); + LOG_TRACE(&Poco::Logger::get("executeQuery"), "join_rewrite_result={}", queryToString(join_rewrite_result)); + + StageQueryShuffleFinishEventRewriteAction add_finish_event_action; + ASTDepthFirstVisitor add_finish_event_visitor(add_finish_event_action, join_rewrite_result); + auto add_finish_event_result = add_finish_event_visitor.visit(); + LOG_TRACE(&Poco::Logger::get("executeQuery"), "add_finish_event_result={}", queryToString(add_finish_event_result)); + + ast = add_finish_event_result; + } + } interpreter = InterpreterFactory::get(ast, context, SelectQueryOptions(stage).setInternal(internal)); if (context->getCurrentTransaction() && !interpreter->supportsTransactions() && diff --git a/src/Parsers/ASTStageQuery.cpp b/src/Parsers/ASTStageQuery.cpp new file mode 100644 index 000000000000..4c85a5242834 --- /dev/null +++ b/src/Parsers/ASTStageQuery.cpp @@ -0,0 +1,47 @@ +#include +#include +#include +#include +#include "Functions/formatString.h" +#include "Parsers/IAST_fwd.h" +namespace DB +{ +ASTPtr ASTStageQuery::make(ASTPtr current_query, ASTs upstream_queries) +{ + auto stage_query = std::make_shared(); + stage_query->current_query = current_query; + stage_query->upstream_queries = upstream_queries; + stage_query->children.insert(stage_query->children.end(), upstream_queries.begin(), upstream_queries.end()); + stage_query->children.push_back(current_query); + return stage_query; +} +ASTPtr ASTStageQuery::clone() const +{ + auto res = std::make_shared(); + for (const auto & ast : upstream_queries) + { + res->upstream_queries.emplace_back(ast->clone()); + res->children.emplace_back(res->upstream_queries.back()); + } + res->current_query = current_query->clone(); + res->children.emplace_back(res->current_query); + return res; +} + +void ASTStageQuery::formatImpl(const FormatSettings & settings, FormatState & /*state*/, FormatStateStacked /*frame*/) const +{ + int i = 0; + for (const auto & ast : upstream_queries) + { + if (i) + settings.ostr << "\n"; + ast->format(settings); + settings.ostr << ";"; + i += 1; + } + if (i) + settings.ostr << "\n"; + current_query->format(settings); +} + +} diff --git a/src/Parsers/ASTStageQuery.h b/src/Parsers/ASTStageQuery.h new file mode 100644 index 000000000000..3b804f3cb0d0 --- /dev/null +++ b/src/Parsers/ASTStageQuery.h @@ -0,0 +1,26 @@ +#pragma once +#include +#include +/** + * For making one query into multi queries。 Eache query is a tree node. Current query will be executed only when + * all its children nodes finshed execution。 + * + * Mostly are designed for distributed shuffle join + * + * Be careful, ASTStageQuery cannot be an inner ast in ASTSelectQuery or ASTSelectWithUnionQuery. But inner ast in + * ASTStageQuery could be ASTSelectQuery, ASTSelectWithUnionQuery or ASTStageQuery. + */ +namespace DB +{ +class ASTStageQuery : public IAST +{ +public: + ASTs upstream_queries; // be relied by current_query + ASTPtr current_query; // rely on upstream_queries + String getID(char) const override { return "ASTStageQuery"; } + ASTPtr clone() const override; + void formatImpl(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override; + + static ASTPtr make(ASTPtr current_query, ASTs upstream_queries); +}; +} diff --git a/src/Processors/QueryPlan/StageQueryStep.cpp b/src/Processors/QueryPlan/StageQueryStep.cpp new file mode 100644 index 000000000000..84259c594a84 --- /dev/null +++ b/src/Processors/QueryPlan/StageQueryStep.cpp @@ -0,0 +1,98 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +StageQueryStep::StageQueryStep(ContextPtr context_, const QueryBlockIO & output_block_io_, const QueryBlockIOs & input_block_ios_) + : context(context_), output_block_io(output_block_io_), input_block_ios(input_block_ios_) +{ +} + +QueryPipelineBuilderPtr StageQueryStep::updatePipeline(QueryPipelineBuilders pipelines, const BuildQueryPipelineSettings & /*settings*/) +{ + if (!pipelines.empty()) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "pipelines should be empty"); + } + auto pipeline_builder_ptr = std::make_unique(); + auto & pipeline_builder = *pipeline_builder_ptr; + + auto input_block_io_transform = [&](const QueryBlockIOs & block_ios) + { + Processors processors; + for (const auto & block_io : block_ios) + { + processors.emplace_back(std::make_shared(context, block_io)); + } + return Pipe(processors); + }; + if (!input_block_ios.empty()) + { + auto pipe = input_block_io_transform(input_block_ios); + pipeline_builder.init(std::move(pipe)); + LOG_TRACE(logger, "pipeline input header: {}.", pipeline_builder.getHeader().dumpNames()); + } + else + { + QueryBlockIOs source_block_ios = {output_block_io}; + auto pipe = input_block_io_transform(source_block_ios); + pipeline_builder.init(std::move(pipe)); + return pipeline_builder_ptr; + } + + auto output_block_io_transform = [&](OutputPortRawPtrs outports) + { + std::vector headers; + for (auto & outport : outports) + { + LOG_TRACE(logger, "upstream output header:{}.", outport->getHeader().dumpNames()); + headers.emplace_back(outport->getHeader()); + } + auto processor = std::make_shared(context, output_block_io, headers); + auto & in_ports = processor->getInputs(); + size_t i = 0; + for (auto & in_port : in_ports) + { + connect(*outports[i], in_port); + i++; + } + return Processors{processor}; + }; + pipeline_builder.transform(output_block_io_transform); + return pipeline_builder_ptr; +} + +ParallelStageQueryStep::ParallelStageQueryStep( + ContextPtr context_, const QueryBlockIO & output_block_io_, const QueryBlockIOs & input_block_ios_) + : context(context_), output_block_io(output_block_io_), input_block_ios(input_block_ios_) +{ +} + +QueryPipelineBuilderPtr +ParallelStageQueryStep::updatePipeline(QueryPipelineBuilders pipelines, const BuildQueryPipelineSettings & /*settings*/) +{ + if (!pipelines.empty()) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "pipelines should be empty"); + } + + auto pipeline_builder_ptr = std::make_unique(); + auto & pipeline_builder = *pipeline_builder_ptr; + + Processors processors; + processors.emplace_back(std::make_shared(context, output_block_io, input_block_ios)); + pipeline_builder.init(Pipe(processors)); + return pipeline_builder_ptr; +} +} diff --git a/src/Processors/QueryPlan/StageQueryStep.h b/src/Processors/QueryPlan/StageQueryStep.h new file mode 100644 index 000000000000..3251c82cd615 --- /dev/null +++ b/src/Processors/QueryPlan/StageQueryStep.h @@ -0,0 +1,39 @@ +#pragma once +#include +#include +#include +#include +//#include +#include +#include +#include +#include +#include +namespace DB +{ +class StageQueryStep : public IQueryPlanStep +{ +public: + explicit StageQueryStep(ContextPtr context_, const QueryBlockIO & output_block_io_, const QueryBlockIOs & input_block_ios_); + String getName() const override { return "StageQueryStep"; } + QueryPipelineBuilderPtr updatePipeline(QueryPipelineBuilders pipelines, const BuildQueryPipelineSettings & settings) override; +private: + ContextPtr context; + QueryBlockIO output_block_io; + QueryBlockIOs input_block_ios; + Poco::Logger * logger = &Poco::Logger::get("StageQueryStep"); +}; + +class ParallelStageQueryStep : public IQueryPlanStep +{ +public: + explicit ParallelStageQueryStep(ContextPtr context_, const QueryBlockIO & output_block_io_, const QueryBlockIOs & input_block_ios_); + String getName() const override { return "ParallelStageQueryStep"; } + QueryPipelineBuilderPtr updatePipeline(QueryPipelineBuilders pipelines, const BuildQueryPipelineSettings & settings) override; +private: + ContextPtr context; + QueryBlockIO output_block_io; + QueryBlockIOs input_block_ios; + Poco::Logger * logger = &Poco::Logger::get("ParallelStageQueryStep"); +}; +} diff --git a/src/Processors/Transforms/StageQueryTransform.cpp b/src/Processors/Transforms/StageQueryTransform.cpp new file mode 100644 index 000000000000..cce89ac48d2e --- /dev/null +++ b/src/Processors/Transforms/StageQueryTransform.cpp @@ -0,0 +1,324 @@ +#include +#include +#include +#include "Common/Stopwatch.h" +#include "Common/ThreadPool.h" +#include +#include "Processors/Executors/PullingAsyncPipelineExecutor.h" +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +BlockIOSourceTransform::BlockIOSourceTransform(ContextPtr context_, const QueryBlockIO & block_io_) + : ISource(block_io_.block_io->pipeline.completed() ? Block{} : block_io_.block_io->pipeline.getHeader()) + , context(context_) + , block_io(block_io_.block_io) + , query(block_io_.query) +{ + is_pulling_pipeline = block_io->pipeline.pulling(); + is_completed_pipeline = block_io->pipeline.completed(); +} + +BlockIOSourceTransform::~BlockIOSourceTransform() +{ + LOG_TRACE(logger, "run query({}) in elapsedMilliseconds:{}", queryToString(query), elapsed); +} + +Chunk BlockIOSourceTransform::generate() +{ + if (unlikely(!watch)) + { + watch = std::make_unique(); + } + if (is_completed_pipeline) + { + LOG_TRACE(logger, "Run in completed mode. current query:{}", queryToString(query)); + CompletedPipelineExecutor executor(block_io->pipeline); + executor.execute(); + watch->stop(); + elapsed = watch->elapsedMilliseconds(); + return {}; + } + else if (is_pulling_pipeline) + { + if (!pulling_executor) + { + LOG_TRACE(logger, "Run in pulling mode. current query:{}", queryToString(query)); + pulling_executor = std::make_unique(block_io->pipeline); + } + Chunk res; + while (pulling_executor->pull(res)) + { + if (res) + { + return res; + } + } + } + else + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid pipeline mode"); + } + watch->stop(); + elapsed = watch->elapsedMilliseconds(); + return {}; +} + +static InputPorts headersToInputPorts(const std::vector & headers) +{ + InputPorts ports; + for (const auto & header : headers) + { + ports.emplace_back(header); + } + return ports; +} +StageBlockIOsConnectTransform::StageBlockIOsConnectTransform( + ContextPtr context_, const QueryBlockIO & output_block_io_, const std::vector & input_headers_) + : IProcessor(headersToInputPorts(input_headers_), {output_block_io_.block_io->pipeline.getHeader()}) + , context(context_) + , output_block_io(output_block_io_.block_io) + , query(output_block_io_.query) +{ + is_pulling_pipeline = output_block_io->pipeline.pulling(); + is_completed_pipeline = output_block_io->pipeline.completed(); +} + +StageBlockIOsConnectTransform::~StageBlockIOsConnectTransform() +{ + LOG_TRACE(logger, "run query({}) in elapsedMilliseconds:{}", queryToString(query), elapsed); +} + +IProcessor::Status StageBlockIOsConnectTransform::prepare() +{ + if (unlikely(!watch)) + { + watch = std::make_unique(); + } + auto & output = outputs.front(); + if (output.isFinished()) + { + for (auto & input : inputs) + { + input.close(); + } + LOG_TRACE(logger, "output.isFinished()"); + watch->stop(); + elapsed = watch->elapsedMilliseconds(); + return Status::Finished; + } + + if (!output.canPush()) + { + for (auto & input : inputs) + { + input.setNotNeeded(); + } + LOG_TRACE(logger, "!output.canPush()"); + return Status::PortFull; + } + + if (has_output) + { + output.push(std::move(chunk)); + has_output = false; + //LOG_TRACE(logger, "has_output"); + return Status::PortFull; + } + + if (has_input) + { + //LOG_TRACE(logger, "has_input"); + return Status::Ready; + } + + bool all_input_finished = true; + for (auto & input : inputs) + { + if (!input.isFinished()) + { + LOG_TRACE(logger, "try to pull upstream data. current query:{}", queryToString(query)); + all_input_finished = false; + input.setNeeded(); + if (input.hasData()) + (void)input.pullData(); + } + } + + if (!all_input_finished) + { + LOG_TRACE(logger, "need_blocked && !all_input_finished"); + return Status::NeedData; + } + + if (is_completed_pipeline) + { + LOG_TRACE(logger, "Run in completed mode. current query:{}", queryToString(query)); + CompletedPipelineExecutor executor(output_block_io->pipeline); + executor.execute(); + outputs.front().finish(); + watch->stop(); + elapsed = watch->elapsedMilliseconds(); + return Status::Finished; + } + else if (is_pulling_pipeline) + { + if (!pulling_executor) + { + LOG_TRACE(logger, "Run in pullig mode. current query:{}", queryToString(query)); + pulling_executor = std::make_unique(output_block_io->pipeline); + } + Chunk new_chunk; + while (pulling_executor->pull(new_chunk)) + { + LOG_TRACE(logger, "pull chunk rows:{}", new_chunk.getNumRows()); + if (new_chunk) + { + has_input = true; + chunk.swap(new_chunk); + return Status::Ready; + } + } + LOG_TRACE(logger, "pulling_executor->pull() = false. chunk rows:{}", new_chunk.getNumRows()); + } + else + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid pipeline"); + } + + outputs.front().finish(); + watch->stop(); + elapsed = watch->elapsedMilliseconds(); + return Status::Finished; +} + +void StageBlockIOsConnectTransform::work() +{ + if (has_input) + { + has_input = false; + has_output = true; + } +} + +ParallelStageBlockIOsTransform::ParallelStageBlockIOsTransform( + ContextPtr context_, const QueryBlockIO & output_block_io_, const QueryBlockIOs & input_block_ios_) + : ISource(output_block_io_.block_io->pipeline.completed() ? Block{} : output_block_io_.block_io->pipeline.getHeader()) + , context(context_) + , output_block_io(output_block_io_) + , input_block_ios(input_block_ios_) +{ + is_completed_pipeline = output_block_io.block_io->pipeline.completed(); + is_pulling_pipeline = output_block_io.block_io->pipeline.pulling(); +} + +ParallelStageBlockIOsTransform::~ParallelStageBlockIOsTransform() +{ +#if 0 + for (auto & task : background_tasks) + { + task->deactivate(); + } +#else + if (thread_pool) + thread_pool->wait(); +#endif + LOG_TRACE(logger, "run query({}) in elapsedMilliseconds:{}", queryToString(output_block_io.query), elapsed); +} + +Chunk ParallelStageBlockIOsTransform::generate() +{ + if (unlikely(!has_start_background_tasks)) + { + startBackgroundTasks(); + } + + if (unlikely(!watch)) + { + watch = std::make_unique(); + } + + if (is_completed_pipeline) + { + CompletedPipelineExecutor executor(output_block_io.block_io->pipeline); + executor.execute(); + elapsed = watch->elapsedMilliseconds(); + return {}; + } + else if (is_pulling_pipeline) + { + if (unlikely(!pulling_executor)) + { + pulling_executor = std::make_unique(output_block_io.block_io->pipeline); + } + Chunk res; + while (pulling_executor->pull(res)) + { + if (res) + { + LOG_TRACE(logger, "read chunk . rows:{}. query:{}", res.getNumRows(), queryToString(output_block_io.query)); + return res; + } + } + } + else + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid pipeline mode"); + } + elapsed = watch->elapsedMilliseconds(); + return {}; +} + +void ParallelStageBlockIOsTransform::startBackgroundTasks() +{ + auto build_task = [](QueryBlockIO & block_io) + { + Stopwatch task_watch; + if (block_io.block_io->pipeline.completed()) + { + CompletedPipelineExecutor executor(block_io.block_io->pipeline); + executor.execute(); + } + else if (block_io.block_io->pipeline.pulling()) + { + PullingAsyncPipelineExecutor executor(block_io.block_io->pipeline); + Chunk res; + while (executor.pull(res)) + { + } + } + LOG_TRACE( + &Poco::Logger::get("ParallelStageBlockIOsTransform"), + "upstream query({}) run in elapsedMilliseconds:{}", + queryToString(block_io.query), + task_watch.elapsedMilliseconds()); + }; +#if 0 + auto & thread_pool = context->getSchedulePool(); + for (auto & block_io : input_block_ios) + { + background_tasks.emplace_back(thread_pool.createTask("BackgroundBlockIOTask", [build_task, &block_io](){ build_task(block_io);})); + background_tasks.back()->activateAndSchedule(); + } +#else + thread_pool = std::make_unique(input_block_ios.size()); + for (auto & block : input_block_ios) + { + thread_pool->scheduleOrThrowOnError([&]() { build_task(block); }); + } +#endif + has_start_background_tasks = true; +} +} diff --git a/src/Processors/Transforms/StageQueryTransform.h b/src/Processors/Transforms/StageQueryTransform.h new file mode 100644 index 000000000000..e40e5a293ae2 --- /dev/null +++ b/src/Processors/Transforms/StageQueryTransform.h @@ -0,0 +1,99 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ + +struct QueryBlockIO +{ + using BlockIOPtr = std::shared_ptr; + BlockIOPtr block_io; + ASTPtr query; // the query to generate this block_io +}; +using QueryBlockIOs = std::vector; +class BlockIOSourceTransform : public ISource +{ +public: + using BlockIOPtr = std::shared_ptr; + explicit BlockIOSourceTransform(ContextPtr context_, const QueryBlockIO & block_io_); + ~BlockIOSourceTransform() override; + String getName() const override { return "BlockIOSourceTransform"; } + Chunk generate() override; +private: + ContextPtr context; + BlockIOPtr block_io; + ASTPtr query; + bool is_pulling_pipeline; + bool is_completed_pipeline; + std::unique_ptr pulling_executor; + Poco::Logger * logger = &Poco::Logger::get("BlockIOSourceTransform"); + std::unique_ptr watch; + UInt64 elapsed = 0l; +}; + +class StageBlockIOsConnectTransform : public IProcessor +{ +public: + using BlockIOPtr = std::shared_ptr; + using BlockIOs = std::vector; + StageBlockIOsConnectTransform( + ContextPtr context_, + const QueryBlockIO & output_block_io_, + const std::vector & input_headers_); + ~StageBlockIOsConnectTransform() override; + String getName() const override { return "StageBlockIOsConnectTransform"; } + Status prepare() override; + void work() override; +private: + ContextPtr context; + BlockIOPtr output_block_io; + ASTPtr query; + bool is_pulling_pipeline; + bool is_completed_pipeline; + bool has_output = false; + bool has_input = false; + Chunk chunk; + std::unique_ptr pulling_executor; + Poco::Logger * logger = &Poco::Logger::get("StageBlockIOsConnectTransform"); + std::unique_ptr watch; + UInt64 elapsed = 0l; +}; + +class ParallelStageBlockIOsTransform : public ISource +{ +public: + explicit ParallelStageBlockIOsTransform(ContextPtr context_, const QueryBlockIO & output_block_io_, const QueryBlockIOs & input_block_ios_); + ~ParallelStageBlockIOsTransform() override; + String getName() const override { return "ParallelStageBlockIOsTransform"; } + Chunk generate() override; +private: + ContextPtr context; + QueryBlockIO output_block_io; + QueryBlockIOs input_block_ios; + + bool is_pulling_pipeline; + bool is_completed_pipeline; + + Chunk chunk; + std::unique_ptr pulling_executor; + Poco::Logger * logger = &Poco::Logger::get("ParallelStageBlockIOsTransform"); + + bool has_start_background_tasks = false; + // BackgroundSchedulePool::TaskHolder doesn't throw the inside exceptions + //std::vector background_tasks; + std::unique_ptr thread_pool; + void startBackgroundTasks(); + + std::unique_ptr watch; + UInt64 elapsed = 0l; +}; +} diff --git a/src/Storages/DistributedShuffle/ShuffleBlockTable.cpp b/src/Storages/DistributedShuffle/ShuffleBlockTable.cpp new file mode 100644 index 000000000000..162e3203afa7 --- /dev/null +++ b/src/Storages/DistributedShuffle/ShuffleBlockTable.cpp @@ -0,0 +1,242 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ProfileEvents +{ +extern const Event ClearTimeoutShuffleStorageSession; +} + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int LOGICAL_ERROR; +} + +void ShuffleBlockTable::addChunk(Chunk && chunk) +{ + if (chunk.hasRows())[[likely]] + { + if (is_sink_finished)[[unlikely]] + throw Exception(ErrorCodes::LOGICAL_ERROR, "Try in insert into a sink finished table({}.{})", session_id, table_id); + std::unique_lock lock(mutex); + rows += chunk.getNumRows(); + chunks.emplace_back(std::move(chunk)); + wait_more_data.notify_one(); + } + else + { + LOG_TRACE(logger, "add empty chunk"); + wait_more_data.notify_all(); + } +} + +Chunk ShuffleBlockTable::popChunk() +{ + std::unique_lock lock(mutex); + while (chunks.empty()) + { + if (!is_sink_finished) + { + wait_more_data.wait(lock, [&] { return is_sink_finished || !chunks.empty(); }); + } + else + { + break; + } + } + LOG_TRACE(logger, "{}.{} popChunk. isSinkFinished()={}, chunks.size()={}", session_id, table_id, is_sink_finished, chunks.size()); + + Chunk res; + if (likely(!chunks.empty())) + { + res.swap(chunks.front()); + chunks.pop_front(); + if (unlikely(!res.hasRows())) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Chunk should not be empty"); + } + } + lock.unlock(); + return res; +} + +ShuffleBlockSession::ShuffleBlockSession(const String & session_id_, ContextPtr context_) : session_id(session_id_), context(context_) +{ + created_timestamp = Poco::Timestamp().raw()/1000000; + timeout_second = context->getSettings().shuffle_storage_session_timeout; +} + +ShuffleBlockTablePtr ShuffleBlockSession::getTable(const String & table_id_, bool wait_created) +{ + std::unique_lock lock(mutex); + auto iter = tables.find(table_id_); + if (iter == tables.end()) + { + if (!wait_created) + { + LOG_INFO(logger, "Table({}) not found in session({})", table_id_, session_id); + return nullptr; + } + new_table_cond.wait(lock, [&]{iter = tables.find(table_id_); return iter != tables.end();}); + } + return iter->second; +} + +ShuffleBlockTablePtr ShuffleBlockSession::getOrSetTable(const String & table_id_, const Block & header_) +{ + ShuffleBlockTablePtr table; + bool is_new_table = false; + { + std::lock_guard lock(mutex); + auto iter = tables.find(table_id_); + if (iter == tables.end()) + { + LOG_TRACE(logger, "create new blocks table:{}.{}", session_id, table_id_); + table = std::make_shared(session_id, table_id_, header_); + tables[table_id_] = table; + is_new_table = true; + } + else + { + table = iter->second; + } + } + if (is_new_table) + { + new_table_cond.notify_all(); + return table; + } + + const auto & table_header = table->getHeader(); + if (!blocksHaveEqualStructure(table_header, header_)) + { + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Table({}-{}) exists with different header(), input header is :{}", + session_id, + table_id_, + table_header.dumpNames(), + header_.dumpNames()); + } + return table; +} + +void ShuffleBlockSession::releaseTable(const String & table_id_) +{ + LOG_INFO(logger, "release table {}.{}", session_id, table_id_); + size_t table_count = 0; + { + std::lock_guard lock(mutex); + auto iter = tables.find(table_id_); + if (iter != tables.end()) + { + iter->second->makeSinkFinished(); + } + tables.erase(table_id_); + table_count = tables.size(); + } + if (!table_count) + { + ShuffleBlockTableManager::getInstance().tryCloseSession(session_id); + } +} + +bool ShuffleBlockSession::isTimeout() const +{ + UInt64 now = Poco::Timestamp().raw()/1000000; + return (created_timestamp + timeout_second < now); +} + +ShuffleBlockTableManager & ShuffleBlockTableManager::getInstance() +{ + static ShuffleBlockTableManager storage; + return storage; +} + +ShuffleBlockSessionPtr ShuffleBlockTableManager::getSession(const String & session_id_) const +{ + std::lock_guard lock(mutex); + + auto iter = sessions.find(session_id_); + if (iter == sessions.end()) + { + LOG_INFO(logger, "Session() not found.", session_id_); + return nullptr; + } + return iter->second; +} + +ShuffleBlockSessionPtr ShuffleBlockTableManager::getOrSetSession(const String & session_id_, ContextPtr context_) +{ + std::lock_guard lock(mutex); + clearTimeoutSession(); + + auto iter = sessions.find(session_id_); + if (iter == sessions.end()) + { + LOG_TRACE(logger, "create new session:{}", session_id_); + auto session = std::make_shared(session_id_, context_); + sessions[session_id_] = session; + return session; + } + return iter->second; +} + +void ShuffleBlockTableManager::closeSession(const String & session_id_) +{ + LOG_TRACE(logger, "close session:{}", session_id_); + std::lock_guard lock(mutex); + sessions.erase(session_id_); +} + +void ShuffleBlockTableManager::tryCloseSession(const String & session_id_) +{ + std::lock_guard lock(mutex); + auto iter = sessions.find(session_id_); + if (iter == sessions.end()) + { + LOG_TRACE(logger, "try to close a non-exists session:{}", session_id_); + return; + } + auto & session = iter->second; + + if (session->getTablesNumber()) + { + LOG_INFO(logger, "session({}) has tables which are in used", session_id_); + return; + } + LOG_INFO(logger, "close session:{}", session_id_); + sessions.erase(session_id_); +} + +void ShuffleBlockTableManager::clearTimeoutSession() +{ + for (auto it = sessions.begin(); it != sessions.end();) + { + if (it->second->isTimeout()) + { + LOG_TRACE(&Poco::Logger::get("ShuffleBlockTableManager"), "Clear timeoout session: {}", it->first); + ProfileEvents::increment(ProfileEvents::ClearTimeoutShuffleStorageSession, 1); + sessions.erase(it++); + } + else + it++; + } +} + + +} diff --git a/src/Storages/DistributedShuffle/ShuffleBlockTable.h b/src/Storages/DistributedShuffle/ShuffleBlockTable.h new file mode 100644 index 000000000000..1fe2f9b23bbb --- /dev/null +++ b/src/Storages/DistributedShuffle/ShuffleBlockTable.h @@ -0,0 +1,141 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +/// +/// How to clear all the data when a query session has finished ? +/// The following measures were taken at current +/// 1)Chunks in ShuffleBlockTable are read only once, so we use popChunkWithoutMutex() for loading a chunk. +/// That ensures that all chunks are released after the loading finish. +/// 2) When ShuffleBlockTable becomes empty, it will call ShuffleBlockSession::releaseTable() to +/// release it-self. +/// 3) When ShuffleBlockSession becomes empty, it will call ShuffleBlockTableManager::tryCloseSession() to +/// release it-self. +/// All above will ensure all datas are released in normal processing. But more need be considered, exceptions could +/// happen during the processing which make the release actions not be called. Some measures may be token. +/// 1) In TCPHandler, catch all exceptions , and make a session releasing action on all nodes +/// 2) All sessions have a max TTL, make background routine to check timeout sessions and clear them. +/// + +class ShuffleBlockTable +{ +public: + using ChunkIterator = std::list::iterator; + explicit ShuffleBlockTable( + const String & session_id_, + const String table_id_, + const Block & header_) + : session_id(session_id_) + , table_id(table_id_) + , header(header_) + {} + + ~ShuffleBlockTable() + { + LOG_TRACE(logger, "close table {}.{}", session_id, table_id); + } + + inline const Block & getHeader() const + { + return header; + } + inline const String & getSessionId() const { return session_id; } + inline const String & getTableId() const { return table_id; } + inline size_t getChunksNum() const { return chunks.size(); } + + + // TODO : Should make merge action to reduce small size chunks? + void addChunk(Chunk && chunk); + + Chunk popChunk(); + + void makeSinkFinished() + { + is_sink_finished = true; + LOG_INFO(logger, "{}.{} has total rows:{}", session_id, table_id, rows); + wait_more_data.notify_all(); + } + +private: + std::mutex mutex; + String session_id; + String table_id; + Block header; + std::atomic is_sink_finished = false; + std::list chunks; + std::condition_variable wait_more_data; + Poco::Logger * logger = &Poco::Logger::get("ShuffleBlockTable"); + size_t rows = 0; +}; +using ShuffleBlockTablePtr = std::shared_ptr; + +class ShuffleBlockSession +{ +public: + using Table = ShuffleBlockTable; + using TablePtr = ShuffleBlockTablePtr; + explicit ShuffleBlockSession(const String & session_id_, ContextPtr context_); + + TablePtr getTable(const String & table_id_, bool wait_created = false); + TablePtr getOrSetTable(const String & table_id_, const Block & header_); + void releaseTable(const String & table_id_); + + size_t getTablesNumber() const + { + std::lock_guard lock{mutex}; + return tables.size(); + } + + bool isTimeout() const; +private: + Poco::Logger * logger = &Poco::Logger::get("ShuffleBlockSession"); + String session_id; + ContextPtr context; + UInt64 created_timestamp; + UInt64 timeout_second; + mutable std::mutex mutex; + std::condition_variable new_table_cond; + std::unordered_map> tables; +}; +using ShuffleBlockSessionPtr = std::shared_ptr; + +class ShuffleBlockTableManager : public boost::noncopyable +{ +public: + using Session = ShuffleBlockSession; + using SessionPtr = ShuffleBlockSessionPtr; + + static ShuffleBlockTableManager & getInstance(); + SessionPtr getSession(const String & session_id_) const; + SessionPtr getOrSetSession(const String & session_id_, ContextPtr context_); + + void closeSession(const String & session_id_); + void tryCloseSession(const String & session_id_); +protected: + ShuffleBlockTableManager() = default; +private: + Poco::Logger * logger = &Poco::Logger::get("ShuffleBlockTableManager"); + + mutable std::mutex mutex; + std::unordered_map sessions; + + void clearTimeoutSession(); +}; + + +} diff --git a/src/Storages/DistributedShuffle/StorageShuffle.cpp b/src/Storages/DistributedShuffle/StorageShuffle.cpp new file mode 100644 index 000000000000..e94ea8ddc501 --- /dev/null +++ b/src/Storages/DistributedShuffle/StorageShuffle.cpp @@ -0,0 +1,771 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +class StorageShuffleSource : public SourceWithProgress, WithContext +{ +public: + StorageShuffleSource(ContextPtr context_, const String & session_id_, const String & table_id_, const Block & header_) + : SourceWithProgress(header_), WithContext(context_), session_id(session_id_), table_id(table_id_), header(header_) + { + } + + ~StorageShuffleSource() override + { + if (table) + { + session->releaseTable(table_id); + } + } + + String getName() const override { return "StorageShuffleSource"; } + Chunk generate() override + { + tryInitialize(); + if (unlikely(!table)) + { + LOG_INFO(logger, "{}.{} is not found.", session_id, table_id); + return {}; + } + Chunk res = table->popChunk(); + read_rows += res.getNumRows(); + return res; + } + +private: + Poco::Logger * logger = &Poco::Logger::get("StorageShuffleSource"); + bool has_initialized = false; + String session_id; + String table_id; + Block header; + ShuffleBlockSessionPtr session; + ShuffleBlockTablePtr table; + size_t read_rows = 0; + + void tryInitialize() + { + if (likely(has_initialized)) + return; + session = ShuffleBlockTableManager::getInstance().getOrSetSession(session_id, getContext()); + if (session) + { + table = session->getTable(table_id, true); + if (!table) + { + LOG_TRACE(logger, "Not found table:{}-{}", session_id, table_id); + } + } + else + { + LOG_TRACE(logger, "Not found session:{}", session_id); + } + + has_initialized = true; + } +}; + +class StorageShuffleSink : public SinkToStorage +{ +public: + explicit StorageShuffleSink( + ContextPtr context_, + const String & cluster_, + const String & session_id_, + const String & table_id_, + const Block & header_, + const ColumnsDescription & columns_, + ASTPtr hash_expr_list_) + : SinkToStorage(header_) + , context(context_) + , cluster(cluster_) + , session_id(session_id_) + , table_id(table_id_) + , columns_desc(columns_) + , hash_expr_list(hash_expr_list_) + { + } + + void onFinish() override + { + if (has_initialized) + { + { + std::unique_lock chunk_lock(pending_blocks_mutex); + consume_finished = true; + } + pending_blocks_cond.notify_all(); + sink_thread_pool->wait(); + + for (auto & inserter : inserters) + { + std::lock_guard lock(inserter->mutex); + inserter->flush(); + inserter->inserter->onFinish(); + } + } + if (watch) + { + watch->stop(); + size_t elapse = watch->elapsedMilliseconds(); + LOG_INFO(logger, "{}.{} sink elapsed. {}", session_id, table_id, elapse); + } + } + String getName() const override { return "StorageShuffleSink"; } + +protected: + // split the block into multi blocks, and send to different nodes + void consume(Chunk chunk) override + { + if (!has_initialized) [[unlikely]] + { + initOnce(); + } + { + std::lock_guard lock(pending_blocks_mutex); + pending_blocks.emplace_back(getInputPort().getHeader().cloneWithColumns(chunk.detachColumns())); + } + pending_blocks_cond.notify_one(); + } + +private: + Poco::Logger * logger = &Poco::Logger::get("StorageShuffleSink"); + ContextPtr context; + String cluster; + String session_id; + String table_id; + ColumnsDescription columns_desc; + ASTPtr hash_expr_list; + Strings hash_expr_columns_names; + + std::mutex init_mutex; + std::atomic has_initialized = false; + + std::atomic consume_finished = false; + std::mutex pending_blocks_mutex; + std::condition_variable pending_blocks_cond; + std::list pending_blocks; + std::unique_ptr sink_thread_pool; + + + struct InternalInserter + { + std::mutex mutex; + Blocks pending_blocks; + size_t current_rows = 0; + size_t max_rows_limit = DEFAULT_BLOCK_SIZE; + std::shared_ptr inserter; + void tryWrite(const Block & block) + { + if (current_rows + block.rows() > max_rows_limit && !pending_blocks.empty()) + { + auto to_send_block = concatenateBlocks(pending_blocks); + inserter->write(to_send_block); + pending_blocks.clear(); + current_rows = 0; + } + current_rows += block.rows(); + pending_blocks.push_back(block); + } + + void flush() + { + if (!pending_blocks.empty()) + { + auto to_send_block = concatenateBlocks(pending_blocks); + inserter->write(to_send_block); + current_rows = 0; + pending_blocks.clear(); + } + } + }; + using InternalInserterPtr = std::shared_ptr; + std::vector inserters; + std::vector> node_connections; + std::shared_ptr hash_expr_cols_actions; + String hash_expr_column_name; + + std::unique_ptr watch; + + + void initOnce() + { + watch = std::make_unique(); + std::lock_guard lock(init_mutex); + if (has_initialized) + return; + initInserters(); + initHashColumnsNames(); + + size_t thread_size = context->getSettings().max_threads; + //size_t thread_size = inserters.size(); + sink_thread_pool = std::make_unique(thread_size); + for (size_t i = 0; i < thread_size; ++i) + { + sink_thread_pool->scheduleOrThrowOnError( + [&] + { + while (true) + { + Block block; + { + std::unique_lock chunk_lock(pending_blocks_mutex); + if (pending_blocks.empty() && !consume_finished) + { + pending_blocks_cond.wait(chunk_lock, [&] { return consume_finished || !pending_blocks.empty(); }); + } + if (!pending_blocks.empty()) + { + block.swap(pending_blocks.front()); + pending_blocks.pop_front(); + } + else if (consume_finished) + { + break; + } + } + if (block.rows()) + { + std::vector split_blocks; + splitBlock(block, split_blocks); + sendBlocks(split_blocks); + } + } + }); + } + has_initialized = true; + } + + void initInserters() + { + // prepare insert sql + String insert_sql; + auto names_and_types = columns_desc.getAllPhysical(); + WriteBufferFromOwnString write_buf; + auto names = names_and_types.getNames(); + auto types = names_and_types.getTypes(); + for (size_t i = 0; i < names.size(); ++i) + { + if (i) + write_buf << ","; + write_buf << names[i] << " " << types[i]->getName(); + } + insert_sql = fmt::format( + "INSERT INTO FUNCTION {}('{}', '{}', '{}', '{}') VALUES", + TableFunctionLocalShuffle::name, + cluster, + session_id, + table_id, + write_buf.str()); + + const auto & settings = context->getSettings(); + // prepare remote call + auto cluster_addresses = getSortedShardAddresses(); + for (const auto & node : cluster_addresses) + { + auto connection = std::make_shared( + node.host_name, + node.port, + context->getGlobalContext()->getCurrentDatabase(), + node.user, + node.password, + node.cluster, + node.cluster_secret, + "StorageShuffleSink", + node.compression, + node.secure); + node_connections.emplace_back(connection); + auto inserter = std::make_shared( + *connection, + ConnectionTimeouts{ + settings.connect_timeout.value.seconds() * 1000, + settings.send_timeout.value.seconds() * 1000, + settings.receive_timeout.value.seconds() * 1000}, + insert_sql, + context->getSettings(), + context->getClientInfo()); + auto internal_inserter = std::make_shared(); + internal_inserter->inserter = inserter; + internal_inserter->max_rows_limit = context->getSettingsRef().max_block_size; + inserters.emplace_back(internal_inserter); + } + } + + void initHashColumnsNames() + { + //hash_expr_columns_names + for (auto & child : hash_expr_list->children) + { + hash_expr_columns_names.emplace_back(queryToString(child)); + } + } + + Cluster::Addresses getSortedShardAddresses() const + { + auto cluster_instance = context->getCluster(cluster)->getClusterWithReplicasAsShards(context->getSettings()); + Cluster::Addresses addresses; + for (const auto & replicas : cluster_instance->getShardsAddresses()) + { + for (const auto & node : replicas) + { + addresses.emplace_back(node); + } + } + std::sort( + std::begin(addresses), + std::end(addresses), + [](const Cluster::Address & a, const Cluster::Address & b) { return a.host_name > b.host_name && a.port > b.port; }); + return addresses; + } + + void splitBlock(Block & original_block, std::vector & split_blocks) + { + size_t num_rows = original_block.rows(); + size_t num_shards = inserters.size(); + Block header = original_block.cloneEmpty(); + ColumnRawPtrs hash_cols; + for (const auto & hash_col_name : hash_expr_columns_names) + { + hash_cols.push_back(original_block.getByName(hash_col_name).column.get()); + } + IColumn::Selector selector(num_rows); + for (size_t i = 0; i < num_rows; ++i) + { + SipHash hash; + for (const auto & hash_col : hash_cols) + { + hash_col->updateHashWithValue(i, hash); + } + selector[i] = hash.get64() % num_shards; + } + + for (size_t i = 0; i < num_shards; ++i) + { + split_blocks.emplace_back(original_block.cloneEmpty()); + } + + auto columns_in_block = header.columns(); + for (size_t i = 0; i < columns_in_block; ++i) + { + auto split_columns = original_block.getByPosition(i).column->scatter(num_shards, selector); + for (size_t block_index = 0; block_index < num_shards; ++block_index) + { + split_blocks[block_index].getByPosition(i).column = std::move(split_columns[block_index]); + } + } + } + void sendBlocks(std::vector & blocks) + { + std::list to_send_blocks; + for (size_t i = 0, sz = blocks.size(); i < sz; ++i) + to_send_blocks.emplace_back(i); + while (!to_send_blocks.empty()) + { + for (auto iter = to_send_blocks.begin(); iter != to_send_blocks.end();) + { + auto & inserter = inserters[*iter]; + if (inserter->mutex.try_lock()) + { + auto & block = blocks[*iter]; + if (block.rows()) + { + inserter->tryWrite(block); + //inserter->inserter->write(block); + } + inserter->mutex.unlock(); + iter = to_send_blocks.erase(iter); + } + else + { + iter++; + } + } + } + } +}; + +class StorageLocalShuffleSink : public SinkToStorage +{ +public: + explicit StorageLocalShuffleSink(ContextPtr context_, const String & session_id_, const String & table_id_, const Block & header_) + : SinkToStorage(header_), context(context_), session_id(session_id_), table_id(table_id_) + { + auto session = ShuffleBlockTableManager::getInstance().getOrSetSession(session_id_, context_); + if (!session) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Get session({}) storage failed.", session_id_); + table_storage = session->getOrSetTable(table_id_, header_); + if (!table_storage) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Get session table({}-{}) failed.", session_id_, table_id_); + } + + void onFinish() override { LOG_INFO(logger, "{}.{} sink elapsed:{}", session_id, table_id, watch.elapsedMilliseconds()); } + String getName() const override { return "StorageLocalShuffleSink"; } + +protected: + void consume(Chunk chunk) override { table_storage->addChunk(std::move(chunk)); } + +private: + ContextPtr context; + String session_id; + String table_id; + ShuffleBlockTablePtr table_storage; + Poco::Logger * logger = &Poco::Logger::get("StorageLocalShuffleSink"); + Stopwatch watch; +}; + +/// +/// FIXEDME: Maybe need the initiator pass the cluster nodes, not query by the worker nodes. Since the cluster nodes set may change +/// +StorageShuffleBase::StorageShuffleBase( + ContextPtr context_, + ASTPtr query_, + const String & cluster_name_, + const String & session_id_, + const String & table_id_, + const ColumnsDescription & columns_, + ASTPtr hash_expr_list_) + : IStorage(StorageID(session_id_, table_id_)) + , WithContext(context_) + , query(query_) + , cluster_name(cluster_name_) + , session_id(session_id_) + , table_id(table_id_) + , hash_expr_list(hash_expr_list_) +{ + StorageInMemoryMetadata storage_metadata; + storage_metadata.setColumns(columns_); + setInMemoryMetadata(storage_metadata); +} + +Pipe StorageShuffleBase::read( + const Names & column_names_, + const StorageSnapshotPtr & metadata_snapshot_, + SelectQueryInfo & query_info_, + ContextPtr context_, + QueryProcessingStage::Enum processed_stage_, + size_t /*max_block_size_*/, + unsigned num_streams) +{ + auto header = getInMemoryMetadata().getSampleBlock(); + + auto query_kind = context_->getClientInfo().query_kind; + if (query_kind != ClientInfo::QueryKind::INITIAL_QUERY) + { + auto source = std::make_shared(context_, session_id, table_id, header); + Pipe res(source); + res.resize(num_streams); + return res; + } + /// Since the query_info_.query has been rewritten, it may cause an ambiguous column exception in join case. + /// So we use the original_query here. + auto remote_query = queryToString(query_info_.original_query); + auto cluster = context_->getCluster(cluster_name)->getClusterWithReplicasAsShards(context_->getSettings()); + const Scalars & scalars = context_->hasQueryContext() ? context_->getQueryContext()->getScalars() : Scalars{}; + header = InterpreterSelectQuery(query_info_.query, context_, SelectQueryOptions(processed_stage_).analyze()).getSampleBlock(); + Pipes pipes; + for (const auto & replicas : cluster->getShardsAddresses()) + { + for (const auto & node : replicas) + { + auto connection = std::make_shared( + node.host_name, + node.port, + context_->getGlobalContext()->getCurrentDatabase(), + node.user, + node.password, + node.cluster, + node.cluster_secret, + "StorageShuffleBase", + node.compression, + node.secure); + + auto remote_query_executor = std::make_shared( + connection, remote_query, header, context_, nullptr, scalars, Tables(), processed_stage_, RemoteQueryExecutor::Extension{}); + //LOG_TRACE9logger, "run query on node:{}. query:{}", node.host_name, remote_query); + pipes.emplace_back(std::make_shared(remote_query_executor, false, false)); + } + } + metadata_snapshot_->check(column_names_); + auto res = Pipe::unitePipes(std::move(pipes)); + res.resize(num_streams); + return res; +} + +SinkToStoragePtr StorageShuffleBase::write(const ASTPtr & /*ast*/, const StorageMetadataPtr & /*storage_metadata*/, ContextPtr context_) +{ + // If there is no hash expression, just move blocks into local the shuffle block table + SinkToStoragePtr sinker; + if (hash_expr_list) + sinker = std::make_shared( + context_, + cluster_name, + session_id, + table_id, + getInMemoryMetadata().getSampleBlock(), + getInMemoryMetadata().getColumns(), + hash_expr_list); + else + sinker = std::make_shared(context_, session_id, table_id, getInMemoryMetadata().getSampleBlock()); + return sinker; +} + +QueryProcessingStage::Enum StorageShuffleBase::getQueryProcessingStage( + ContextPtr local_context, + QueryProcessingStage::Enum to_stage, + const StorageSnapshotPtr & /*metadata_snapshot*/, + SelectQueryInfo & query_info) const +{ + if (local_context->getClientInfo().query_kind == ClientInfo::QueryKind::INITIAL_QUERY) + { + /// When there is join in the query, cannot enable the two phases processing. It will cause + /// a column missing exception, if the result column is in the right table but not in the left table + auto select_query = query_info.query->as(); + if (select_query.join()) + return QueryProcessingStage::FetchColumns; + + if (to_stage >= QueryProcessingStage::WithMergeableState) + return QueryProcessingStage::WithMergeableState; + } + + return QueryProcessingStage::FetchColumns; +} + +StorageShuffleJoin::StorageShuffleJoin( + ContextPtr context_, + ASTPtr query_, + const String & cluster_name_, + const String & session_id_, + const String & table_id_, + const ColumnsDescription & columns_, + ASTPtr hash_expr_list_) + : StorageShuffleBase(context_, query_, cluster_name_, session_id_, table_id_, columns_, hash_expr_list_) +{ + logger = &Poco::Logger::get("StorageShuffleJoin"); +} + + +StorageShuffleAggregation::StorageShuffleAggregation( + ContextPtr context_, + ASTPtr query_, + const String & cluster_name_, + const String & session_id_, + const String & table_id_, + const ColumnsDescription & columns_, + ASTPtr hash_expr_list_) + : StorageShuffleBase(context_, query_, cluster_name_, session_id_, table_id_, columns_, hash_expr_list_) +{ + logger = &Poco::Logger::get("StorageShuffleAggregation"); +} + + +StorageLocalShuffle::StorageLocalShuffle( + ContextPtr context_, + ASTPtr query_, + const String & cluster_name_, + const String & session_id_, + const String & table_id_, + const ColumnsDescription & columns_) + : IStorage(StorageID(session_id_, table_id_ + "_part")) + , WithContext(context_) + , query(query_) + , cluster_name(cluster_name_) + , session_id(session_id_) + , table_id(table_id_) +{ + StorageInMemoryMetadata storage_metadata; + storage_metadata.setColumns(columns_); + setInMemoryMetadata(storage_metadata); +} + +Pipe StorageLocalShuffle::read( + const Names & column_names_, + const StorageSnapshotPtr & metadata_snapshot_, + SelectQueryInfo & query_info_, + ContextPtr context_, + QueryProcessingStage::Enum processed_stage_, + size_t /*max_block_size_*/, + unsigned /*num_streams_*/) +{ + auto header = getInMemoryMetadata().getSampleBlock(); + auto query_kind = context_->getClientInfo().query_kind; + if (query_kind != ClientInfo::QueryKind::INITIAL_QUERY) + { + return Pipe(std::make_shared(context_, session_id, table_id, header)); + } + auto remote_query = queryToString(query_info_.original_query); + auto cluster = context_->getCluster(cluster_name)->getClusterWithReplicasAsShards(context_->getSettings()); + const Scalars & scalars = context_->hasQueryContext() ? context_->getQueryContext()->getScalars() : Scalars{}; + header = InterpreterSelectQuery(query_info_.query, context_, SelectQueryOptions(processed_stage_).analyze()).getSampleBlock(); + Pipes pipes; + for (const auto & replicas : cluster->getShardsAddresses()) + { + for (const auto & node : replicas) + { + auto connection = std::make_shared( + node.host_name, + node.port, + context_->getGlobalContext()->getCurrentDatabase(), + node.user, + node.password, + node.cluster, + node.cluster_secret, + "StorageLocalShuffle", + node.compression, + node.secure); + + auto remote_query_executor = std::make_shared( + connection, remote_query, header, context_, nullptr, scalars, Tables(), processed_stage_, RemoteQueryExecutor::Extension{}); + //LOG_TRACE9logger, "run query on node:{}. query:{}", node.host_name, remote_query); + pipes.emplace_back(std::make_shared(remote_query_executor, false, false)); + } + } + metadata_snapshot_->check(column_names_); + return Pipe::unitePipes(std::move(pipes)); +} + +SinkToStoragePtr StorageLocalShuffle::write(const ASTPtr & /*ast*/, const StorageMetadataPtr & /*storage_metadata*/, ContextPtr context_) +{ + auto sinker = std::make_shared(context_, session_id, table_id, getInMemoryMetadata().getSampleBlock()); + return sinker; +} + + +QueryProcessingStage::Enum StorageLocalShuffle::getQueryProcessingStage( + ContextPtr local_context, + QueryProcessingStage::Enum to_stage, + const StorageSnapshotPtr & /*metadata_snapshot*/, + SelectQueryInfo & query_info) const +{ + //LOG_TRACE9logger, "query:{}, to_stage:{}, query_kind:{}", queryToString(query_info.query), to_stage, local_context->getClientInfo().query_kind); + if (local_context->getClientInfo().query_kind == ClientInfo::QueryKind::INITIAL_QUERY) + { + // When there is join in the query, cannot enable the two phases processing. It will cause + // a column missing exception, if the result column is in the right table but not in the left table + auto select_query = query_info.query->as(); + if (select_query.join()) + return QueryProcessingStage::FetchColumns; + + if (to_stage >= QueryProcessingStage::WithMergeableState) + return QueryProcessingStage::WithMergeableState; + } + + return QueryProcessingStage::FetchColumns; +} + +StorageShuffleClose::StorageShuffleClose( + ContextPtr context_, + ASTPtr query_, + const ColumnsDescription & columns_, + const String & cluster_name_, + const String & session_id_, + const String & table_id_) + : IStorage(StorageID(session_id_, table_id_ + "_closed")) + , WithContext(context_) + , query(query_) + , cluster_name(cluster_name_) + , session_id(session_id_) + , table_id(table_id_) +{ + StorageInMemoryMetadata storage_metadata; + storage_metadata.setColumns(columns_); + setInMemoryMetadata(storage_metadata); +} + +Pipe StorageShuffleClose::read( + const Names & /*column_names_*/, + const StorageSnapshotPtr & /*metadata_snapshot_*/, + SelectQueryInfo & /*query_info_*/, + ContextPtr /*context_*/, + QueryProcessingStage::Enum /*processed_stage_*/, + size_t /*max_block_size_*/, + unsigned /*num_streams_*/) +{ + throw Exception(ErrorCodes::LOGICAL_ERROR, "StorageShuffleClose has no read implementation"); +} + +class StorageShuffleCloseSink : public SinkToStorage +{ +public: + explicit StorageShuffleCloseSink( + ContextPtr context_, const String & cluster_name_, const String & session_id_, const String & table_id_, const Block & header_) + : SinkToStorage(header_) + , context(context_) + , cluster_name(cluster_name_) + , session_id(session_id_) + , table_id(table_id_) + , header(header_) + { + } + + String getName() const override { return "StorageShuffleCloseSink"; } + +protected: + void consume(Chunk /*chunk*/) override + { + auto session = ShuffleBlockTableManager::getInstance().getSession(session_id); + if (!session) + return; + auto table = session->getTable(table_id); + if (!table) + return; + table->makeSinkFinished(); + LOG_INFO(logger, "mark table {}.{} sink finished", session_id, table_id); + } + +private: + ContextPtr context; + String cluster_name; + String session_id; + String table_id; + Block header; + Poco::Logger * logger = &Poco::Logger::get("StorageShuffleCloseSink"); +}; + +SinkToStoragePtr StorageShuffleClose::write(const ASTPtr & /*ast*/, const StorageMetadataPtr & /*storage_metadata*/, ContextPtr context_) +{ + auto sinker + = std::make_shared(context_, cluster_name, session_id, table_id, getInMemoryMetadata().getSampleBlock()); + return sinker; +} + +} diff --git a/src/Storages/DistributedShuffle/StorageShuffle.h b/src/Storages/DistributedShuffle/StorageShuffle.h new file mode 100644 index 000000000000..7c6a19f86e2e --- /dev/null +++ b/src/Storages/DistributedShuffle/StorageShuffle.h @@ -0,0 +1,179 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace DB +{ + + +class StorageShuffleBase : public IStorage, WithContext +{ +public: + virtual String getName() const override = 0; + Pipe read( + const Names & column_names_, + const StorageSnapshotPtr & metadata_snapshot_, + SelectQueryInfo & query_info_, + ContextPtr context_, + QueryProcessingStage::Enum processed_stage_, + size_t max_block_size_, + unsigned num_streams_) override; + + SinkToStoragePtr write( + const ASTPtr & ast, + const StorageMetadataPtr & storage_metadata, + ContextPtr context + ) override; + + StorageShuffleBase( + ContextPtr context_, + ASTPtr query_, + const String & cluster_name_, + const String & session_id_, + const String & table_id_, + const ColumnsDescription & columns_, + ASTPtr hash_expr_list_); + QueryProcessingStage::Enum getQueryProcessingStage( + ContextPtr local_context, + QueryProcessingStage::Enum to_stage, + const StorageSnapshotPtr & metadata_snapshot, + SelectQueryInfo & query_info) const override; + + /// + /// Do not set it true until we find other a way to signal the table has finished sinking + /// + bool supportsParallelInsert() const override { return false; } +protected: + Poco::Logger * logger; + ASTPtr query; + String cluster_name; + String session_id; + String table_id; + ASTPtr hash_expr_list; +}; + +/// +/// StorageShuffleJoin and StorageShuffleAggregation have the same behavior. the only difference is the name +/// for some special usage purpose. +/// +class StorageShuffleJoin : public StorageShuffleBase +{ +public: + static constexpr auto NAME = "StorageShuffleJoin"; + String getName() const override { return NAME; } + + StorageShuffleJoin( + ContextPtr context_, + ASTPtr query_, + const String & cluster_name_, + const String & session_id_, + const String & table_id_, + const ColumnsDescription & columns_, + ASTPtr hash_expr_list_); +}; + +class StorageShuffleAggregation : public StorageShuffleBase +{ +public: + static constexpr auto NAME = "StorageShuffleAggregation"; + String getName() const override { return NAME; } + StorageShuffleAggregation( + ContextPtr context_, + ASTPtr query_, + const String & cluster_name_, + const String & session_id_, + const String & table_id_, + const ColumnsDescription & columns_, + ASTPtr hash_expr_list_); +}; + +/// +/// use for reading/writing the local ShuffleBlockTable +/// + +class StorageLocalShuffle : public IStorage, WithContext +{ +public: + static constexpr auto NAME = "StorageLocalShuffle"; + String getName() const override { return NAME; } + Pipe read( + const Names & column_names_, + const StorageSnapshotPtr & metadata_snapshot_, + SelectQueryInfo & query_info_, + ContextPtr context_, + QueryProcessingStage::Enum processed_stage_, + size_t max_block_size_, + unsigned num_streams_) override; + SinkToStoragePtr write( + const ASTPtr & ast, + const StorageMetadataPtr & storage_metadata, + ContextPtr context + ) override; + + StorageLocalShuffle( + ContextPtr context_, + ASTPtr query_, + const String & cluster_name_, + const String & session_id_, + const String & table_id_, + const ColumnsDescription & columns_ + ); + + QueryProcessingStage::Enum getQueryProcessingStage( + ContextPtr local_context, + QueryProcessingStage::Enum to_stage, + const StorageSnapshotPtr & metadata_snapshot, + SelectQueryInfo & query_info) const override; +private: + Poco::Logger * logger = &Poco::Logger::get("StorageLocalShuffle"); + ASTPtr query; + String cluster_name; + String session_id; + String table_id; +}; + + +/// +/// Use to close a shuffle table +/// + +class StorageShuffleClose : public IStorage, WithContext +{ +public: + static constexpr auto NAME = "StorageShuffleClose"; + String getName() const override { return NAME; } + Pipe read( + const Names & column_names_, + const StorageSnapshotPtr & metadata_snapshot_, + SelectQueryInfo & query_info_, + ContextPtr context_, + QueryProcessingStage::Enum processed_stage_, + size_t max_block_size_, + unsigned num_streams_) override; + + SinkToStoragePtr write( + const ASTPtr & ast, + const StorageMetadataPtr & storage_metadata, + ContextPtr context_) override; + + StorageShuffleClose( + ContextPtr context_, + ASTPtr query_, + const ColumnsDescription & columns_, + const String & cluster_name_, + const String & session_id_, + const String & table_id_); + +private: + ASTPtr query; + String cluster_name; + String session_id; + String table_id; + +}; + +} diff --git a/src/TableFunctions/TableFunctionShuffle.cpp b/src/TableFunctions/TableFunctionShuffle.cpp new file mode 100644 index 000000000000..e91d40e4adc0 --- /dev/null +++ b/src/TableFunctions/TableFunctionShuffle.cpp @@ -0,0 +1,197 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +void TableFunctionLocalShuffle::parseArguments(const ASTPtr & ast_function_, ContextPtr context_) +{ + ASTs & args_func = ast_function_->children; + if (args_func.size() != 1) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Table function '{}' must have arguments.", getName()); + + ASTs & args = args_func.at(0)->children; + String usage_message = fmt::format( + "The signature of function {} is:\b" + "- session_id, table_id, table structure descrition", + getName()); + + if (args.size() < 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, usage_message); + + for (auto & arg : args) + arg = evaluateConstantExpressionOrIdentifierAsLiteral(arg, context_); + + cluster_name = args[0]->as().value.safeGet(); + session_id = args[1]->as().value.safeGet(); + table_id = args[2]->as().value.safeGet(); + table_structure = args[3]->as().value.safeGet(); + + columns = parseColumnsListFromString(table_structure, context_); +} +ColumnsDescription TableFunctionLocalShuffle::getActualTableStructure(ContextPtr) const +{ + return columns; +} + +StoragePtr TableFunctionLocalShuffle::executeImpl( + const ASTPtr & ast_function, ContextPtr context, const std::string & /*table_name*/, ColumnsDescription /*cached_columns*/) const +{ + StoragePtr storage = std::make_shared(context, ast_function, cluster_name, session_id, table_id, columns); + return storage; +} + +void TableFunctionShuffleJoin::parseArguments(const ASTPtr & ast_function_, ContextPtr context_) +{ + ASTs & args_func = ast_function_->children; + if (args_func.size() != 1) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Table function '{}' must have arguments.", getName()); + + ASTs & args = args_func.at(0)->children; + String usage_message = fmt::format( + "The signature of function {} is:\b" + "- cluster_name, session_id, table_id, table structure descrition, [hash key expression list]", + getName()); + + if (args.size() < 4) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, usage_message); + + for (auto & arg : args) + arg = evaluateConstantExpressionOrIdentifierAsLiteral(arg, context_); + + cluster_name = args[0]->as().value.safeGet(); + session_id = args[1]->as().value.safeGet(); + table_id = args[2]->as().value.safeGet(); + table_structure = args[3]->as().value.safeGet(); + columns = parseColumnsListFromString(table_structure, context_); + if (args.size() >= 5) + table_hash_exprs = args[4]->as().value.safeGet(); + + auto settings = context_->getSettings(); + ParserExpressionList hash_expr_list_parser(true); + if (!table_hash_exprs.empty()) + hash_expr_list_ast = parseQuery( + hash_expr_list_parser, table_hash_exprs, "Parsing table hash keys", settings.max_query_size, settings.max_parser_depth); +} +ColumnsDescription TableFunctionShuffleJoin::getActualTableStructure(ContextPtr) const +{ + return columns; +} + +StoragePtr TableFunctionShuffleJoin::executeImpl( + const ASTPtr & ast_function, ContextPtr context, const std::string & /*table_name*/, ColumnsDescription /*cached_columns*/) const +{ + StoragePtr storage = std::make_shared(context, ast_function, cluster_name, session_id, table_id, columns, hash_expr_list_ast); + return storage; +} + +void TableFunctionShuffleAggregation::parseArguments(const ASTPtr & ast_function_, ContextPtr context_) +{ + ASTs & args_func = ast_function_->children; + if (args_func.size() != 1) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Table function '{}' must have arguments.", getName()); + + ASTs & args = args_func.at(0)->children; + String usage_message = fmt::format( + "The signature of function {} is:\b" + "- cluster_name, session_id, table_id, table structure descrition, [hash key expression list]", + getName()); + + if (args.size() < 4) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, usage_message); + + for (auto & arg : args) + arg = evaluateConstantExpressionOrIdentifierAsLiteral(arg, context_); + + cluster_name = args[0]->as().value.safeGet(); + session_id = args[1]->as().value.safeGet(); + table_id = args[2]->as().value.safeGet(); + table_structure = args[3]->as().value.safeGet(); + columns = parseColumnsListFromString(table_structure, context_); + if (args.size() >= 5) + table_hash_exprs = args[4]->as().value.safeGet(); + + auto settings = context_->getSettings(); + ParserExpressionList hash_expr_list_parser(true); + if (!table_hash_exprs.empty()) + hash_expr_list_ast = parseQuery( + hash_expr_list_parser, table_hash_exprs, "Parsing table hash keys", settings.max_query_size, settings.max_parser_depth); +} + +ColumnsDescription TableFunctionShuffleAggregation::getActualTableStructure(ContextPtr /*context*/) const +{ + return columns; +} + +StoragePtr TableFunctionShuffleAggregation::executeImpl( + const ASTPtr & ast_function, ContextPtr context, const std::string & /*table_name*/, ColumnsDescription /*cached_columns*/) const +{ + StoragePtr storage = std::make_shared(context, ast_function, cluster_name, session_id, table_id, columns, hash_expr_list_ast); + LOG_TRACE(&Poco::Logger::get("TableFunctionShuffleAggregation"), "create agg storage. {}", storage->getName()); + return storage; +} + + +void TableFunctionClosedShuffle::parseArguments(const ASTPtr & ast_function, ContextPtr context) +{ + ASTs & args_func = ast_function->children; + if (args_func.size() != 1) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Table function '{}' must have arguments.", getName()); + + ASTs & args = args_func.at(0)->children; + String usage_message = fmt::format( + "The signature of function {} is:\b" + "- cluster_name, session_id, table_id", + getName()); + + if (args.size() < 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, usage_message); + + for (auto & arg : args) + arg = evaluateConstantExpressionOrIdentifierAsLiteral(arg, context); + + cluster_name = args[0]->as()->value.safeGet(); + session_id = args[1]->as()->value.safeGet(); + table_id = args[2]->as()->value.safeGet(); + + String table_structure = "n UInt32"; + columns = parseColumnsListFromString(table_structure, context); +} + +ColumnsDescription TableFunctionClosedShuffle::getActualTableStructure(ContextPtr) const +{ + return columns; +} + +StoragePtr TableFunctionClosedShuffle::executeImpl( + const ASTPtr & ast_function, ContextPtr context, const std::string & /*table_name*/, ColumnsDescription /*cached_columns*/) const +{ + StoragePtr storage = std::make_shared(context, ast_function, columns, cluster_name, session_id, table_id); + return storage; +} + +void registerTableFunctionShuffle(TableFunctionFactory & factory) +{ + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); +} +} diff --git a/src/TableFunctions/TableFunctionShuffle.h b/src/TableFunctions/TableFunctionShuffle.h new file mode 100644 index 000000000000..9384b7373ced --- /dev/null +++ b/src/TableFunctions/TableFunctionShuffle.h @@ -0,0 +1,111 @@ +#pragma once +#include +#include +#include +#include +namespace DB +{ +/** + * Only for inserting chunks into different nodes. + * + */ +class TableFunctionLocalShuffle : public ITableFunction +{ +public: + static constexpr auto name = "localShuffleStorage"; + static constexpr auto storage_type_name = "StorageLocalShuffle";// but no storage is registered + std::string getName() const override { return name; } + bool hasStaticStructure() const override { return true; } + StoragePtr executeImpl( + const ASTPtr & ast_function, ContextPtr context, const std::string & table_name, ColumnsDescription cached_columns) const override; + const char * getStorageTypeName() const override { return storage_type_name; } + ColumnsDescription getActualTableStructure(ContextPtr) const override; + void parseArguments(const ASTPtr & ast_function_, ContextPtr context_) override; + +private: + Poco::Logger * logger = &Poco::Logger::get("TableFunctionLocalShuffle"); + + String cluster_name; + String session_id; + String table_id; + String table_structure; + + ColumnsDescription columns; + +}; + +class TableFunctionShuffleJoin : public ITableFunction +{ +public: + static constexpr auto name = "shuffleJoinStorage"; + static constexpr auto storage_type_name = "StorageShuffleJoin";// but no storage is registered + std::string getName() const override { return name; } + bool hasStaticStructure() const override { return true; } + StoragePtr executeImpl( + const ASTPtr & ast_function, ContextPtr context, const std::string & table_name, ColumnsDescription cached_columns) const override; + const char * getStorageTypeName() const override { return storage_type_name; } + ColumnsDescription getActualTableStructure(ContextPtr) const override; + void parseArguments(const ASTPtr & ast_function_, ContextPtr context_) override; + +private: + Poco::Logger * logger = &Poco::Logger::get("TableFunctionShuffleJoin"); + + // followings are args + String cluster_name; + String session_id; + String table_id; + String table_structure; + String table_hash_exprs; + + ColumnsDescription columns; + ASTPtr hash_expr_list_ast; +}; + +class TableFunctionShuffleAggregation : public ITableFunction +{ +public: + static constexpr auto name = "shuffleAggregationStorage"; + static constexpr auto storage_type_name = "StorageShuffleAggregation";// but no storage is registered + std::string getName() const override { return name; } + + bool hasStaticStructure() const override { return true; } + StoragePtr executeImpl( + const ASTPtr & ast_function, ContextPtr context, const std::string & table_name, ColumnsDescription cached_columns) const override; + const char * getStorageTypeName() const override { return storage_type_name; } + ColumnsDescription getActualTableStructure(ContextPtr) const override; + void parseArguments(const ASTPtr & ast_function_, ContextPtr context_) override; + +private: + Poco::Logger * logger = &Poco::Logger::get("TableFunctionShuffleAggregation"); + + // followings are args + String cluster_name; + String session_id; + String table_id; + String table_structure; + String table_hash_exprs; + + ColumnsDescription columns; + ASTPtr hash_expr_list_ast; +}; + +class TableFunctionClosedShuffle : public ITableFunction +{ +public: + static constexpr auto name = "closedShulleStorage"; + static constexpr auto storage_type_name = "ClosedShuffleStorage"; + std::string getName() const override { return name; } + + bool hasStaticStructure() const override { return true; } + StoragePtr executeImpl( + const ASTPtr & ast_function, ContextPtr context, const std::string & table_name, ColumnsDescription cached_columns) const override; + const char * getStorageTypeName() const override { return storage_type_name; } + ColumnsDescription getActualTableStructure(ContextPtr) const override; + void parseArguments(const ASTPtr & ast_function_, ContextPtr context_) override; +private: + String cluster_name; + String session_id; + String table_id; + ColumnsDescription columns; +}; +} diff --git a/src/TableFunctions/registerTableFunctions.cpp b/src/TableFunctions/registerTableFunctions.cpp index f8f1530d587b..d2da81fbe16e 100644 --- a/src/TableFunctions/registerTableFunctions.cpp +++ b/src/TableFunctions/registerTableFunctions.cpp @@ -58,6 +58,8 @@ void registerTableFunctions() registerTableFunctionDictionary(factory); registerTableFunctionFormat(factory); + + registerTableFunctionShuffle(factory); } } diff --git a/src/TableFunctions/registerTableFunctions.h b/src/TableFunctions/registerTableFunctions.h index 5a54a64cb3c3..772ad097fdd6 100644 --- a/src/TableFunctions/registerTableFunctions.h +++ b/src/TableFunctions/registerTableFunctions.h @@ -57,6 +57,8 @@ void registerTableFunctionDictionary(TableFunctionFactory & factory); void registerTableFunctionFormat(TableFunctionFactory & factory); +void registerTableFunctionShuffle(TableFunctionFactory & factory); + void registerTableFunctions(); } From ff781b0a93bd3fa6b38973890d27b71c9d83bbbd Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Mon, 16 May 2022 17:49:24 +0800 Subject: [PATCH 2/4] fixed: session close --- src/Interpreters/ASTRewriters/ASTBuildUtil.h | 39 +++++----- .../ASTRewriters/IdentRenameRewriteAction.cpp | 5 +- .../IdentifierQualiferRemoveAction.cpp | 12 ++++ ...eryDistributedAggregationRewriteAction.cpp | 9 ++- ...StageQueryDistributedJoinRewriteAction.cpp | 34 ++++++--- src/Interpreters/InterpreterStageQuery.cpp | 23 +++++- .../DistributedShuffle/ShuffleBlockTable.cpp | 71 ++++++++++++------- .../DistributedShuffle/ShuffleBlockTable.h | 28 +++++++- .../DistributedShuffle/StorageShuffle.cpp | 16 +++-- 9 files changed, 164 insertions(+), 73 deletions(-) diff --git a/src/Interpreters/ASTRewriters/ASTBuildUtil.h b/src/Interpreters/ASTRewriters/ASTBuildUtil.h index cfc042960229..33728197c85d 100644 --- a/src/Interpreters/ASTRewriters/ASTBuildUtil.h +++ b/src/Interpreters/ASTRewriters/ASTBuildUtil.h @@ -16,16 +16,15 @@ class ASTBuildUtil static String toTableStructureDescription(const ColumnWithDetailNameAndTypes & columns); /** - * @brief Create a Shuffle Table Function object + * Create a Shuffle Table Function object * - * @param function_name which shuffle function to use. see TableFunctionShuffle.h - * @param session_id session_od - * @param cluster_name cluster name - * @param table_id table_id - * @param columns describe the table structure. etc. 'x int, y string' - * @param hash_expression_list hash expression list for shuffle hashing. etc. 'x, y' - * @param alias table alias - * @return ASTPtr ASTFunction + * - function_name which shuffle function to use. see TableFunctionShuffle.h + * - session_id session_od + * - cluster_name cluster name + * - table_id table_id + * - columns describe the table structure. etc. 'x int, y string' + * - hash_expression_list hash expression list for shuffle hashing. etc. 'x, y' + * - alias table alias */ static ASTPtr createShuffleTableFunction( const String & function_name, @@ -37,32 +36,32 @@ class ASTBuildUtil const String & alias = ""); /** - * @brief Create a Table Function Insert Select Query object + * Create a Table Function Insert Select Query object * - * @param table_function must be a ASTFunction - * @param select_query must be a ASTSelectWithUnionQuery - * @return ASTPtr it's a ASTInsertQuery + * - table_function must be a ASTFunction + * - select_query must be a ASTSelectWithUnionQuery + * return ASTPtr it's a ASTInsertQuery */ static ASTPtr createTableFunctionInsertSelectQuery(ASTPtr table_function, ASTPtr select_query); /** - * @brief Create a ASTSelectWithUnionQuery with a ASTSelectQuery + * Create a ASTSelectWithUnionQuery with a ASTSelectQuery * - * @param select_query must be a ASTSelectQuery - * @return ASTPtr it's a ASTSelectWithUnionQuery + * - select_query must be a ASTSelectQuery + * return ASTPtr it's a ASTSelectWithUnionQuery */ static ASTPtr wrapSelectQuery(const ASTSelectQuery * select_query); /** - * @brief Create a Select Expression object + * Create a Select Expression object * - * @param names_and_types Use the names to build the select expression - * @return ASTPtr + * - names_and_types Use the names to build the select expression + * return ASTPtr */ static ASTPtr createSelectExpression(const NamesAndTypesList & names_and_types); /** - * @brief Update ASTSelectQuery::TABLES ASTTableExpressions + * Update ASTSelectQuery::TABLES ASTTableExpressions */ static void updateSelectQueryTables(ASTSelectQuery * select_query, const ASTTableExpression * table_expr_); diff --git a/src/Interpreters/ASTRewriters/IdentRenameRewriteAction.cpp b/src/Interpreters/ASTRewriters/IdentRenameRewriteAction.cpp index 2294d200b254..a9cab1526f2e 100644 --- a/src/Interpreters/ASTRewriters/IdentRenameRewriteAction.cpp +++ b/src/Interpreters/ASTRewriters/IdentRenameRewriteAction.cpp @@ -58,10 +58,7 @@ ASTs IdentifierRenameAction::collectChildren(const ASTPtr & ast) { return {}; } - else - { - throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknow ast type {}. {}", ast->getID(), queryToString(ast)); - } + return {}; } void IdentifierRenameAction::beforeVisitChildren(const ASTPtr & ast) { diff --git a/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp b/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp index 7910ee1b94bd..c7696dd1355a 100644 --- a/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp +++ b/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp @@ -6,6 +6,7 @@ #include #include #include +#include "Parsers/ASTExpressionList.h" namespace DB { @@ -22,6 +23,10 @@ ASTs IdentifiterQualiferRemoveAction::collectChildren(const ASTPtr & ast) { children = function_ast->arguments->children; } + else if (const auto * expr_list_ast = ast->as()) + { + children = expr_list_ast->children; + } return children; } @@ -49,6 +54,13 @@ void IdentifiterQualiferRemoveAction::visit(const ASTPtr & ast) auto * result_ast = frame->result_ast->as(); result_ast->alias = ident_ast->tryGetAlias(); } + else if (const auto * expr_list_ast = ast->as()) + { + auto frame = frames.getTopFrame(); + frame->result_ast = std::make_shared(); + auto * result_ast = frame->result_ast->as(); + result_ast->children = frame->children_results; + } else { throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid ast({}): {}", ast->getID(), queryToString(ast)); diff --git a/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp b/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp index e0ec51530b7a..31dfb2f8bc3c 100644 --- a/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp +++ b/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp @@ -38,6 +38,7 @@ StageQueryDistributedAggregationRewriteAction::StageQueryDistributedAggregationR void StageQueryDistributedAggregationRewriteAction::beforeVisitChildren(const ASTPtr & ast) { + LOG_TRACE(logger, "{} push frame. {}:{}", __LINE__, ast->getID(), queryToString(ast)); frames.pushFrame(ast); } @@ -59,6 +60,7 @@ ASTs StageQueryDistributedAggregationRewriteAction::collectChildren(const ASTPtr { if (!ast) return {}; + LOG_TRACE(logger, "{} collectChildren. {}:{}", __LINE__, ast->getID(), queryToString(ast)); ASTs children; if (const auto * union_select_ast = ast->as()) { @@ -153,6 +155,7 @@ void StageQueryDistributedAggregationRewriteAction::visit(const ASTSubquery * su void StageQueryDistributedAggregationRewriteAction::visit(const ASTSelectQuery * select_ast) { + LOG_TRACE(logger, "{} frame size={}, select ast={}", __LINE__, frames.size(), queryToString(*select_ast)); auto frame = frames.getTopFrame(); if (frame->children_results.empty()) // join query { @@ -165,12 +168,13 @@ void StageQueryDistributedAggregationRewriteAction::visit(const ASTSelectQuery * throw Exception(ErrorCodes::LOGICAL_ERROR, "ASTStageQuery is expected. return query is : {}", queryToString(rewrite_ast)); auto * return_select_ast = stage_query->current_query->as(); if (!return_select_ast) - throw Exception(ErrorCodes::LOGICAL_ERROR, "ASTSelectQuery is expected. return query is : {}", queryToString(stage_query->current_query)); + throw Exception(ErrorCodes::LOGICAL_ERROR, "ASTSelectQuery is expected. return query is :(id={}) {}", stage_query->current_query->getID(), queryToString(stage_query->current_query)); ASTs upstream_queries; upstream_queries.insert(upstream_queries.end(), stage_query->upstream_queries.begin(), stage_query->upstream_queries.end()); frame->upstream_queries = std::vector{upstream_queries}; + //frame->children_results.emplace_back(stage_query->current_query); frame->result_ast = stage_query->current_query; } else @@ -207,6 +211,7 @@ void StageQueryDistributedAggregationRewriteAction::visit(const ASTSelectQuery * if (frames.size() == 1) { + LOG_TRACE(logger, "{} top select ast:{}", __LINE__, queryToString(*select_ast)); frame->mergeChildrenUpstreamQueries(); frame->result_ast = ASTStageQuery::make(frame->result_ast, frame->upstream_queries[0]); } @@ -318,7 +323,7 @@ void StageQueryDistributedAggregationRewriteAction::visitSelectQueryWithGroupby( auto required_columns = collect_columns_visitor.visit().required_columns; auto insert_query = createShuffleInsert( - TableFunctionLocalShuffle::name, + TableFunctionShuffleAggregation::name, rewrite_table_expr, ColumnWithDetailNameAndType::toNamesAndTypesList(required_columns[0]), select_ast->groupBy()); diff --git a/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp b/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp index 40034530db0c..2e6e00e72c12 100644 --- a/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp +++ b/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include namespace DB @@ -213,10 +215,11 @@ void StageQueryDistributedJoinRewriteAction::visit(const ASTSelectQuery * select void StageQueryDistributedJoinRewriteAction::visitSelectQueryWithAggregation(const ASTSelectQuery * select_ast) { + LOG_TRACE(logger, "{} visitSelectQueryWithAggregation:{}", __LINE__, queryToString(*select_ast)); auto frame = frames.getTopFrame(); - StageQueryDistributedAggregationRewriteAction distributed_join_action(context, id_generator); - ASTDepthFirstVisitor distributed_join_visitor(distributed_join_action, select_ast->clone()); - auto rewrite_ast = distributed_join_visitor.visit(); + StageQueryDistributedAggregationRewriteAction distributed_agg_action(context, id_generator); + ASTDepthFirstVisitor distributed_agg_visitor(distributed_agg_action, select_ast->clone()); + auto rewrite_ast = distributed_agg_visitor.visit(); auto * stage_query = rewrite_ast->as(); if (!stage_query) throw Exception(ErrorCodes::LOGICAL_ERROR, "ASTStageQuery is expected. return query is : {}", queryToString(rewrite_ast)); @@ -245,11 +248,6 @@ void StageQueryDistributedJoinRewriteAction::visitSelectQueryOnJoin(const ASTSel ASTBuildUtil::createTablesInSelectQueryElement( frame->children_results[1]->as(), result_ast->join()->table_join) ->as()); - if (frames.size() == 1) - { - frame->mergeChildrenUpstreamQueries(); - frame->result_ast = ASTStageQuery::make(frame->result_ast, frame->upstream_queries[0]); - } } else { @@ -371,6 +369,7 @@ bool StageQueryDistributedJoinRewriteAnalyzer::isApplicableJoinType() else return false;// using clause + #if 1 // if it is a special storage, return false const auto * join_ast = from_query->join(); const auto & table_to_join = join_ast->table_expression->as(); @@ -378,11 +377,12 @@ bool StageQueryDistributedJoinRewriteAnalyzer::isApplicableJoinType() { auto joined_table_id = context->resolveStorageID(table_to_join.database_and_table_name); StoragePtr storage = DatabaseCatalog::instance().tryGetTable(joined_table_id, context); - if (storage) + if (std::dynamic_pointer_cast(storage)) { return false; } } + #endif return true; } @@ -510,8 +510,20 @@ bool StageQueryDistributedJoinRewriteAnalyzer::collectHashKeysOnAnd( auto * func = ast->as(); for (auto & arg : func->arguments->children) { - if (!collectHashKeysOnEqual(arg, keys_list, alias_columns)) - return false; + auto * arg_func = arg->as(); + if (arg_func) + { + if (arg_func->name == "equals") + { + if (!collectHashKeysOnEqual(arg, keys_list, alias_columns)) + return false; + } + else if (arg_func->name == "and") + { + if (!collectHashKeysOnAnd(arg, keys_list, alias_columns)) + return false; + } + } } return true; } diff --git a/src/Interpreters/InterpreterStageQuery.cpp b/src/Interpreters/InterpreterStageQuery.cpp index d723df6611e2..162d13204ae4 100644 --- a/src/Interpreters/InterpreterStageQuery.cpp +++ b/src/Interpreters/InterpreterStageQuery.cpp @@ -233,7 +233,14 @@ std::optional>> InterpreterStageQue } return res; } - + auto * union_select_query = insert_query->select->as(); + for (auto & child : union_select_query->list_of_selects->children) + { + auto * select_query = child->as(); + if (select_query->limitBy() || select_query->limitByLength() || select_query->limitLength() || select_query->limitOffset() + || select_query->limitByOffset()) + return {}; + } auto storages = getSelectStorages(insert_query->select); bool has_groupby = ASTAnalyzeUtil::hasGroupByRecursively(from_query); bool has_agg = ASTAnalyzeUtil::hasAggregationColumnRecursively(from_query); @@ -366,6 +373,20 @@ std::optional>> InterpreterStageQue auto storages = getSelectStorages(from_query); bool has_groupby = ASTAnalyzeUtil::hasGroupByRecursively(from_query); bool has_agg = ASTAnalyzeUtil::hasAggregationColumnRecursively(from_query); + + // if the query has order by or limit, run in single node + auto * union_select_query = from_query->as(); + for (auto & child : union_select_query->list_of_selects->children) + { + auto * select_query = child->as(); + if (select_query->orderBy() || select_query->limitBy() || select_query->limitByLength() || select_query->limitLength() + || select_query->limitOffset() || select_query->limitByOffset()) + { + LOG_TRACE(logger, "query has order by or limit. [{}] {}", child->getID(), queryToString(child)); + return {}; + } + } + if (storages.size() == 2) { for (const auto & storage : storages) diff --git a/src/Storages/DistributedShuffle/ShuffleBlockTable.cpp b/src/Storages/DistributedShuffle/ShuffleBlockTable.cpp index 162e3203afa7..4a81ddca0f57 100644 --- a/src/Storages/DistributedShuffle/ShuffleBlockTable.cpp +++ b/src/Storages/DistributedShuffle/ShuffleBlockTable.cpp @@ -28,7 +28,7 @@ namespace ErrorCodes void ShuffleBlockTable::addChunk(Chunk && chunk) { - if (chunk.hasRows())[[likely]] + if (chunk.hasRows()) [[likely]] { if (is_sink_finished)[[unlikely]] throw Exception(ErrorCodes::LOGICAL_ERROR, "Try in insert into a sink finished table({}.{})", session_id, table_id); @@ -39,7 +39,7 @@ void ShuffleBlockTable::addChunk(Chunk && chunk) } else { - LOG_TRACE(logger, "add empty chunk"); + LOG_TRACE(logger, "Add an empty chunk. table({}.{})", session_id, table_id); wait_more_data.notify_all(); } } @@ -58,19 +58,18 @@ Chunk ShuffleBlockTable::popChunk() break; } } - LOG_TRACE(logger, "{}.{} popChunk. isSinkFinished()={}, chunks.size()={}", session_id, table_id, is_sink_finished, chunks.size()); + //LOG_TRACE(logger, "{}.{} popChunk. isSinkFinished()={}, chunks.size()={}", session_id, table_id, is_sink_finished, chunks.size()); Chunk res; - if (likely(!chunks.empty())) + if (!chunks.empty()) [[likely]] { res.swap(chunks.front()); chunks.pop_front(); if (unlikely(!res.hasRows())) { - throw Exception(ErrorCodes::LOGICAL_ERROR, "Chunk should not be empty"); + throw Exception(ErrorCodes::LOGICAL_ERROR, "Chunk should not be empty. table({}.{})", session_id, table_id); } } - lock.unlock(); return res; } @@ -138,7 +137,6 @@ ShuffleBlockTablePtr ShuffleBlockSession::getOrSetTable(const String & table_id_ void ShuffleBlockSession::releaseTable(const String & table_id_) { LOG_INFO(logger, "release table {}.{}", session_id, table_id_); - size_t table_count = 0; { std::lock_guard lock(mutex); auto iter = tables.find(table_id_); @@ -147,11 +145,6 @@ void ShuffleBlockSession::releaseTable(const String & table_id_) iter->second->makeSinkFinished(); } tables.erase(table_id_); - table_count = tables.size(); - } - if (!table_count) - { - ShuffleBlockTableManager::getInstance().tryCloseSession(session_id); } } @@ -161,26 +154,49 @@ bool ShuffleBlockSession::isTimeout() const return (created_timestamp + timeout_second < now); } +void ShuffleBlockSession::decreaseRef() +{ + if (ref_count == 0) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Session({}) ref_count = 0", session_id); + ref_count -= 1; +} + +String ShuffleBlockSession::dumpTables() +{ + std::lock_guard lock(mutex); + String table_names; + int i = 0; + for (auto & table : tables) + { + if (i) + table_names += ","; + i += 1; + table_names += table.first; + } + return table_names; +} + ShuffleBlockTableManager & ShuffleBlockTableManager::getInstance() { static ShuffleBlockTableManager storage; return storage; } -ShuffleBlockSessionPtr ShuffleBlockTableManager::getSession(const String & session_id_) const +std::shared_ptr ShuffleBlockTableManager::getSession(const String & session_id_) const { std::lock_guard lock(mutex); auto iter = sessions.find(session_id_); if (iter == sessions.end()) { - LOG_INFO(logger, "Session() not found.", session_id_); + LOG_INFO(logger, "Session({}) not found.", session_id_); return nullptr; } - return iter->second; + iter->second->increaseRef(); + return std::make_shared(iter->second); } -ShuffleBlockSessionPtr ShuffleBlockTableManager::getOrSetSession(const String & session_id_, ContextPtr context_) +std::shared_ptr ShuffleBlockTableManager::getOrSetSession(const String & session_id_, ContextPtr context_) { std::lock_guard lock(mutex); clearTimeoutSession(); @@ -190,19 +206,24 @@ ShuffleBlockSessionPtr ShuffleBlockTableManager::getOrSetSession(const String & { LOG_TRACE(logger, "create new session:{}", session_id_); auto session = std::make_shared(session_id_, context_); + session->increaseRef(); sessions[session_id_] = session; - return session; + return std::make_shared(session); } - return iter->second; + iter->second->increaseRef(); + return std::make_shared(iter->second); } -void ShuffleBlockTableManager::closeSession(const String & session_id_) +ShuffleBlockSessionHolder::ShuffleBlockSessionHolder(ShuffleBlockSessionPtr session_) + : session(session_) { - LOG_TRACE(logger, "close session:{}", session_id_); - std::lock_guard lock(mutex); - sessions.erase(session_id_); } +ShuffleBlockSessionHolder::~ShuffleBlockSessionHolder() +{ +} + + void ShuffleBlockTableManager::tryCloseSession(const String & session_id_) { std::lock_guard lock(mutex); @@ -213,10 +234,10 @@ void ShuffleBlockTableManager::tryCloseSession(const String & session_id_) return; } auto & session = iter->second; - - if (session->getTablesNumber()) + session->decreaseRef(); + if (session->getRefCount() || session->getTablesNumber()) { - LOG_INFO(logger, "session({}) has tables which are in used", session_id_); + LOG_INFO(logger, "session({}) is in used. ref={}, tables={}", session_id_, session->getRefCount(), session->dumpTables()); return; } LOG_INFO(logger, "close session:{}", session_id_); diff --git a/src/Storages/DistributedShuffle/ShuffleBlockTable.h b/src/Storages/DistributedShuffle/ShuffleBlockTable.h index 1fe2f9b23bbb..608fdc1eda55 100644 --- a/src/Storages/DistributedShuffle/ShuffleBlockTable.h +++ b/src/Storages/DistributedShuffle/ShuffleBlockTable.h @@ -91,6 +91,7 @@ class ShuffleBlockSession using TablePtr = ShuffleBlockTablePtr; explicit ShuffleBlockSession(const String & session_id_, ContextPtr context_); + const String & getSessionId() const { return session_id; } TablePtr getTable(const String & table_id_, bool wait_created = false); TablePtr getOrSetTable(const String & table_id_, const Block & header_); void releaseTable(const String & table_id_); @@ -102,6 +103,11 @@ class ShuffleBlockSession } bool isTimeout() const; + + void increaseRef() { ref_count += 1; } + void decreaseRef(); + inline UInt32 getRefCount() const { return ref_count; } + String dumpTables(); private: Poco::Logger * logger = &Poco::Logger::get("ShuffleBlockSession"); String session_id; @@ -111,20 +117,36 @@ class ShuffleBlockSession mutable std::mutex mutex; std::condition_variable new_table_cond; std::unordered_map> tables; + std::atomic ref_count = 0; }; using ShuffleBlockSessionPtr = std::shared_ptr; +class ShuffleBlockSessionHolder +{ +public: + ShuffleBlockSessionHolder() = default; + explicit ShuffleBlockSessionHolder(ShuffleBlockSessionPtr session_); + + ~ShuffleBlockSessionHolder(); + + ShuffleBlockSession & value() { return *session; } + +private: + ShuffleBlockSessionPtr session; + +}; + class ShuffleBlockTableManager : public boost::noncopyable { public: using Session = ShuffleBlockSession; using SessionPtr = ShuffleBlockSessionPtr; + using SessionHolder = ShuffleBlockSessionHolder; static ShuffleBlockTableManager & getInstance(); - SessionPtr getSession(const String & session_id_) const; - SessionPtr getOrSetSession(const String & session_id_, ContextPtr context_); + std::shared_ptr getSession(const String & session_id_) const; + std::shared_ptr getOrSetSession(const String & session_id_, ContextPtr context_); - void closeSession(const String & session_id_); void tryCloseSession(const String & session_id_); protected: ShuffleBlockTableManager() = default; diff --git a/src/Storages/DistributedShuffle/StorageShuffle.cpp b/src/Storages/DistributedShuffle/StorageShuffle.cpp index e94ea8ddc501..867d71075a2b 100644 --- a/src/Storages/DistributedShuffle/StorageShuffle.cpp +++ b/src/Storages/DistributedShuffle/StorageShuffle.cpp @@ -55,7 +55,7 @@ class StorageShuffleSource : public SourceWithProgress, WithContext { if (table) { - session->releaseTable(table_id); + session->value().releaseTable(table_id); } } @@ -63,7 +63,7 @@ class StorageShuffleSource : public SourceWithProgress, WithContext Chunk generate() override { tryInitialize(); - if (unlikely(!table)) + if (!table) [[unlikely]] { LOG_INFO(logger, "{}.{} is not found.", session_id, table_id); return {}; @@ -79,7 +79,7 @@ class StorageShuffleSource : public SourceWithProgress, WithContext String session_id; String table_id; Block header; - ShuffleBlockSessionPtr session; + std::shared_ptr session; ShuffleBlockTablePtr table; size_t read_rows = 0; @@ -90,7 +90,7 @@ class StorageShuffleSource : public SourceWithProgress, WithContext session = ShuffleBlockTableManager::getInstance().getOrSetSession(session_id, getContext()); if (session) { - table = session->getTable(table_id, true); + table = session->value().getTable(table_id, true); if (!table) { LOG_TRACE(logger, "Not found table:{}-{}", session_id, table_id); @@ -430,10 +430,10 @@ class StorageLocalShuffleSink : public SinkToStorage explicit StorageLocalShuffleSink(ContextPtr context_, const String & session_id_, const String & table_id_, const Block & header_) : SinkToStorage(header_), context(context_), session_id(session_id_), table_id(table_id_) { - auto session = ShuffleBlockTableManager::getInstance().getOrSetSession(session_id_, context_); + session = ShuffleBlockTableManager::getInstance().getOrSetSession(session_id_, context_); if (!session) throw Exception(ErrorCodes::LOGICAL_ERROR, "Get session({}) storage failed.", session_id_); - table_storage = session->getOrSetTable(table_id_, header_); + table_storage = session->value().getOrSetTable(table_id_, header_); if (!table_storage) throw Exception(ErrorCodes::LOGICAL_ERROR, "Get session table({}-{}) failed.", session_id_, table_id_); } @@ -448,6 +448,7 @@ class StorageLocalShuffleSink : public SinkToStorage ContextPtr context; String session_id; String table_id; + std::shared_ptr session; ShuffleBlockTablePtr table_storage; Poco::Logger * logger = &Poco::Logger::get("StorageLocalShuffleSink"); Stopwatch watch; @@ -745,7 +746,8 @@ class StorageShuffleCloseSink : public SinkToStorage auto session = ShuffleBlockTableManager::getInstance().getSession(session_id); if (!session) return; - auto table = session->getTable(table_id); + LOG_TRACE(logger, "session({}), ref={}, tables={}", session_id, session->value().getRefCount(), session->value().dumpTables()); + auto table = session->value().getTable(table_id); if (!table) return; table->makeSinkFinished(); From dddcffce16ef540ae17e275a0b9b34df7faf0df2 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Wed, 18 May 2022 17:41:11 +0800 Subject: [PATCH 3/4] improve hash dispatch blocks 1) extend remote inserters to accelerate data transform 2) make advantage of low cardinality to improve groupby --- .../CollectRequiredColumnsAction.cpp | 9 ++ .../IdentifierQualiferRemoveAction.cpp | 2 + .../NestedJoinQueryRewriteAction.cpp | 4 +- .../NestedJoinQueryRewriteAction.h | 1 + ...eryDistributedAggregationRewriteAction.cpp | 50 +++++++- ...QueryDistributedAggregationRewriteAction.h | 3 + .../Transforms/StageQueryTransform.cpp | 16 --- .../DistributedShuffle/ShuffleBlockTable.cpp | 55 +++++---- .../DistributedShuffle/ShuffleBlockTable.h | 19 +-- .../DistributedShuffle/StorageShuffle.cpp | 109 ++++++++++++------ 10 files changed, 174 insertions(+), 94 deletions(-) diff --git a/src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.cpp b/src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.cpp index 87844036d8b0..459bc4cbd8f0 100644 --- a/src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.cpp +++ b/src/Interpreters/ASTRewriters/CollectRequiredColumnsAction.cpp @@ -147,6 +147,15 @@ void CollectRequiredColumnsAction::visit(const ASTIdentifier * ident_ast) .alias_name = ident_ast->tryGetAlias(), .type = col.type }; + /* + LOG_TRACE( + &Poco::Logger::get("CollectRequiredColumnsAction"), + "add ident @ {}, full name:{}, short name:{}, alias:{}.", + *best_pos, + ident_ast->name(), + ident_ast->shortName(), + ident_ast->tryGetAlias()); + */ final_result.required_columns[*best_pos].push_back(column_metadta); found = true; added_names.insert(ident_ast->name()); diff --git a/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp b/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp index c7696dd1355a..20dbd20f7a09 100644 --- a/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp +++ b/src/Interpreters/ASTRewriters/IdentifierQualiferRemoveAction.cpp @@ -32,6 +32,8 @@ ASTs IdentifiterQualiferRemoveAction::collectChildren(const ASTPtr & ast) void IdentifiterQualiferRemoveAction::visit(const ASTPtr & ast) { + if (!ast) + return; if (const auto * function_ast = ast->as()) { auto frame = frames.getTopFrame(); diff --git a/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.cpp b/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.cpp index 3f298fbccc93..ef16ba36c25e 100644 --- a/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.cpp +++ b/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.cpp @@ -160,8 +160,8 @@ void NestedJoinQueryRewriteAction::visit(const ASTSelectQuery * select_ast) { auto ident = std::make_shared(col.splitedFullName()); ident->alias = col.alias_name; - if (ident->alias.empty()) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Alias name is expected for {}", col.full_name); + //if (ident->alias.empty()) + // throw Exception(ErrorCodes::LOGICAL_ERROR, "Alias is expected for {}", col.full_name); nested_select_expr_list->children.emplace_back(ident); } nested_select_ast->setExpression(ASTSelectQuery::Expression::SELECT, nested_select_expr_list); diff --git a/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.h b/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.h index f8a74dc66017..a259477220f5 100644 --- a/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.h +++ b/src/Interpreters/ASTRewriters/NestedJoinQueryRewriteAction.h @@ -10,6 +10,7 @@ #include #include #include +#include namespace DB { /** diff --git a/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp b/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp index 31dfb2f8bc3c..4edb6778fd11 100644 --- a/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp +++ b/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.cpp @@ -300,7 +300,12 @@ void StageQueryDistributedAggregationRewriteAction::visitSelectQueryWithAggregat void StageQueryDistributedAggregationRewriteAction::visitSelectQueryWithGroupby(const ASTSelectQuery * select_ast) { + LOG_TRACE(logger, "visitSelectQueryWithGroupby:{}", queryToString(*select_ast)); auto frame = frames.getTopFrame(); + auto tables = getDatabaseAndTablesWithColumns(getTableExpressions(*select_ast), context, true, true); + if (tables.size() != 1) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Tables size should be 1"); + auto * rewrite_table_expr = frame->children_results[0]->as(); if (!rewrite_table_expr) throw Exception(ErrorCodes::LOGICAL_ERROR, "ASTTableExpression is expected. return query is : {}", queryToString(frame->children_results[0])); @@ -313,21 +318,36 @@ void StageQueryDistributedAggregationRewriteAction::visitSelectQueryWithGroupby( rewrite_table_expr->subquery->as()->setAlias(table_alias); } - - auto tables = getDatabaseAndTablesWithColumns(getTableExpressions(*select_ast), context, true, true); - if (tables.size() != 1) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Tables size should be 1"); + + if (isAllRequiredColumnsLowCardinality(select_ast->groupBy(), tables)) + { + CollectQueryStoragesAction collect_storage_action(context); + ASTDepthFirstVisitor collect_storage_visitor(collect_storage_action, frame->children_results[0]); + auto storages = collect_storage_visitor.visit(); + /// If all columns used in the groupby clasue are low cardinality, do not shuffle the data and + /// run the groupby in the two-phase way. + if (storages.size() > 1) + visitSelectQueryWithAggregation(select_ast); + return; + } CollectRequiredColumnsAction collect_columns_action(tables); ASTDepthFirstVisitor collect_columns_visitor(collect_columns_action, select_ast->clone()); auto required_columns = collect_columns_visitor.visit().required_columns; - auto insert_query = createShuffleInsert( + ASTPtr insert_query = createShuffleInsert( TableFunctionShuffleAggregation::name, rewrite_table_expr, ColumnWithDetailNameAndType::toNamesAndTypesList(required_columns[0]), select_ast->groupBy()); + auto * insert_query_ptr = insert_query->as(); + auto * insert_select_ptr = insert_query_ptr->select->as()->list_of_selects->children[0]->as(); + IdentifiterQualiferRemoveAction remove_qualifier_action; + ASTDepthFirstVisitor remove_qualifier_visitor(remove_qualifier_action, select_ast->where()); + auto where_expr = remove_qualifier_visitor.visit(); + insert_select_ptr->setExpression(ASTSelectQuery::Expression::WHERE, std::move(where_expr)); + ASTs upstream_queries; frame->mergeChildrenUpstreamQueries(); if (!frame->upstream_queries[0].empty()) @@ -341,6 +361,7 @@ void StageQueryDistributedAggregationRewriteAction::visitSelectQueryWithGroupby( frame->result_ast = select_ast->clone(); auto * result_select_ast = frame->result_ast->as(); + result_select_ast->setExpression(ASTSelectQuery::Expression::WHERE, nullptr); ASTBuildUtil::updateSelectQueryTables( result_select_ast, ASTBuildUtil::createTablesInSelectQueryElement(table_function->as())->as()); @@ -371,4 +392,23 @@ ASTPtr StageQueryDistributedAggregationRewriteAction::createShuffleInsert( return ASTBuildUtil::createTableFunctionInsertSelectQuery(table_function, ASTBuildUtil::wrapSelectQuery(select_query)); } + +bool StageQueryDistributedAggregationRewriteAction::isAllRequiredColumnsLowCardinality(const ASTPtr & ast, const TablesWithColumns & tables) +{ + CollectRequiredColumnsAction collect_columns_action(tables); + ASTDepthFirstVisitor collect_columns_visitor(collect_columns_action, ast); + auto required_columns = collect_columns_visitor.visit().required_columns; + for (auto & cols : required_columns) + { + for (auto & col : cols) + { + LOG_TRACE(logger, "check group by col. {} {}", col.full_name, col.type->getName()); + if (!col.type->lowCardinality()) + return false; + } + } + LOG_TRACE(logger, "all columns are locaCardinality"); + return true; + +} } diff --git a/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.h b/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.h index 97f645e46863..832d4319e64f 100644 --- a/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.h +++ b/src/Interpreters/ASTRewriters/StageQueryDistributedAggregationRewriteAction.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace DB { @@ -67,5 +68,7 @@ class StageQueryDistributedAggregationRewriteAction : public EmptyASTDepthFirstV ASTPtr createShuffleInsert( const String & table_function_name, ASTTableExpression * table_expr, const NamesAndTypesList & table_desc, ASTPtr groupby_clause); + + bool isAllRequiredColumnsLowCardinality(const ASTPtr & ast, const TablesWithColumns & tables); }; } diff --git a/src/Processors/Transforms/StageQueryTransform.cpp b/src/Processors/Transforms/StageQueryTransform.cpp index cce89ac48d2e..f829bb964186 100644 --- a/src/Processors/Transforms/StageQueryTransform.cpp +++ b/src/Processors/Transforms/StageQueryTransform.cpp @@ -226,15 +226,8 @@ ParallelStageBlockIOsTransform::ParallelStageBlockIOsTransform( ParallelStageBlockIOsTransform::~ParallelStageBlockIOsTransform() { -#if 0 - for (auto & task : background_tasks) - { - task->deactivate(); - } -#else if (thread_pool) thread_pool->wait(); -#endif LOG_TRACE(logger, "run query({}) in elapsedMilliseconds:{}", queryToString(output_block_io.query), elapsed); } @@ -305,20 +298,11 @@ void ParallelStageBlockIOsTransform::startBackgroundTasks() queryToString(block_io.query), task_watch.elapsedMilliseconds()); }; -#if 0 - auto & thread_pool = context->getSchedulePool(); - for (auto & block_io : input_block_ios) - { - background_tasks.emplace_back(thread_pool.createTask("BackgroundBlockIOTask", [build_task, &block_io](){ build_task(block_io);})); - background_tasks.back()->activateAndSchedule(); - } -#else thread_pool = std::make_unique(input_block_ios.size()); for (auto & block : input_block_ios) { thread_pool->scheduleOrThrowOnError([&]() { build_task(block); }); } -#endif has_start_background_tasks = true; } } diff --git a/src/Storages/DistributedShuffle/ShuffleBlockTable.cpp b/src/Storages/DistributedShuffle/ShuffleBlockTable.cpp index 4a81ddca0f57..41e706cc8de8 100644 --- a/src/Storages/DistributedShuffle/ShuffleBlockTable.cpp +++ b/src/Storages/DistributedShuffle/ShuffleBlockTable.cpp @@ -30,11 +30,18 @@ void ShuffleBlockTable::addChunk(Chunk && chunk) { if (chunk.hasRows()) [[likely]] { - if (is_sink_finished)[[unlikely]] - throw Exception(ErrorCodes::LOGICAL_ERROR, "Try in insert into a sink finished table({}.{})", session_id, table_id); - std::unique_lock lock(mutex); - rows += chunk.getNumRows(); - chunks.emplace_back(std::move(chunk)); + { + std::unique_lock lock(mutex); + if (is_sink_finished) [[unlikely]] + throw Exception(ErrorCodes::LOGICAL_ERROR, "Try in insert into a sink finished table({}.{})", session_id, table_id); + while (remained_rows > max_rows_limit) + wait_consume_data.wait(lock); + rows += chunk.getNumRows(); + remained_rows += chunk.getNumRows(); + if (remained_rows > max_rows) + max_rows = remained_rows; + chunks.emplace_back(std::move(chunk)); + } wait_more_data.notify_one(); } else @@ -46,30 +53,34 @@ void ShuffleBlockTable::addChunk(Chunk && chunk) Chunk ShuffleBlockTable::popChunk() { - std::unique_lock lock(mutex); - while (chunks.empty()) + Chunk res; { - if (!is_sink_finished) - { - wait_more_data.wait(lock, [&] { return is_sink_finished || !chunks.empty(); }); - } - else + std::unique_lock lock(mutex); + while (chunks.empty()) { - break; + if (!is_sink_finished) + { + wait_more_data.wait(lock, [&] { return is_sink_finished || !chunks.empty(); }); + } + else + { + break; + } } - } - //LOG_TRACE(logger, "{}.{} popChunk. isSinkFinished()={}, chunks.size()={}", session_id, table_id, is_sink_finished, chunks.size()); + //LOG_TRACE(logger, "{}.{} popChunk. isSinkFinished()={}, chunks.size()={}", session_id, table_id, is_sink_finished, chunks.size()); - Chunk res; - if (!chunks.empty()) [[likely]] - { - res.swap(chunks.front()); - chunks.pop_front(); - if (unlikely(!res.hasRows())) + if (!chunks.empty()) [[likely]] { - throw Exception(ErrorCodes::LOGICAL_ERROR, "Chunk should not be empty. table({}.{})", session_id, table_id); + res.swap(chunks.front()); + remained_rows -= res.getNumRows(); + chunks.pop_front(); + if (unlikely(!res.hasRows())) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Chunk should not be empty. table({}.{})", session_id, table_id); + } } } + wait_consume_data.notify_all(); return res; } diff --git a/src/Storages/DistributedShuffle/ShuffleBlockTable.h b/src/Storages/DistributedShuffle/ShuffleBlockTable.h index 608fdc1eda55..877d6b7a2096 100644 --- a/src/Storages/DistributedShuffle/ShuffleBlockTable.h +++ b/src/Storages/DistributedShuffle/ShuffleBlockTable.h @@ -18,19 +18,6 @@ namespace DB { /// -/// How to clear all the data when a query session has finished ? -/// The following measures were taken at current -/// 1)Chunks in ShuffleBlockTable are read only once, so we use popChunkWithoutMutex() for loading a chunk. -/// That ensures that all chunks are released after the loading finish. -/// 2) When ShuffleBlockTable becomes empty, it will call ShuffleBlockSession::releaseTable() to -/// release it-self. -/// 3) When ShuffleBlockSession becomes empty, it will call ShuffleBlockTableManager::tryCloseSession() to -/// release it-self. -/// All above will ensure all datas are released in normal processing. But more need be considered, exceptions could -/// happen during the processing which make the release actions not be called. Some measures may be token. -/// 1) In TCPHandler, catch all exceptions , and make a session releasing action on all nodes -/// 2) All sessions have a max TTL, make background routine to check timeout sessions and clear them. -/// class ShuffleBlockTable { @@ -47,7 +34,7 @@ class ShuffleBlockTable ~ShuffleBlockTable() { - LOG_TRACE(logger, "close table {}.{}", session_id, table_id); + LOG_TRACE(logger, "close table {}.{}. rows:{}. max_rows:{}", session_id, table_id, rows, max_rows); } inline const Block & getHeader() const @@ -79,8 +66,12 @@ class ShuffleBlockTable std::atomic is_sink_finished = false; std::list chunks; std::condition_variable wait_more_data; + std::condition_variable wait_consume_data; Poco::Logger * logger = &Poco::Logger::get("ShuffleBlockTable"); size_t rows = 0; + size_t remained_rows = 0; + size_t max_rows = 0; + const static size_t max_rows_limit = 20000000; }; using ShuffleBlockTablePtr = std::shared_ptr; diff --git a/src/Storages/DistributedShuffle/StorageShuffle.cpp b/src/Storages/DistributedShuffle/StorageShuffle.cpp index 867d71075a2b..ca42ac04954e 100644 --- a/src/Storages/DistributedShuffle/StorageShuffle.cpp +++ b/src/Storages/DistributedShuffle/StorageShuffle.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -36,6 +37,7 @@ #include #include #include +#include namespace DB { @@ -221,6 +223,7 @@ class StorageShuffleSink : public SinkToStorage }; using InternalInserterPtr = std::shared_ptr; std::vector inserters; + const static size_t inserter_group_size = 4; std::vector> node_connections; std::shared_ptr hash_expr_cols_actions; String hash_expr_column_name; @@ -303,6 +306,7 @@ class StorageShuffleSink : public SinkToStorage auto cluster_addresses = getSortedShardAddresses(); for (const auto & node : cluster_addresses) { + #if 0 auto connection = std::make_shared( node.host_name, node.port, @@ -328,6 +332,36 @@ class StorageShuffleSink : public SinkToStorage internal_inserter->inserter = inserter; internal_inserter->max_rows_limit = context->getSettingsRef().max_block_size; inserters.emplace_back(internal_inserter); + #else + for (size_t i = 0; i < inserter_group_size; ++i) + { + auto connection = std::make_shared( + node.host_name, + node.port, + context->getGlobalContext()->getCurrentDatabase(), + node.user, + node.password, + node.cluster, + node.cluster_secret, + "StorageShuffleSink", + node.compression, + node.secure); + node_connections.emplace_back(connection); + auto inserter = std::make_shared( + *connection, + ConnectionTimeouts{ + settings.connect_timeout.value.seconds() * 1000, + settings.send_timeout.value.seconds() * 1000, + settings.receive_timeout.value.seconds() * 1000}, + insert_sql, + context->getSettings(), + context->getClientInfo()); + auto internal_inserter = std::make_shared(); + internal_inserter->inserter = inserter; + internal_inserter->max_rows_limit = context->getSettingsRef().max_block_size; + inserters.emplace_back(internal_inserter); + } + #endif } } @@ -358,26 +392,34 @@ class StorageShuffleSink : public SinkToStorage return addresses; } + static IColumn::Selector hashToSelector(const WeakHash32 & hash, size_t num_shards) + { + const auto & data = hash.getData(); + size_t num_rows = data.size(); + + IColumn::Selector selector(num_rows); + for (size_t i = 0; i < num_rows; ++i) + selector[i] = data[i] % num_shards; + return selector; + } + void splitBlock(Block & original_block, std::vector & split_blocks) { size_t num_rows = original_block.rows(); - size_t num_shards = inserters.size(); + size_t num_shards = inserters.size() / inserter_group_size; Block header = original_block.cloneEmpty(); ColumnRawPtrs hash_cols; for (const auto & hash_col_name : hash_expr_columns_names) { hash_cols.push_back(original_block.getByName(hash_col_name).column.get()); } - IColumn::Selector selector(num_rows); - for (size_t i = 0; i < num_rows; ++i) + + WeakHash32 hash(num_rows); + for (const auto & hash_col : hash_cols) { - SipHash hash; - for (const auto & hash_col : hash_cols) - { - hash_col->updateHashWithValue(i, hash); - } - selector[i] = hash.get64() % num_shards; + hash_col->updateWeakHash32(hash); } + auto selector = hashToSelector(hash, num_shards); for (size_t i = 0; i < num_shards; ++i) { @@ -394,33 +436,29 @@ class StorageShuffleSink : public SinkToStorage } } } + void sendBlocks(std::vector & blocks) { - std::list to_send_blocks; - for (size_t i = 0, sz = blocks.size(); i < sz; ++i) - to_send_blocks.emplace_back(i); - while (!to_send_blocks.empty()) + #if 0 + for (size_t i = 0; i < blocks.size(); ++i) { - for (auto iter = to_send_blocks.begin(); iter != to_send_blocks.end();) - { - auto & inserter = inserters[*iter]; - if (inserter->mutex.try_lock()) - { - auto & block = blocks[*iter]; - if (block.rows()) - { - inserter->tryWrite(block); - //inserter->inserter->write(block); - } - inserter->mutex.unlock(); - iter = to_send_blocks.erase(iter); - } - else - { - iter++; - } - } + auto & block = blocks[i]; + auto & inserter = inserters[i]; + std::lock_guard lock(inserter->mutex); + inserter->tryWrite(block); + } + #else + static std::atomic inserter_round_idx = 0; + inserter_round_idx += 1; + size_t idx = inserter_round_idx % inserter_group_size; + for (size_t i = 0; i < blocks.size(); ++i) + { + auto & block = blocks[i]; + auto inserter = inserters[i * inserter_group_size + idx]; + std::lock_guard lock(inserter->mutex); + inserter->tryWrite(block); } + #endif } }; @@ -522,7 +560,7 @@ Pipe StorageShuffleBase::read( auto remote_query_executor = std::make_shared( connection, remote_query, header, context_, nullptr, scalars, Tables(), processed_stage_, RemoteQueryExecutor::Extension{}); - //LOG_TRACE9logger, "run query on node:{}. query:{}", node.host_name, remote_query); + LOG_TRACE(logger, "run query on node:{}. query:{}", node.host_name, remote_query); pipes.emplace_back(std::make_shared(remote_query_executor, false, false)); } } @@ -629,6 +667,7 @@ Pipe StorageLocalShuffle::read( { auto header = getInMemoryMetadata().getSampleBlock(); auto query_kind = context_->getClientInfo().query_kind; + LOG_TRACE(logger, "query kind={}, query={}", query_kind, queryToString(query_info_.query)); if (query_kind != ClientInfo::QueryKind::INITIAL_QUERY) { return Pipe(std::make_shared(context_, session_id, table_id, header)); @@ -656,7 +695,7 @@ Pipe StorageLocalShuffle::read( auto remote_query_executor = std::make_shared( connection, remote_query, header, context_, nullptr, scalars, Tables(), processed_stage_, RemoteQueryExecutor::Extension{}); - //LOG_TRACE9logger, "run query on node:{}. query:{}", node.host_name, remote_query); + LOG_TRACE(logger, "run query on node:{}. query:{}", node.host_name, remote_query); pipes.emplace_back(std::make_shared(remote_query_executor, false, false)); } } @@ -677,7 +716,7 @@ QueryProcessingStage::Enum StorageLocalShuffle::getQueryProcessingStage( const StorageSnapshotPtr & /*metadata_snapshot*/, SelectQueryInfo & query_info) const { - //LOG_TRACE9logger, "query:{}, to_stage:{}, query_kind:{}", queryToString(query_info.query), to_stage, local_context->getClientInfo().query_kind); + LOG_TRACE(logger, "query:{}, to_stage:{}, query_kind:{}", queryToString(query_info.query), to_stage, local_context->getClientInfo().query_kind); if (local_context->getClientInfo().query_kind == ClientInfo::QueryKind::INITIAL_QUERY) { // When there is join in the query, cannot enable the two phases processing. It will cause From 918aa32b4b742782c22e640dcef7feeb19c2173d Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 23 Jun 2022 11:10:30 +0800 Subject: [PATCH 4/4] adjust shuffle join --- programs/server/Server.cpp | 1 - src/Core/Settings.h | 1 + .../StageQueryDistributedJoinRewriteAction.cpp | 5 ++++- src/Interpreters/InterpreterStageQuery.cpp | 1 - src/Interpreters/StorageDistributedTasksBuilder.cpp | 9 ++++++++- src/Interpreters/StorageDistributedTasksBuilder.h | 2 +- src/Interpreters/executeQuery.cpp | 3 ++- src/Storages/DistributedShuffle/StorageShuffle.cpp | 6 +++--- src/Storages/Hive/StorageHiveCluster.h | 7 +++++++ 9 files changed, 26 insertions(+), 9 deletions(-) diff --git a/programs/server/Server.cpp b/programs/server/Server.cpp index 7326e83a20e3..4f824cd1e869 100644 --- a/programs/server/Server.cpp +++ b/programs/server/Server.cpp @@ -689,7 +689,6 @@ int Server::main(const std::vector & /*args*/) } } - registerAllStorageDistributedTaskBuilderMakers(); Poco::ThreadPool server_pool(3, config().getUInt("max_connections", 1024)); std::mutex servers_lock; std::vector servers; diff --git a/src/Core/Settings.h b/src/Core/Settings.h index 8d441e8d8b9f..527daf93b331 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -600,6 +600,7 @@ static constexpr UInt64 operator""_GiB(unsigned long long value) M(String, insert_deduplication_token, "", "If not empty, used for duplicate detection instead of data digest", 0) \ M(Bool, count_distinct_optimization, false, "Rewrite count distinct to subquery of group by", 0) \ M(String, use_cluster_for_distributed_shuffle, "", "If you want to run the join and group by in distributed shuffle mode, set it as one of the available cluster.", 0) \ + M(Bool, enable_distribute_shuffle, false, "Enable shuffle join", 0) \ M(UInt64, shuffle_storage_session_timeout, 1800, "How long a session can be alive before expired by timeout", 0) \ M(Bool, throw_on_unsupported_query_inside_transaction, true, "Throw exception if unsupported query is used inside transaction", 0) \ M(TransactionsWaitCSNMode, wait_changes_become_visible_after_commit_mode, TransactionsWaitCSNMode::WAIT_UNKNOWN, "Wait for committed changes to become actually visible in the latest snapshot", 0) \ diff --git a/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp b/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp index 2e6e00e72c12..0bf6e703648e 100644 --- a/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp +++ b/src/Interpreters/ASTRewriters/StageQueryDistributedJoinRewriteAction.cpp @@ -354,7 +354,10 @@ bool StageQueryDistributedJoinRewriteAnalyzer::isApplicableJoinType() { const auto * join_tables = from_query->join(); auto * table_join = join_tables->table_join->as(); - if (table_join->kind == ASTTableJoin::Kind::Cross) + + if (table_join->kind != ASTTableJoin::Kind::Left && table_join->kind != ASTTableJoin::Kind::Inner) + return false; + if (table_join->strictness == ASTTableJoin::Strictness::Asof) return false; // TODO if right table is dict or special storage, return false; diff --git a/src/Interpreters/InterpreterStageQuery.cpp b/src/Interpreters/InterpreterStageQuery.cpp index 162d13204ae4..6ec8bec705be 100644 --- a/src/Interpreters/InterpreterStageQuery.cpp +++ b/src/Interpreters/InterpreterStageQuery.cpp @@ -104,7 +104,6 @@ BlockIO InterpreterStageQuery::execute(const QueryBlockIO & output_io, const Que auto pipeline_builder = query_plan.buildQueryPipeline( QueryPlanOptimizationSettings::fromContext(context), BuildQueryPipelineSettings::fromContext(context)); - pipeline_builder->addInterpreterContext(context); BlockIO res; res.pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder)); return res; diff --git a/src/Interpreters/StorageDistributedTasksBuilder.cpp b/src/Interpreters/StorageDistributedTasksBuilder.cpp index 75f497a59737..20433212acc4 100644 --- a/src/Interpreters/StorageDistributedTasksBuilder.cpp +++ b/src/Interpreters/StorageDistributedTasksBuilder.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -7,9 +8,13 @@ namespace ErrorCodes { extern const int LOGICAL_ERROR; } + +static std::once_flag init_builder_flag; +void registerAllStorageDistributedTaskBuilderMakers(); StorageDistributedTaskBuilderFactory & StorageDistributedTaskBuilderFactory::getInstance() { static StorageDistributedTaskBuilderFactory instance; + std::call_once(init_builder_flag, [](){ registerAllStorageDistributedTaskBuilderMakers(instance); }); return instance; } @@ -31,7 +36,9 @@ StorageDistributedTaskBuilderPtr StorageDistributedTaskBuilderFactory::getBuilde return iter->second(); } -void registerAllStorageDistributedTaskBuilderMakers() +void registerHiveClusterTasksBuilder(StorageDistributedTaskBuilderFactory & instance); +void registerAllStorageDistributedTaskBuilderMakers(StorageDistributedTaskBuilderFactory & instance) { + registerHiveClusterTasksBuilder(instance); } } diff --git a/src/Interpreters/StorageDistributedTasksBuilder.h b/src/Interpreters/StorageDistributedTasksBuilder.h index 1848933d139d..b20650ac4c6b 100644 --- a/src/Interpreters/StorageDistributedTasksBuilder.h +++ b/src/Interpreters/StorageDistributedTasksBuilder.h @@ -34,5 +34,5 @@ class StorageDistributedTaskBuilderFactory : boost::noncopyable }; -void registerAllStorageDistributedTaskBuilderMakers(); +void registerAllStorageDistributedTaskBuilderMakers(StorageDistributedTaskBuilderFactory & instance); } diff --git a/src/Interpreters/executeQuery.cpp b/src/Interpreters/executeQuery.cpp index d8052837dbb1..fcd689f75ecd 100644 --- a/src/Interpreters/executeQuery.cpp +++ b/src/Interpreters/executeQuery.cpp @@ -641,7 +641,7 @@ static std::tuple executeQueryImpl( select_with_union->set_of_modes.size(), select_with_union->list_of_selects->getID()); } - if (!context->getSettings().use_cluster_for_distributed_shuffle.value.empty()) + if (!context->getSettingsRef().use_cluster_for_distributed_shuffle.value.empty() && context->getSettingsRef().enable_distribute_shuffle) { MakeFunctionColumnAliasAction function_alias_action; ASTDepthFirstVisitor function_alias_visitor(function_alias_action, ast); @@ -666,6 +666,7 @@ static std::tuple executeQueryImpl( ast = add_finish_event_result; } } + interpreter = InterpreterFactory::get(ast, context, SelectQueryOptions(stage).setInternal(internal)); if (context->getCurrentTransaction() && !interpreter->supportsTransactions() && diff --git a/src/Storages/DistributedShuffle/StorageShuffle.cpp b/src/Storages/DistributedShuffle/StorageShuffle.cpp index ca42ac04954e..8a27f0e3dba0 100644 --- a/src/Storages/DistributedShuffle/StorageShuffle.cpp +++ b/src/Storages/DistributedShuffle/StorageShuffle.cpp @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include #include #include @@ -45,11 +45,11 @@ namespace ErrorCodes { extern const int LOGICAL_ERROR; } -class StorageShuffleSource : public SourceWithProgress, WithContext +class StorageShuffleSource : public ISource, WithContext { public: StorageShuffleSource(ContextPtr context_, const String & session_id_, const String & table_id_, const Block & header_) - : SourceWithProgress(header_), WithContext(context_), session_id(session_id_), table_id(table_id_), header(header_) + : ISource(header_), WithContext(context_), session_id(session_id_), table_id(table_id_), header(header_) { } diff --git a/src/Storages/Hive/StorageHiveCluster.h b/src/Storages/Hive/StorageHiveCluster.h index 62c6db815d2b..35df03445cdb 100644 --- a/src/Storages/Hive/StorageHiveCluster.h +++ b/src/Storages/Hive/StorageHiveCluster.h @@ -1,4 +1,5 @@ #pragma once +#include #include #if USE_HIVE #include @@ -66,6 +67,12 @@ class StorageHiveCluster : public IStorage, WithContext void checkAlterIsPossible(const AlterCommands & commands, ContextPtr local_context) const override; void alter(const AlterCommands & params, ContextPtr local_context, AlterLockHolder & alter_lock_holder) override; + std::shared_ptr getStorageHiveSettings() { return storage_settings; } + const String & getHiveMetastoreURL() const { return hive_metastore_url; } + const String & getHiveDatabase() const { return hive_database; } + const String & getHiveTableName() const { return hive_table; } + ASTPtr getPartitionByAst() const { return partition_by_ast; } + private: String cluster_name; String hive_metastore_url;