From bd436659409fc8c3031a4c25ce2d170655749d40 Mon Sep 17 00:00:00 2001 From: Zhiyuan Liang <132966438+Ami11111@users.noreply.github.com> Date: Thu, 17 Oct 2024 19:09:55 +0800 Subject: [PATCH] Support Substring function (#2061) ### What problem does this PR solve? Support Substring function ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Test cases --- example/http/functions.sh | 15 ++++ src/function/scalar/regex.cpp | 4 +- src/function/scalar/substring.cpp | 81 ++++++++----------- src/function/scalar_function.cppm | 56 ++++++++++++- .../operator/ternary_operator.cppm | 15 ++++ test/sql/dql/type/varchar.slt | 18 +++++ 6 files changed, 137 insertions(+), 52 deletions(-) diff --git a/example/http/functions.sh b/example/http/functions.sh index e466321619..cce7210da6 100644 --- a/example/http/functions.sh +++ b/example/http/functions.sh @@ -203,6 +203,21 @@ curl --request GET \ "filter": "regex(body, '\''('[0-9A-Za-z_]+'('[-+.][0-9A-Za-z_]+')''*'')'@'('[0-9A-Za-z_]+'('[-.][0-9A-Za-z_]+')''*'')''\\'.'('[0-9A-Za-z_]+'('[-.][0-9A-Za-z_]+')''*'')'\'')" } ' +# show rows of 'tbl1' where first 4 chars of body is 'test' +echo -e '\n\n-- show rows of 'tbl1' where first 4 chars of body is 'test'' +curl --request GET \ + --url http://localhost:23820/databases/default_db/tables/tbl1/docs \ + --header 'accept: application/json' \ + --header 'content-type: application/json' \ + --data ' + { + "output": + [ + "body" + ], + "filter": "substring(body, 0, 4) = '\'test\''" + } ' + # drop tbl1 echo -e '\n\n-- drop tbl1' curl --request DELETE \ diff --git a/src/function/scalar/regex.cpp b/src/function/scalar/regex.cpp index 82edaa4f60..3ca6f7998a 100644 --- a/src/function/scalar/regex.cpp +++ b/src/function/scalar/regex.cpp @@ -42,11 +42,11 @@ void RegisterRegexFunction(const UniquePtr &catalog_ptr){ SharedPtr function_set_ptr = MakeShared(func_name); - ScalarFunction Regex_function(func_name, + ScalarFunction regex_function(func_name, {DataType(LogicalType::kVarchar), DataType(LogicalType::kVarchar)}, DataType(LogicalType::kBoolean), &ScalarFunction::BinaryFunction); - function_set_ptr->AddFunction(Regex_function); + function_set_ptr->AddFunction(regex_function); Catalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr); } diff --git a/src/function/scalar/substring.cpp b/src/function/scalar/substring.cpp index 76262bfd7e..6661a2b0f1 100644 --- a/src/function/scalar/substring.cpp +++ b/src/function/scalar/substring.cpp @@ -33,59 +33,48 @@ namespace infinity { struct SubstrFunction { template - static inline bool Run(TA, TB, TC, TD &, ColumnVector *) { + static inline bool Run(TA &first, TB &second, TC &third, TD &result, ColumnVector *first_ptr, ColumnVector *result_ptr) { String error_message = "Not implement: SubstrFunction::Run"; UnrecoverableError(error_message); } }; template <> -inline bool SubstrFunction::Run(VarcharT, BigIntT, BigIntT, VarcharT &, ColumnVector *) { - // Validate the input before slice the string - String error_message = "Not implement: SubstrFunction::Run"; - UnrecoverableError(error_message); +inline bool SubstrFunction::Run(VarcharT &first, BigIntT &second, BigIntT &third, VarcharT &result, ColumnVector *first_ptr, ColumnVector * result_ptr) { + if (second < 0) { + UnrecoverableError(fmt::format("substring start offset should >= 0, currently it is {}", second)); + } + if (third < 0) { + UnrecoverableError(fmt::format("substring length should >= 0, currently it is {}", second)); + } -// if (second < 0) { -// Error(fmt::format("substring start offset should >= 0, currently it is {}", second)); -// } -// -// if (third < 0) { -// Error(fmt::format("substring length should >= 0, currently it is {}", second)); -// } -// -// if (third == 0) { -// // Construct empty varchar value; -// result.InitializeAsEmptyStr(); -// return true; -// } -// -// SizeT source_len = first.GetDataLen(); -// if (second >= source_len) { -// // Construct empty varchar value; -// result.InitializeAsEmptyStr(); -// return true; -// } -// -// SizeT start_offset = second; -// SizeT end_offset = 0; -// if (start_offset + third >= source_len) { -// end_offset = source_len; -// } else { -// end_offset = start_offset + third; -// } -// -// SizeT copied_length = end_offset - start_offset; -// ptr_t source_ptr = first.GetDataPtr(); -// if (copied_length <= VarcharT::INLINE_LENGTH) { -// // inline varchar -// std::memcpy(result.prefix, source_ptr + start_offset, copied_length); -// result.length = copied_length; -// } else { -// std::memcpy(result.prefix, source_ptr + start_offset, VarcharT::INLINE_LENGTH); -// result.ptr = column_vector_ptr->buffer_->fix_heap_mgr_->Allocate(copied_length); -// std::memcpy(result.ptr, source_ptr + start_offset, copied_length); -// } + Span first_v = first_ptr->GetVarcharInner(first); + if (third == 0) { + // Construct empty varchar value; + Span substr_span = Span(first_v.data(), 0); + result_ptr->AppendVarcharInner(substr_span, result); + return true; + } + + SizeT source_len = first_v.size(); + if ((SizeT)second >= source_len) { + // Construct empty varchar value; + Span substr_span = Span(first_v.data(), 0);; + result_ptr->AppendVarcharInner(substr_span, result); + return true; + } + + SizeT start_offset = second; + SizeT end_offset = 0; + if (start_offset + third >= source_len) { + end_offset = source_len; + } else { + end_offset = start_offset + third; + } + + Span substr_span = Span(first_v.data() + start_offset, end_offset - start_offset); + result_ptr->AppendVarcharInner(substr_span, result); return true; } @@ -98,7 +87,7 @@ void RegisterSubstringFunction(const UniquePtr &catalog_ptr) { ScalarFunction varchar_substr(func_name, {DataType(LogicalType::kVarchar), DataType(LogicalType::kBigInt), DataType(LogicalType::kBigInt)}, {DataType(LogicalType::kVarchar)}, - &ScalarFunction::TernaryFunctionToVarlenWithFailure); + &ScalarFunction::TernaryFunctionVarlenToVarlenWithFailure); function_set_ptr->AddFunction(varchar_substr); Catalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr); diff --git a/src/function/scalar_function.cppm b/src/function/scalar_function.cppm index f93ac381fb..fa7e4e7eb4 100644 --- a/src/function/scalar_function.cppm +++ b/src/function/scalar_function.cppm @@ -60,7 +60,7 @@ struct BinaryOpDirectWrapper { template struct TernaryOpDirectWrapper { template - inline static void Execute(FirstType first, SecondType second, ThirdType third, ResultType &result, Bitmask *, SizeT, void *) { + inline static void Execute(FirstType first, SecondType second, ThirdType third, ResultType &result, Bitmask *, SizeT, void *, void *) { return Operator::template Run(first, second, third, result); } }; @@ -94,7 +94,7 @@ struct BinaryTryOpWrapper { template struct TernaryTryOpWrapper { template - inline static void Execute(FirstType first, SecondType second, ThirdType third, ResultType &result, Bitmask *nulls_ptr, SizeT idx, void *) { + inline static void Execute(FirstType first, SecondType second, ThirdType third, ResultType &result, Bitmask *nulls_ptr, SizeT idx, void *, void *) { if (Operator::template Run(first, second, third, result)) { return; } @@ -125,7 +125,7 @@ struct BinaryOpDirectToVarlenWrapper { template struct TernaryOpDirectToVarlenWrapper { template - inline static void Execute(FirstType first, SecondType second, ThirdType third, ResultType &result, Bitmask *, SizeT, void *state_ptr) { + inline static void Execute(FirstType first, SecondType second, ThirdType third, ResultType &result, Bitmask *, SizeT, void *, void *state_ptr) { auto *function_data_ptr = (ScalarFunctionData *)(state_ptr); return Operator::template Run(first, second, @@ -167,7 +167,7 @@ template struct TernaryTryOpToVarlenWrapper { template inline static void - Execute(FirstType first, SecondType second, ThirdType third, ResultType &result, Bitmask *nulls_ptr, SizeT idx, void *state_ptr) { + Execute(FirstType first, SecondType second, ThirdType third, ResultType &result, Bitmask *nulls_ptr, SizeT idx, void *, void *state_ptr) { auto *function_data_ptr = (ScalarFunctionData *)(state_ptr); if (Operator::template Run(first, second, @@ -192,6 +192,27 @@ struct UnaryOpDirectVarlenToVarlenWrapper { } }; +template +struct TernaryTryOpVarlenToVarlenWrapper { + template + inline static void + Execute(FirstType first, SecondType second, ThirdType third, ResultType &result, Bitmask *nulls_ptr, SizeT idx, void *first_ptr, void *state_ptr) { + auto *function_data_ptr_first = (ScalarFunctionData *)(first_ptr); + auto *function_data_ptr = (ScalarFunctionData *)(state_ptr); + if (Operator::template Run(first, + second, + third, + result, + function_data_ptr_first->column_vector_ptr_, + function_data_ptr->column_vector_ptr_)) { + return; + } + + nulls_ptr->SetFalse(idx); + result = NullValue(); + } +}; + using ScalarFunctionType = std::function &)>; @@ -409,6 +430,7 @@ public: output, input.row_count(), nullptr, + nullptr, true); } @@ -429,6 +451,7 @@ public: output, input.row_count(), nullptr, + nullptr, true); } @@ -449,6 +472,7 @@ public: input.column_vectors[2], output, input.row_count(), + nullptr, &function_data, true); } @@ -470,6 +494,30 @@ public: input.column_vectors[2], output, input.row_count(), + nullptr, + &function_data, + true); + } + + // Ternary function result is varlen with some failures such as overflow. + template + static inline void TernaryFunctionVarlenToVarlenWithFailure(const DataBlock &input, SharedPtr &output) { + if (input.column_count() != 3) { + String error_message = "Ternary function: input column count isn't three."; + UnrecoverableError(error_message); + } + if (!input.Finalized()) { + String error_message = "Input data block is finalized"; + UnrecoverableError(error_message); + } + ScalarFunctionData function_data_first(input.column_vectors[0].get()); + ScalarFunctionData function_data(output.get()); + TernaryOperator::Execute>(input.column_vectors[0], + input.column_vectors[1], + input.column_vectors[2], + output, + input.row_count(), + &function_data_first, &function_data, true); } diff --git a/src/storage/column_vector/operator/ternary_operator.cppm b/src/storage/column_vector/operator/ternary_operator.cppm index bd835eafc3..a07f096090 100644 --- a/src/storage/column_vector/operator/ternary_operator.cppm +++ b/src/storage/column_vector/operator/ternary_operator.cppm @@ -32,6 +32,7 @@ public: const SharedPtr &third, SharedPtr &result, SizeT count, + void *state_ptr_first, void *state_ptr, bool nullable) { @@ -70,6 +71,7 @@ public: result_ptr, result_null, count, + state_ptr_first, state_ptr); } else { ExecuteFFF(first_ptr, @@ -78,6 +80,7 @@ public: result_ptr, result_null, count, + state_ptr_first, state_ptr); } result->Finalize(count); @@ -110,6 +113,7 @@ public: result_ptr, result_null, count, + state_ptr_first, state_ptr); } else { ExecuteFCC(first_ptr, @@ -118,6 +122,7 @@ public: result_ptr, result_null, count, + state_ptr_first, state_ptr); } result->Finalize(count); @@ -159,6 +164,7 @@ public: result_ptr[0], result_null.get(), 0, + state_ptr_first, state_ptr); } else { result->nulls_ptr_->SetFalse(0); @@ -171,6 +177,7 @@ public: result_ptr[0], result_null.get(), 0, + state_ptr_first, state_ptr); } result->Finalize(1); @@ -187,6 +194,7 @@ private: ResultType *__restrict result_ptr, SharedPtr &result_null, SizeT count, + void *state_ptr_first, void *state_ptr) { for (SizeT i = 0; i < count; i++) { Operator::template Execute(first_ptr[i], @@ -195,6 +203,7 @@ private: result_ptr[i], result_null.get(), i, + state_ptr_first, state_ptr); } } @@ -209,6 +218,7 @@ private: ResultType *__restrict result_ptr, SharedPtr &result_null, SizeT count, + void *state_ptr_first, void *state_ptr) { *result_null = *first_null; result_null->MergeAnd(*second_null); @@ -223,6 +233,7 @@ private: result_ptr[i], result_null.get(), i, + state_ptr_first, state_ptr); return i + 1 < count; }); @@ -236,6 +247,7 @@ private: ResultType *__restrict result_ptr, SharedPtr &result_null, SizeT count, + void *state_ptr_first, void *state_ptr) { for (SizeT i = 0; i < count; i++) { Operator::template Execute(first_ptr[i], @@ -244,6 +256,7 @@ private: result_ptr[i], result_null.get(), i, + state_ptr_first, state_ptr); } } @@ -258,6 +271,7 @@ private: ResultType *__restrict result_ptr, SharedPtr &result_null, SizeT count, + void *state_ptr_first, void *state_ptr) { *result_null = *first_null; if (!(second_null->IsAllTrue() && third_null->IsAllTrue())) { @@ -273,6 +287,7 @@ private: result_ptr[i], result_null.get(), i, + state_ptr_first, state_ptr); return i + 1 < count; }); diff --git a/test/sql/dql/type/varchar.slt b/test/sql/dql/type/varchar.slt index faf2b6c99c..003bf65bb0 100644 --- a/test/sql/dql/type/varchar.slt +++ b/test/sql/dql/type/varchar.slt @@ -87,5 +87,23 @@ SELECT * FROM test_varchar_filter where regex(c1, '(\w+([-+.]\w+)*)@(\w+([-.]\w+ ---- regex@regex.com gmail@gmail.com 6 +query XIII +SELECT * FROM test_varchar_filter where substring(c1, 0, 4) = 'abcd'; +---- +abcddddd abcddddd 1 +abcddddc abcddddd 2 +abcdddde abcddddd 3 +abcdddde abcdddde 4 + +query XIIII +SELECT * FROM test_varchar_filter where substring(c1, 0, 0) = ''; +---- +abcddddd abcddddd 1 +abcddddc abcddddd 2 +abcdddde abcddddd 3 +abcdddde abcdddde 4 +abc abcd 5 +regex@regex.com gmail@gmail.com 6 + statement ok DROP TABLE test_varchar_filter;