From 4a4d4944d60eca5491f947ee60606e737cefef54 Mon Sep 17 00:00:00 2001 From: ltrk2 <107155950+ltrk2@users.noreply.github.com> Date: Mon, 12 Sep 2022 16:12:34 -0700 Subject: [PATCH] Implement raw argument extraction to support nesting --- .../KustoFunctions/IParserKQLFunction.cpp | 73 +++++++++++++++++-- .../Kusto/KustoFunctions/IParserKQLFunction.h | 13 +++- .../KustoFunctions/KQLStringFunctions.cpp | 4 +- .../02366_kql_func_string.reference | 9 +++ .../0_stateless/02366_kql_func_string.sql | 9 +++ 5 files changed, 95 insertions(+), 13 deletions(-) diff --git a/src/Parsers/Kusto/KustoFunctions/IParserKQLFunction.cpp b/src/Parsers/Kusto/KustoFunctions/IParserKQLFunction.cpp index 68f9af5c4428..43ce504742f6 100644 --- a/src/Parsers/Kusto/KustoFunctions/IParserKQLFunction.cpp +++ b/src/Parsers/Kusto/KustoFunctions/IParserKQLFunction.cpp @@ -22,14 +22,42 @@ #include -namespace DB -{ -namespace ErrorCodes +namespace DB::ErrorCodes { - extern const int SYNTAX_ERROR; + extern const int NOT_IMPLEMENTED; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int SYNTAX_ERROR; +} + +namespace +{ +constexpr DB::TokenType determineClosingPair(const DB::TokenType token_type) +{ + if (token_type == DB::TokenType::OpeningCurlyBrace) + return DB::TokenType::ClosingCurlyBrace; + else if (token_type == DB::TokenType::OpeningRoundBracket) + return DB::TokenType::ClosingRoundBracket; + else if (token_type == DB::TokenType::OpeningSquareBracket) + return DB::TokenType::ClosingSquareBracket; + + throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Unhandled token: {}", magic_enum::enum_name(token_type)); +} + +constexpr bool isClosingBracket(const DB::TokenType token_type) +{ + return token_type == DB::TokenType::ClosingCurlyBrace || token_type == DB::TokenType::ClosingRoundBracket + || token_type == DB::TokenType::ClosingSquareBracket; +} + +constexpr bool isOpeningBracket(const DB::TokenType token_type) +{ + return token_type == DB::TokenType::OpeningCurlyBrace || token_type == DB::TokenType::OpeningRoundBracket + || token_type == DB::TokenType::OpeningSquareBracket; +} } +namespace DB +{ bool IParserKQLFunction::convert(String & out, IParser::Pos & pos) { return wrapConvertImpl( @@ -88,9 +116,9 @@ String IParserKQLFunction::generateUniqueIdentifier() return std::to_string(unique_random_generator()); } -String IParserKQLFunction::getArgument(const String & function_name, DB::IParser::Pos & pos) +String IParserKQLFunction::getArgument(const String & function_name, DB::IParser::Pos & pos, const ArgumentState argument_state) { - if (auto optionalArgument = getOptionalArgument(function_name, pos)) + if (auto optionalArgument = getOptionalArgument(function_name, pos, argument_state)) return std::move(*optionalArgument); throw Exception(std::format("Required argument was not provided in {}", function_name), ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); @@ -142,13 +170,42 @@ String IParserKQLFunction::getConvertedArgument(const String & fn_name, IParser: return converted_arg; } -std::optional IParserKQLFunction::getOptionalArgument(const String & function_name, DB::IParser::Pos & pos) +std::optional +IParserKQLFunction::getOptionalArgument(const String & function_name, DB::IParser::Pos & pos, const ArgumentState argument_state) { if (const auto & type = pos->type; type != DB::TokenType::Comma && type != DB::TokenType::OpeningRoundBracket) return {}; ++pos; - return getConvertedArgument(function_name, pos); + if (argument_state == ArgumentState::Parsed) + return getConvertedArgument(function_name, pos); + + if (argument_state != ArgumentState::Raw) + throw Exception( + ErrorCodes::NOT_IMPLEMENTED, + "Argument extraction is not implemented for {}::{}", + magic_enum::enum_type_name(), + magic_enum::enum_name(argument_state)); + + String expression; + std::vector scopes; + while (!pos->isEnd() && (!scopes.empty() || (pos->type != DB::TokenType::Comma && pos->type != DB::TokenType::ClosingRoundBracket))) + { + if (const auto token_type = pos->type; isOpeningBracket(token_type)) + scopes.push_back(token_type); + else if (isClosingBracket(token_type)) + { + if (scopes.empty() || determineClosingPair(scopes.back()) != token_type) + throw Exception(DB::ErrorCodes::SYNTAX_ERROR, "Unmatched token: {} when parsing {}", magic_enum::enum_name(token_type), function_name); + + scopes.pop_back(); + } + + expression.append(pos->begin, pos->end); + ++pos; + } + + return expression; } String IParserKQLFunction::getKQLFunctionName(IParser::Pos & pos) diff --git a/src/Parsers/Kusto/KustoFunctions/IParserKQLFunction.h b/src/Parsers/Kusto/KustoFunctions/IParserKQLFunction.h index 245b196c8e37..45a65bd1d587 100644 --- a/src/Parsers/Kusto/KustoFunctions/IParserKQLFunction.h +++ b/src/Parsers/Kusto/KustoFunctions/IParserKQLFunction.h @@ -43,14 +43,21 @@ class IParserKQLFunction static String getExpression(IParser::Pos & pos); protected: + enum class ArgumentState + { + Parsed, + Raw + }; + virtual bool convertImpl(String & out, IParser::Pos & pos) = 0; static bool directMapping(String & out, IParser::Pos & pos, const String & ch_fn); static String generateUniqueIdentifier(); - static String getArgument(const String & function_name, DB::IParser::Pos & pos); + static String getArgument(const String & function_name, DB::IParser::Pos & pos, ArgumentState argument_state = ArgumentState::Parsed); static String getConvertedArgument(const String & fn_name, IParser::Pos & pos); - static std::optional getOptionalArgument(const String & function_name, DB::IParser::Pos & pos); - static String kqlCallToExpression(std::string_view function_name, std::initializer_list params, uint32_t max_depth); + static std::optional getOptionalArgument(const String & function_name, DB::IParser::Pos & pos, ArgumentState argument_state = ArgumentState::Parsed); + static String + kqlCallToExpression(std::string_view function_name, std::initializer_list params, uint32_t max_depth); static String kqlCallToExpression(std::string_view function_name, std::span params, uint32_t max_depth); static void validateEndOfFunction(const String & fn_name, IParser::Pos & pos); static String getKQLFunctionName(IParser::Pos & pos); diff --git a/src/Parsers/Kusto/KustoFunctions/KQLStringFunctions.cpp b/src/Parsers/Kusto/KustoFunctions/KQLStringFunctions.cpp index f37563ecccd3..631ea69a1297 100644 --- a/src/Parsers/Kusto/KustoFunctions/KQLStringFunctions.cpp +++ b/src/Parsers/Kusto/KustoFunctions/KQLStringFunctions.cpp @@ -553,8 +553,8 @@ bool Trim::convertImpl(String & out, IParser::Pos & pos) if (fn_name.empty()) return false; - const auto regex = getArgument(fn_name, pos); - const auto source = getArgument(fn_name, pos); + const auto regex = getArgument(fn_name, pos, ArgumentState::Raw); + const auto source = getArgument(fn_name, pos, ArgumentState::Raw); out = kqlCallToExpression("trim_start", {regex, std::format("trim_end({0}, {1})", regex, source)}, pos.max_depth); return true; diff --git a/tests/queries/0_stateless/02366_kql_func_string.reference b/tests/queries/0_stateless/02366_kql_func_string.reference index 25da15bc25db..216be22b05d1 100644 --- a/tests/queries/0_stateless/02366_kql_func_string.reference +++ b/tests/queries/0_stateless/02366_kql_func_string.reference @@ -277,12 +277,21 @@ kusto xxx -- trim() https://www.ibm.com Te st1 + asd +asd +sd -- trim_start() www.ibm.com Te st1// $ +asdw + +asd -- trim_end() https - Te st1 +wasd + +asd -- replace_regex Number was: 1 -- has_any_index() diff --git a/tests/queries/0_stateless/02366_kql_func_string.sql b/tests/queries/0_stateless/02366_kql_func_string.sql index d367ec553c4a..19c228e0708c 100644 --- a/tests/queries/0_stateless/02366_kql_func_string.sql +++ b/tests/queries/0_stateless/02366_kql_func_string.sql @@ -199,12 +199,21 @@ print translate('krasp', 'otsku', 'spark'), translate('abc', '', 'ab'), translat print '-- trim()'; print trim("--", "--https://www.ibm.com--"); print trim("[^\w]+", strcat("- ","Te st", "1", "// $")); +print trim("", " asd "); +print trim("a$", "asd"); +print trim("^a", "asd"); print '-- trim_start()'; print trim_start("https://", "https://www.ibm.com"); print trim_start("[^\w]+", strcat("- ","Te st", "1", "// $")); +print trim_start("asd$", "asdw"); +print trim_start("asd$", "asd"); +print trim_start("d$", "asd"); print '-- trim_end()'; print trim_end("://www.ibm.com", "https://www.ibm.com"); print trim_end("[^\w]+", strcat("- ","Te st", "1", "// $")); +print trim_end("^asd", "wasd"); +print trim_end("^asd", "asd"); +print trim_end("^a", "asd"); print '-- replace_regex'; print replace_regex(strcat('Number is ', '1'), 'is (\d+)', 'was: \1'); print '-- has_any_index()';