From 4dbe29c23293f5cd6128a60aaaf07607ad05fa26 Mon Sep 17 00:00:00 2001 From: Pratik Joseph Dabre Date: Fri, 14 Jun 2024 10:00:01 -0700 Subject: [PATCH] Add normalize Presto scalar function (#8590) Summary: Add normalize() Presto scalar function Resolves : https://github.com/prestodb/presto/issues/20224 Reference : https://github.com/prestodb/presto/blob/master/presto-main/src/main/java/com/facebook/presto/operator/scalar/StringFunctions.java#L833 Pull Request resolved: https://github.com/facebookincubator/velox/pull/8590 Reviewed By: bikramSingh91 Differential Revision: D58384754 Pulled By: kevinwilfong fbshipit-source-id: 2ed7a0e7311c3f14bfdba2d3784aee50805fc6ce --- velox/docs/functions/presto/string.rst | 26 +++++++- velox/external/utf8proc/utf8procImpl.h | 2 +- velox/functions/prestosql/StringFunctions.h | 62 +++++++++++++++++++ .../StringFunctionsRegistration.cpp | 4 ++ .../prestosql/tests/StringFunctionsTest.cpp | 54 ++++++++++++++++ 5 files changed, 146 insertions(+), 2 deletions(-) diff --git a/velox/docs/functions/presto/string.rst b/velox/docs/functions/presto/string.rst index 44fb82cef4fa..22d88e833ca4 100644 --- a/velox/docs/functions/presto/string.rst +++ b/velox/docs/functions/presto/string.rst @@ -295,6 +295,30 @@ String Functions Unicode Functions ----------------- +.. function:: normalize(string) -> varchar + + Transforms ``string`` with NFC normalization form. + +.. function:: normalize(string, form) -> varchar + + Reference: https://unicode.org/reports/tr15/#Norm_Forms + Transforms ``string`` with the specified normalization form. + ``form`` must be be one of the following keywords: + + ======== =========== + Form Description + ======== =========== + ``NFD`` Canonical Decomposition + ``NFC`` Canonical Decomposition, followed by Canonical Composition + ``NFKD`` Compatibility Decomposition + ``NFKC`` Compatibility Decomposition, followed by Canonical Composition + ======== =========== + + .. note:: + + This SQL-standard function has special syntax and requires + specifying ``form`` as a keyword, not as a string. + .. function:: to_utf8(string) -> varbinary - Encodes ``string`` into a UTF-8 varbinary representation. + Encodes ``string`` into a UTF-8 varbinary representation. \ No newline at end of file diff --git a/velox/external/utf8proc/utf8procImpl.h b/velox/external/utf8proc/utf8procImpl.h index 32e730240ff1..673de8765c6a 100644 --- a/velox/external/utf8proc/utf8procImpl.h +++ b/velox/external/utf8proc/utf8procImpl.h @@ -635,7 +635,7 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( } decomp_result = utf8proc_decompose_char( uc, - buffer + wpos, + buffer ? buffer + wpos : buffer, (bufsize > wpos) ? (bufsize - wpos) : 0, options, &boundclass); diff --git a/velox/functions/prestosql/StringFunctions.h b/velox/functions/prestosql/StringFunctions.h index 7afaea4e4307..f94b2aeec5fa 100644 --- a/velox/functions/prestosql/StringFunctions.h +++ b/velox/functions/prestosql/StringFunctions.h @@ -494,4 +494,66 @@ struct LevenshteinDistanceFunction { } }; +template +struct NormalizeFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + // Map for holding normalization form options + const static inline std::unordered_map + normalizationOptions{ + {"NFC", (UTF8PROC_STABLE | UTF8PROC_COMPOSE)}, + {"NFD", (UTF8PROC_STABLE | UTF8PROC_DECOMPOSE)}, + {"NFKC", (UTF8PROC_STABLE | UTF8PROC_COMPOSE | UTF8PROC_COMPAT)}, + {"NFKD", (UTF8PROC_STABLE | UTF8PROC_DECOMPOSE | UTF8PROC_COMPAT)}}; + + FOLLY_ALWAYS_INLINE void initialize( + const std::vector& /*inputTypes*/, + const core::QueryConfig& /*config*/, + const arg_type* /*string*/, + const arg_type* form) { + VELOX_USER_CHECK_NOT_NULL(form); + VELOX_USER_CHECK_NE( + normalizationOptions.count(*form), + 0, + "Normalization form must be one of [NFD, NFC, NFKD, NFKC]"); + } + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& string) { + doCall(result, string, "NFC"); + } + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& string, + const arg_type& form) { + doCall(result, string, form); + } + + // Note: This function newly allocates output using malloc so it should be + // free'd at the end. + FOLLY_ALWAYS_INLINE void doCall( + out_type& result, + const arg_type& string, + const arg_type& form) { + utf8proc_uint8_t* output = nullptr; + auto outputLength = utf8proc_map( + (utf8proc_uint8_t*)string.data(), + string.size(), + &output, + normalizationOptions.at(form)); + if (outputLength < 0) { + result = string; + } else { + result.resize(outputLength); + if (result.data()) { + std::memcpy( + result.data(), reinterpret_cast(output), outputLength); + } + } + free(output); + } +}; + } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp b/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp index c1f4d4e471bd..f3847fd7b37d 100644 --- a/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp @@ -133,6 +133,10 @@ void registerStringFunctions(const std::string& prefix) { registerFunction( {prefix + "strrpos"}); + registerFunction({prefix + "normalize"}); + registerFunction( + {prefix + "normalize"}); + // word_stem function registerFunction({prefix + "word_stem"}); registerFunction( diff --git a/velox/functions/prestosql/tests/StringFunctionsTest.cpp b/velox/functions/prestosql/tests/StringFunctionsTest.cpp index 4c05965fb56a..751087e1d24a 100644 --- a/velox/functions/prestosql/tests/StringFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/StringFunctionsTest.cpp @@ -1947,3 +1947,57 @@ TEST_F(StringFunctionsTest, hammingDistance) { hammingDistance("\xFF\x82\xFF", "\xF0\x82"), "The input strings to hamming_distance function must have the same length"); } + +TEST_F(StringFunctionsTest, normalize) { + const auto normalizeWithoutForm = [&](std::optional string) { + return evaluateOnce("normalize(c0)", string); + }; + + const auto normalizeWithForm = [&](std::optional string, + const std::string& form) { + return evaluateOnce( + fmt::format("normalize(c0, '{}')", form), string); + }; + + EXPECT_EQ(normalizeWithoutForm(std::nullopt), std::nullopt); + EXPECT_EQ(normalizeWithoutForm(""), ""); + EXPECT_EQ(normalizeWithoutForm("sch\u00f6n"), "sch\u00f6n"); + EXPECT_EQ(normalizeWithForm(std::nullopt, "NFD"), std::nullopt); + EXPECT_EQ(normalizeWithForm("", "NFKC"), ""); + EXPECT_EQ( + normalizeWithForm( + (normalizeWithForm("sch\u00f6n", "NFD"), "scho\u0308n"), "NFC"), + "sch\u00f6n"); + EXPECT_EQ( + normalizeWithForm( + (normalizeWithForm("sch\u00f6n", "NFKD"), "scho\u0308n"), "NFKC"), + "sch\u00f6n"); + EXPECT_EQ( + normalizeWithForm("Hello world from Velox!!", "NFKC"), + "Hello world from Velox!!"); + + std::string testStringOne = + "\u3231\u3327\u3326\u2162\u3231\u3327\u3326\u2162\u3231\u3327\u3326\u2162"; + std::string testStringTwo = + "(\u682a)\u30c8\u30f3\u30c9\u30ebIII(\u682a)\u30c8\u30f3\u30c9\u30ebIII(\u682a)\u30c8\u30f3\u30c9\u30ebIII"; + EXPECT_EQ(normalizeWithForm(testStringOne, "NFKC"), testStringTwo); + EXPECT_EQ( + normalizeWithForm((normalizeWithForm(testStringTwo, "NFC")), "NFKC"), + testStringTwo); + + std::string testStringThree = + "\uff8a\uff9d\uff76\uff78\uff76\uff85\uff8a\uff9d\uff76\uff78\uff76\uff85\uff8a\uff9d\uff76\uff78\uff76\uff85\uff8a\uff9d\uff76\uff78\uff76\uff85"; + std::string testStringFour = + "\u30cf\u30f3\u30ab\u30af\u30ab\u30ca\u30cf\u30f3\u30ab\u30af\u30ab\u30ca\u30cf\u30f3\u30ab\u30af\u30ab\u30ca\u30cf\u30f3\u30ab\u30af\u30ab\u30ca"; + EXPECT_EQ(normalizeWithForm(testStringThree, "NFKC"), testStringFour); + EXPECT_EQ( + normalizeWithForm((normalizeWithForm(testStringFour, "NFD")), "NFKC"), + testStringFour); + + // Invalid UTF-8 string + std::string inValidTestString = "\xEF\xBE\x8"; + EXPECT_EQ(normalizeWithForm(inValidTestString, "NFKC"), inValidTestString); + VELOX_ASSERT_THROW( + normalizeWithForm("sch\u00f6n", "NFKE"), + "Normalization form must be one of [NFD, NFC, NFKD, NFKC]"); +}