From 9d6f6ea9ee71ba936db04c1f33e41d606222ab82 Mon Sep 17 00:00:00 2001 From: Jimmy Lu Date: Fri, 20 Dec 2024 17:44:25 -0800 Subject: [PATCH] fix: Optimize json_parse Differential Revision: D67538322 --- velox/common/base/SortingNetwork.h | 161 +++++++ velox/functions/lib/Utf8Utils.cpp | 27 +- velox/functions/lib/Utf8Utils.h | 8 +- velox/functions/lib/tests/Utf8Test.cpp | 11 - velox/functions/prestosql/JsonFunctions.cpp | 403 ++++++++++-------- .../prestosql/json/JsonStringUtil.cpp | 202 ++++++--- .../functions/prestosql/json/JsonStringUtil.h | 49 ++- .../prestosql/tests/JsonFunctionsTest.cpp | 4 +- velox/functions/prestosql/types/JsonType.cpp | 4 +- 9 files changed, 575 insertions(+), 294 deletions(-) create mode 100644 velox/common/base/SortingNetwork.h diff --git a/velox/common/base/SortingNetwork.h b/velox/common/base/SortingNetwork.h new file mode 100644 index 0000000000000..abf5415fef5b9 --- /dev/null +++ b/velox/common/base/SortingNetwork.h @@ -0,0 +1,161 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox { + +constexpr int kSortingNetworkMaxSize = 32; + +template +void sortingNetwork(T* data, int size, LessThan&& lt); + +namespace detail { + +// Compile time generated Bose-Nelson sorting network. +// +// https://bertdobbelaere.github.io/sorting_networks.html +// https://github.com/Vectorized/Static-Sort/blob/master/include/static_sort.h +template +class SortingNetworkImpl { + public: + template + static void apply(T* data, LessThan&& lt) { + PS ps(data, lt); + } + + private: + template + static void compareExchange(T* data, LessThan lt) { + // This is branchless if `lt' is branchless. + auto c = lt(data[I], data[J]); + auto min = c ? data[I] : data[J]; + data[J] = c ? data[J] : data[I]; + data[I] = min; + } + + template + struct PB { + PB(T* data, LessThan lt) { + enum { + L = X >> 1, + M = (X & 1 ? Y : Y + 1) >> 1, + IAddL = I + L, + XSubL = X - L, + }; + PB p0(data, lt); + PB p1(data, lt); + PB p2(data, lt); + } + }; + + template + struct PB { + PB(T* data, LessThan lt) { + compareExchange(data, lt); + } + }; + + template + struct PB { + PB(T* data, LessThan lt) { + compareExchange(data, lt); + compareExchange(data, lt); + } + }; + + template + struct PB { + PB(T* data, LessThan lt) { + compareExchange(data, lt); + compareExchange(data, lt); + } + }; + + template + struct PS { + PS(T* data, LessThan lt) { + enum { L = M >> 1, IAddL = I + L, MSubL = M - L }; + PS ps0(data, lt); + PS ps1(data, lt); + PB pb(data, lt); + } + }; + + template + struct PS { + PS(T* /*data*/, LessThan /*lt*/) {} + }; +}; + +} // namespace detail + +template +void sortingNetwork(T* data, int size, LessThan&& lt) { + switch (size) { + case 0: + case 1: + return; + +#ifdef VELOX_SORTING_NETWORK_IMPL_APPLY_CASE +#error "Macro name clash: VELOX_SORTING_NETWORK_IMPL_APPLY_CASE" +#endif +#define VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(_n) \ + case _n: \ + detail::SortingNetworkImpl<_n>::apply(data, std::forward(lt)); \ + return; + + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(2) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(3) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(4) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(5) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(6) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(7) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(8) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(9) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(10) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(11) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(12) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(13) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(14) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(15) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(16) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(17) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(18) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(19) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(20) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(21) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(22) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(23) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(24) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(25) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(26) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(27) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(28) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(29) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(30) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(31) + VELOX_SORTING_NETWORK_IMPL_APPLY_CASE(32) + +#undef VELOX_SORTING_NETWORK_IMPL_APPLY_CASE + + default: + VELOX_UNREACHABLE(); + } +} + +} // namespace facebook::velox diff --git a/velox/functions/lib/Utf8Utils.cpp b/velox/functions/lib/Utf8Utils.cpp index dded518c6b61b..b5bc77196fd74 100644 --- a/velox/functions/lib/Utf8Utils.cpp +++ b/velox/functions/lib/Utf8Utils.cpp @@ -173,26 +173,6 @@ tryGetUtf8CharLength(const char* input, int64_t size, int32_t& codePoint) { return -1; } -bool hasInvalidUTF8(const char* input, int32_t len) { - for (size_t inputIndex = 0; inputIndex < len;) { - if (IS_ASCII(input[inputIndex])) { - // Ascii - inputIndex++; - } else { - // Unicode - int32_t codePoint; - auto charLength = - tryGetUtf8CharLength(input + inputIndex, len - inputIndex, codePoint); - if (charLength < 0) { - return true; - } - inputIndex += charLength; - } - } - - return false; -} - size_t replaceInvalidUTF8Characters( char* outputBuffer, const char* input, @@ -213,12 +193,9 @@ size_t replaceInvalidUTF8Characters( outputIndex += charLength; inputIndex += charLength; } else { - size_t replaceCharactersToWriteOut = inputIndex < len - 1 && - isMultipleInvalidSequences(input, inputIndex) - ? -charLength - : 1; const auto& replacementCharacterString = - kReplacementCharacterStrings[replaceCharactersToWriteOut - 1]; + getInvalidUTF8ReplacementString( + input + inputIndex, len - inputIndex, -charLength); std::memcpy( outputBuffer + outputIndex, replacementCharacterString.data(), diff --git a/velox/functions/lib/Utf8Utils.h b/velox/functions/lib/Utf8Utils.h index 781f3db6ccc76..3e44421942714 100644 --- a/velox/functions/lib/Utf8Utils.h +++ b/velox/functions/lib/Utf8Utils.h @@ -121,8 +121,12 @@ FOLLY_ALWAYS_INLINE bool isMultipleInvalidSequences( inputBuffer[inputIndex] == '\xc0' || inputBuffer[inputIndex] == '\xc1'; } -/// Returns true only if invalid UTF-8 is present in the input string. -bool hasInvalidUTF8(const char* input, int32_t len); +inline const std::string_view& +getInvalidUTF8ReplacementString(const char* input, int len, int codePointSize) { + auto index = + len >= 2 && isMultipleInvalidSequences(input, 0) ? codePointSize - 1 : 0; + return kReplacementCharacterStrings[index]; +} /// Replaces invalid UTF-8 characters with replacement characters similar to /// that produced by Presto java. The function requires that output have diff --git a/velox/functions/lib/tests/Utf8Test.cpp b/velox/functions/lib/tests/Utf8Test.cpp index 24c0ddbbdcec0..8abc074d73477 100644 --- a/velox/functions/lib/tests/Utf8Test.cpp +++ b/velox/functions/lib/tests/Utf8Test.cpp @@ -104,17 +104,6 @@ TEST(Utf8Test, tryCharLength) { ASSERT_EQ(-1, tryCharLength({0xBF})); } -TEST(UTF8Test, validUtf8) { - auto tryHasInvalidUTF8 = [](const std::vector& bytes) { - return hasInvalidUTF8( - reinterpret_cast(bytes.data()), bytes.size()); - }; - - ASSERT_FALSE(tryHasInvalidUTF8({0x5c, 0x19, 0x7A})); - ASSERT_TRUE(tryHasInvalidUTF8({0x5c, 0x19, 0x7A, 0xBF})); - ASSERT_TRUE(tryHasInvalidUTF8({0x64, 0x65, 0x1A, 0b11100000, 0x81, 0xBF})); -} - TEST(UTF8Test, replaceInvalidUTF8Characters) { auto testReplaceInvalidUTF8Chars = [](const std::string& input, const std::string& expected) { diff --git a/velox/functions/prestosql/JsonFunctions.cpp b/velox/functions/prestosql/JsonFunctions.cpp index 3df47a7dd39ed..8007149a1709e 100644 --- a/velox/functions/prestosql/JsonFunctions.cpp +++ b/velox/functions/prestosql/JsonFunctions.cpp @@ -13,8 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/base/SortingNetwork.h" #include "velox/expression/VectorFunction.h" #include "velox/functions/lib/Utf8Utils.h" +#include "velox/functions/lib/string/StringImpl.h" #include "velox/functions/prestosql/json/JsonStringUtil.h" #include "velox/functions/prestosql/json/SIMDJsonUtil.h" #include "velox/functions/prestosql/types/JsonType.h" @@ -22,6 +24,7 @@ namespace facebook::velox::functions { namespace { + const std::string_view kArrayStart = "["; const std::string_view kArrayEnd = "]"; const std::string_view kSeparator = ","; @@ -46,65 +49,46 @@ inline void addOrMergeViews(JsonViews& jsonViews, std::string_view view) { } } -/// Class to keep track of json strings being written -/// in to a buffer. The size of the backing buffer must be known during -/// construction time. -class BufferTracker { - public: - explicit BufferTracker(BufferPtr buffer) : curPos_(0), currentViewStart_(0) { - bufPtr_ = buffer->asMutable(); - capacity = buffer->capacity(); - } - - /// Write out all the views to the buffer. - auto getCanonicalString(JsonViews& jsonViews) { - for (auto view : jsonViews) { - trimEscapeWriteToBuffer(view); - } - return getStringView(); - } - - /// Sets current view to the end of the previous string. - /// Should be called only after getCanonicalString , - /// as after this call the previous view is lost. - void startNewString() { - currentViewStart_ += curPos_; - curPos_ = 0; - } - - private: - /// Trims whitespace and escapes utf characters before writing to buffer. - void trimEscapeWriteToBuffer(std::string_view input) { - auto trimmed = velox::util::trimWhiteSpace(input.data(), input.size()); - auto curBufPtr = getCurrentBufferPtr(); - auto bytesWritten = - prestoJavaEscapeString(trimmed.data(), trimmed.size(), curBufPtr); - incrementCounter(bytesWritten); - } - - /// Returns current string view against the buffer. - std::string_view getStringView() { - return std::string_view(bufPtr_ + currentViewStart_, curPos_); +void addOrMergeChar(JsonViews& views, std::string_view view) { + VELOX_DCHECK_EQ(view.size(), 1); + if (views.empty()) { + views.push_back(view); + return; } - - inline char* getCurrentBufferPtr() { - return bufPtr_ + currentViewStart_ + curPos_; + auto& last = views.back(); + // OK to do this because input is padded. + if (*last.end() == view[0]) { + last = std::string_view(last.data(), last.size() + 1); + } else { + views.push_back(view); } +} - void incrementCounter(size_t increment) { - VELOX_DCHECK_LE(curPos_ + currentViewStart_ + increment, capacity); - curPos_ += increment; +std::string_view trimToken(std::string_view token) { + VELOX_DCHECK(!stringImpl::isAsciiWhiteSpace(token[0])); + auto size = token.size(); + while (stringImpl::isAsciiWhiteSpace(token[size - 1])) { + --size; } + return std::string_view(token.data(), size); +} - size_t capacity; - size_t curPos_; - size_t currentViewStart_; - char* bufPtr_; +struct JsonField { + std::string_view key; + int32_t offset; + int32_t size; }; -} // namespace +size_t concatViews(const JsonViews& views, char* out) { + size_t total = 0; + for (auto& v : views) { + memcpy(out, v.data(), v.size()); + total += v.size(); + out += v.size(); + } + return total; +} -namespace { class JsonFormatFunction : public exec::VectorFunction { public: void apply( @@ -167,32 +151,26 @@ class JsonParseFunction : public exec::VectorFunction { const auto& arg = args[0]; if (arg->isConstantEncoding()) { auto value = arg->as>()->valueAt(0); - auto size = value.size(); - if (FOLLY_UNLIKELY(hasInvalidUTF8(value.data(), value.size()))) { - size = replaceInvalidUTF8Characters( - paddedInput_.data(), value.data(), value.size()); - paddedInput_.resize(size + simdjson::SIMDJSON_PADDING); - } else { - paddedInput_.resize(size + simdjson::SIMDJSON_PADDING); - memcpy(paddedInput_.data(), value.data(), size); - } - - auto escapeSize = escapedStringSize(value.data(), size); - auto buffer = AlignedBuffer::allocate(escapeSize, context.pool()); - BufferTracker bufferTracker{buffer}; - - JsonViews jsonViews; - - if (auto error = parse(size, jsonViews)) { + bool needNormalize = + needNormalizeForJsonParse(value.data(), value.size()); + auto size = needNormalize + ? normalizedSizeForJsonParse(value.data(), value.size()) + : value.size(); + paddedInput_.resize(size + simdjson::SIMDJSON_PADDING); + VELOX_CHECK_EQ(prepareInput(value, needNormalize), size); + + auto buffer = AlignedBuffer::allocate(size, context.pool()); + if (auto error = parse(size, needNormalize)) { context.setErrors(rows, errors_[error]); return; } + auto* output = buffer->asMutable(); + auto outputSize = concatViews(views_, output); BufferPtr stringViews = - AlignedBuffer::allocate(1, context.pool(), StringView()); + AlignedBuffer::allocate(1, context.pool()); auto rawStringViews = stringViews->asMutable(); - rawStringViews[0] = - StringView(bufferTracker.getCanonicalString(jsonViews)); + rawStringViews[0] = StringView(output, outputSize); auto constantBase = std::make_shared>( context.pool(), @@ -214,43 +192,35 @@ class JsonParseFunction : public exec::VectorFunction { size_t maxSize = 0; size_t totalOutputSize = 0; + std::vector needNormalizes(rows.end()); rows.applyToSelected([&](auto row) { auto value = flatInput->valueAt(row); - maxSize = std::max(maxSize, value.size()); - totalOutputSize += escapedStringSize(value.data(), value.size()); + bool needNormalize = + needNormalizeForJsonParse(value.data(), value.size()); + auto size = needNormalize + ? normalizedSizeForJsonParse(value.data(), value.size()) + : value.size(); + needNormalizes[row] = needNormalize; + maxSize = std::max(maxSize, size); + totalOutputSize += size; }); paddedInput_.resize(maxSize + simdjson::SIMDJSON_PADDING); BufferPtr buffer = AlignedBuffer::allocate(totalOutputSize, context.pool()); - BufferTracker bufferTracker{buffer}; + auto* output = buffer->asMutable(); rows.applyToSelected([&](auto row) { - JsonViews jsonViews; auto value = flatInput->valueAt(row); - auto size = value.size(); - if (FOLLY_UNLIKELY(hasInvalidUTF8(value.data(), size))) { - size = replaceInvalidUTF8Characters( - paddedInput_.data(), value.data(), size); - if (maxSize < size) { - paddedInput_.resize(size + simdjson::SIMDJSON_PADDING); - maxSize = size; - } - } else { - // We clear out the buffer since SIMDJSON peeks past the size of the - // string and can throw if a ':' comes after a '"'. - // issue : https://github.com/simdjson/simdjson/issues/2312 - memset(paddedInput_.data(), 0, paddedInput_.size()); - memcpy(paddedInput_.data(), value.data(), size); - } - - if (auto error = parse(size, jsonViews)) { + auto size = prepareInput(value, needNormalizes[row]); + if (auto error = parse(size, needNormalizes[row])) { context.setVeloxExceptionError(row, errors_[error]); - } else { - auto canonicalString = bufferTracker.getCanonicalString(jsonViews); - - rawStringViews[row] = StringView(canonicalString); - bufferTracker.startNewString(); + return; + } + auto outputSize = concatViews(views_, output); + rawStringViews[row] = StringView(output, outputSize); + if (!StringView::isInline(outputSize)) { + output += outputSize; } }); @@ -275,116 +245,199 @@ class JsonParseFunction : public exec::VectorFunction { } private: - simdjson::error_code parse(size_t size, JsonViews& jsonViews) const { + struct FastSortKey { + static constexpr int kSize = 3; + std::array value; + }; + + size_t prepareInput(const StringView& value, bool needNormalize) const { + size_t outSize; + if (needNormalize) { + outSize = normalizeForJsonParse( + value.data(), value.size(), paddedInput_.data()); + } else { + memcpy(paddedInput_.data(), value.data(), value.size()); + outSize = value.size(); + } + memset(paddedInput_.data() + outSize, 0, simdjson::SIMDJSON_PADDING); + return outSize; + } + + simdjson::error_code parse(size_t size, bool needNormalize) const { simdjson::padded_string_view paddedInput( paddedInput_.data(), size, paddedInput_.size()); SIMDJSON_ASSIGN_OR_RAISE(auto doc, simdjsonParse(paddedInput)); - SIMDJSON_TRY(validate(doc, jsonViews)); + views_.clear(); + if (needNormalize) { + SIMDJSON_TRY((generateViews(doc))); + } else { + SIMDJSON_TRY((generateViews(doc))); + } + VELOX_CHECK(fields_.empty()); if (!doc.at_end()) { return simdjson::TRAILING_CONTENT; } return simdjson::SUCCESS; } - template - static simdjson::error_code validate(T value, JsonViews& jsonViews) { + template + simdjson::error_code generateViews(T value) const { SIMDJSON_ASSIGN_OR_RAISE(auto type, value.type()); switch (type) { case simdjson::ondemand::json_type::array: { SIMDJSON_ASSIGN_OR_RAISE(auto array, value.get_array()); - - jsonViews.push_back(kArrayStart); - auto jsonViewsSize = jsonViews.size(); - for (auto elementOrError : array) { - SIMDJSON_ASSIGN_OR_RAISE(auto element, elementOrError); - SIMDJSON_TRY(validate(element, jsonViews)); - jsonViews.push_back(kSeparator); - } - - // If the array is not empty, remove the last separator. - if (jsonViews.size() > jsonViewsSize) { - jsonViews.pop_back(); - } - - jsonViews.push_back(kArrayEnd); - - return simdjson::SUCCESS; + return generateViewsFromArray(array); } - case simdjson::ondemand::json_type::object: { SIMDJSON_ASSIGN_OR_RAISE(auto object, value.get_object()); - - std::vector> objFields; - for (auto fieldOrError : object) { - SIMDJSON_ASSIGN_OR_RAISE(auto field, fieldOrError); - auto key = field.key_raw_json_token(); - JsonViews elementArray; - SIMDJSON_TRY(validate(field.value(), elementArray)); - objFields.push_back({key, elementArray}); - } - - std::sort(objFields.begin(), objFields.end(), [](auto& a, auto& b) { - // Remove the quotes from the keys before we sort them. - auto af = std::string_view{a.first.data() + 1, a.first.size() - 2}; - auto bf = std::string_view{b.first.data() + 1, b.first.size() - 2}; - return lessThan(af, bf); - }); - - jsonViews.push_back(kObjectStart); - - for (auto i = 0; i < objFields.size(); i++) { - auto field = objFields[i]; - addOrMergeViews(jsonViews, field.first); - jsonViews.push_back(kObjectKeySeparator); - - for (auto& element : field.second) { - addOrMergeViews(jsonViews, element); - } - - if (i < objFields.size() - 1) { - jsonViews.push_back(kSeparator); - } - } - - jsonViews.push_back(kObjectEnd); - return simdjson::SUCCESS; + return generateViewsFromObject(object); } - - case simdjson::ondemand::json_type::number: { - SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json()); - addOrMergeViews(jsonViews, rawJson); - + case simdjson::ondemand::json_type::number: + addOrMergeViews(views_, trimToken(value.raw_json_token())); return value.get_double().error(); - } - case simdjson::ondemand::json_type::string: { - SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json()); - addOrMergeViews(jsonViews, rawJson); - + case simdjson::ondemand::json_type::string: + addOrMergeViews(views_, trimToken(value.raw_json_token())); return value.get_string().error(); + case simdjson::ondemand::json_type::boolean: + addOrMergeViews(views_, trimToken(value.raw_json_token())); + return value.get_bool().error(); + case simdjson::ondemand::json_type::null: + SIMDJSON_ASSIGN_OR_RAISE(auto isNull, value.is_null()); + addOrMergeViews(views_, trimToken(value.raw_json_token())); + return isNull ? simdjson::SUCCESS : simdjson::N_ATOM_ERROR; + } + } + + template + simdjson::error_code generateViewsFromArray( + simdjson::ondemand::array array) const { + addOrMergeChar(views_, kArrayStart); + bool first = true; + for (auto elementOrError : array) { + SIMDJSON_ASSIGN_OR_RAISE(auto element, elementOrError); + if (first) { + first = false; + } else { + addOrMergeChar(views_, kSeparator); } + SIMDJSON_TRY(generateViews(element)); + } + addOrMergeChar(views_, kArrayEnd); + return simdjson::SUCCESS; + } - case simdjson::ondemand::json_type::boolean: { - SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json()); - addOrMergeViews(jsonViews, rawJson); + template + simdjson::error_code generateViewsFromObject( + simdjson::ondemand::object object) const { + addOrMergeChar(views_, kObjectStart); + const auto oldNumFields = fields_.size(); + const auto oldNumViews = views_.size(); + for (auto fieldOrError : object) { + auto offset = views_.size(); + SIMDJSON_ASSIGN_OR_RAISE(auto field, fieldOrError); + auto key = field.escaped_key(); + views_.push_back({key.data() - 1, key.size() + 2}); + addOrMergeChar(views_, kObjectKeySeparator); + SIMDJSON_TRY(generateViews(field.value())); + auto& newField = fields_.emplace_back(); + newField.key = key; + newField.offset = offset; + newField.size = views_.size() - offset; + } + sortFields( + fields_.data() + oldNumFields, + fields_.size() - oldNumFields, + oldNumViews); + fields_.resize(oldNumFields); + addOrMergeChar(views_, kObjectEnd); + return simdjson::SUCCESS; + } - return value.get_bool().error(); + template + void sortFields(const JsonField* fields, int numFields, int oldNumViews) + const { + if (numFields <= 1) { + return; + } + const auto sortedBegin = views_.size(); + sortIndices_.resize(numFields); + std::iota(sortIndices_.begin(), sortIndices_.end(), 0); + if constexpr (kNeedNormalize) { + sortIndices([&](int32_t i, int32_t j) { + return lessThanForJsonParse(fields[i].key, fields[j].key); + }); + } else if (!fastSort(fields, numFields)) { + sortIndices( + [&](int32_t i, int32_t j) { return fields[i].key < fields[j].key; }); + } + for (auto i = 0; i < numFields; ++i) { + if (i > 0) { + addOrMergeChar(views_, kSeparator); } + auto& field = fields[sortIndices_[i]]; + for (int j = 0; j < field.size; ++j) { + views_.push_back(views_[field.offset + j]); + } + } + auto numNewViews = views_.size() - sortedBegin; + static_assert(std::is_trivially_copyable_v); + memmove( + &views_[oldNumViews], + &views_[sortedBegin], + sizeof(std::string_view) * numNewViews); + views_.resize(oldNumViews + numNewViews); + } - case simdjson::ondemand::json_type::null: { - SIMDJSON_ASSIGN_OR_RAISE(auto isNull, value.is_null()); - SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json()); - addOrMergeViews(jsonViews, rawJson); - - return isNull ? simdjson::SUCCESS : simdjson::N_ATOM_ERROR; + bool fastSort(const JsonField* fields, int numFields) const { + for (int i = 0; i < numFields; ++i) { + if (fields[i].key.size() > 8 * FastSortKey::kSize) { + return false; + } + } + fastSortKeys_.resize(numFields); + constexpr auto load = [](const char* s) { + return folly::Endian::big(folly::loadUnaligned(s)); + }; + for (int i = 0; i < numFields; ++i) { + const auto& s = fields[i].key; + auto& t = fastSortKeys_[i].value; + int j = 0; + while (8 * (j + 1) <= s.size()) { + t[j++] = load(s.data() + 8 * j); + } + auto r = s.size() - 8 * j; + if (r > 0) { + auto v = load(s.data() + 8 * j); + v >>= 8 - r; + v <<= 8 - r; + t[j] = v; } } - VELOX_UNREACHABLE(); + sortIndices([&](int32_t i, int32_t j) { + return fastSortKeys_[i].value < fastSortKeys_[j].value; + }); + return true; + } + + template + void sortIndices(LessThan&& lt) const { + if (sortIndices_.size() <= kSortingNetworkMaxSize) { + sortingNetwork( + sortIndices_.data(), sortIndices_.size(), std::forward(lt)); + } else { + std::sort( + sortIndices_.begin(), sortIndices_.end(), std::forward(lt)); + } } mutable folly::once_flag initializeErrors_; mutable std::exception_ptr errors_[simdjson::NUM_ERROR_CODES]; // Padding is needed in case string view is inlined. mutable std::string paddedInput_; + mutable JsonViews views_; + mutable std::vector fields_; + mutable std::vector sortIndices_; + mutable std::vector fastSortKeys_; }; } // namespace diff --git a/velox/functions/prestosql/json/JsonStringUtil.cpp b/velox/functions/prestosql/json/JsonStringUtil.cpp index 720b6d0f4b67b..05eac66e60c3f 100644 --- a/velox/functions/prestosql/json/JsonStringUtil.cpp +++ b/velox/functions/prestosql/json/JsonStringUtil.cpp @@ -123,7 +123,7 @@ void testingEncodeUtf16Hex(char32_t codePoint, char*& out) { encodeUtf16Hex(codePoint, out); } -void escapeString(const char* input, size_t length, char* output) { +void normalizeForJsonCast(const char* input, size_t length, char* output) { char* pos = output; auto* start = reinterpret_cast(input); @@ -165,7 +165,7 @@ void escapeString(const char* input, size_t length, char* output) { } } -size_t escapedStringSize(const char* input, size_t length) { +size_t normalizedSizeForJsonCast(const char* input, size_t length) { // 6 chars that is returned by `writeHex`. constexpr size_t kEncodedHexSize = 6; @@ -310,7 +310,9 @@ int32_t compareChars( } } // namespace -bool lessThan(const std::string_view& first, const std::string_view& second) { +bool lessThanForJsonParse( + const std::string_view& first, + const std::string_view& second) { size_t firstLength = first.size(); size_t secondLength = second.size(); size_t minLength = std::min(firstLength, secondLength); @@ -325,70 +327,65 @@ bool lessThan(const std::string_view& first, const std::string_view& second) { return firstLength < secondLength; } -size_t prestoJavaEscapeString(const char* input, size_t length, char* output) { +size_t normalizeForJsonParse(const char* input, size_t length, char* output) { char* pos = output; - - auto* start = reinterpret_cast(input); - auto* end = reinterpret_cast(input + length); + auto* start = input; + auto* end = input + length; while (start < end) { - int count = validateAndGetNextUtf8Length(start, end); - switch (count) { - case 1: { - // Unescape characters that are escaped by \ character. - if (FOLLY_UNLIKELY(*start == '\\')) { - if (start + 1 == end) { - VELOX_USER_FAIL("Invalid escape sequence at the end of string"); - } - // Presto java implementation only unescapes the / character. - switch (*(start + 1)) { - case '/': - *pos++ = '/'; - start += 2; - continue; - case 'u': { - if (start + 5 > end) { - VELOX_USER_FAIL("Invalid escape sequence at the end of string"); - } - - // Read 4 hex digits. - auto codePoint = parseHex(std::string_view( - reinterpret_cast(start) + 2, 4)); - - // Presto java implementation doesnt unescape surrogate pairs. - // Thus we just write it out in the same way as it is. - if (isHighSurrogate(codePoint) || isLowSurrogate(codePoint) || - isSpecialCode(codePoint)) { - *pos++ = '\\'; - *pos++ = 'u'; - start += 2; - // java upper cases the code points - for (auto k = 0; k < 4; k++) { - *pos++ = std::toupper(start[k]); - } - - start += 4; - continue; - } - - // Otherwise write it as a single code point. - auto increment = utf8proc_encode_char( - codePoint, reinterpret_cast(pos)); - pos += increment; - start += 6; - continue; + // Unescape characters that are escaped by \ character. + if (FOLLY_UNLIKELY(*start == '\\')) { + VELOX_USER_CHECK_NE( + start + 1, end, "Invalid escape sequence at the end of string"); + // Presto java implementation only unescapes the / character. + switch (*(start + 1)) { + case '/': + *pos++ = '/'; + start += 2; + continue; + case 'u': { + VELOX_USER_CHECK_LE( + start + 5, end, "Invalid escape sequence at the end of string"); + + // Read 4 hex digits. + auto codePoint = parseHex(std::string_view(start + 2, 4)); + + // Presto java implementation doesnt unescape surrogate pairs. + // Thus we just write it out in the same way as it is. + if (isHighSurrogate(codePoint) || isLowSurrogate(codePoint) || + isSpecialCode(codePoint)) { + *pos++ = '\\'; + *pos++ = 'u'; + start += 2; + // java upper cases the code points + for (auto k = 0; k < 4; k++) { + *pos++ = std::toupper(start[k]); } - default: - *pos++ = *start; - *pos++ = *(start + 1); - start += 2; - continue; + + start += 4; + continue; } - } else { - *pos++ = *start; - start++; + + // Otherwise write it as a single code point. + auto increment = utf8proc_encode_char( + codePoint, reinterpret_cast(pos)); + pos += increment; + start += 6; continue; } + default: + *pos++ = *start; + *pos++ = *(start + 1); + start += 2; + continue; } + } + if (FOLLY_LIKELY(IS_ASCII(*start))) { + *pos++ = *start++; + continue; + } + int32_t codePoint; + int count = tryGetUtf8CharLength(start, end - start, codePoint); + switch (count) { case 2: { memcpy(pos, reinterpret_cast(start), 2); pos += 2; @@ -402,22 +399,95 @@ size_t prestoJavaEscapeString(const char* input, size_t length, char* output) { continue; } case 4: { - char32_t codePoint = folly::utf8ToCodePoint(start, end, true); if (codePoint == U'\ufffd') { writeHex(0xFFFDu, pos); - continue; + } else { + encodeUtf16Hex(codePoint, pos); } - encodeUtf16Hex(codePoint, pos); + start += 4; continue; } default: { - writeHex(0xFFFDu, pos); - start++; + VELOX_DCHECK_LT(count, 0); + count = -count; + const auto& replacement = + getInvalidUTF8ReplacementString(start, end - start, count); + std::memcpy(pos, replacement.data(), replacement.size()); + pos += replacement.size(); + start += count; } } } + return pos - output; +} - return (pos - output); +size_t normalizedSizeForJsonParse(const char* input, size_t length) { + auto* start = input; + auto* end = input + length; + size_t outSize = 0; + while (start < end) { + if (FOLLY_UNLIKELY(*start == '\\')) { + VELOX_USER_CHECK_NE( + start + 1, end, "Invalid escape sequence at the end of string"); + switch (*(start + 1)) { + case '/': + ++outSize; + start += 2; + continue; + case 'u': { + VELOX_USER_CHECK_LE( + start + 5, end, "Invalid escape sequence at the end of string"); + auto codePoint = parseHex(std::string_view(start + 2, 4)); + if (isHighSurrogate(codePoint) || isLowSurrogate(codePoint) || + isSpecialCode(codePoint)) { + outSize += 6; + } else { + unsigned char buf[4]; + auto increment = utf8proc_encode_char(codePoint, buf); + outSize += increment; + } + start += 6; + continue; + } + default: + outSize += 2; + start += 2; + continue; + } + } + if (FOLLY_LIKELY(IS_ASCII(*start))) { + ++outSize; + ++start; + continue; + } + int32_t codePoint; + auto count = tryGetUtf8CharLength(start, end - start, codePoint); + switch (count) { + case 2: + case 3: + outSize += count; + start += count; + continue; + case 4: { + if (codePoint >= 0x10000u) { + outSize += 12; + } else { + outSize += 6; + } + start += 4; + continue; + } + default: { + VELOX_DCHECK_LT(count, 0); + count = -count; + const auto& replacement = + getInvalidUTF8ReplacementString(start, end - start, count); + outSize += replacement.size(); + start += count; + } + } + } + return outSize; } } // namespace facebook::velox diff --git a/velox/functions/prestosql/json/JsonStringUtil.h b/velox/functions/prestosql/json/JsonStringUtil.h index 384f6b768bd14..39c4b85895443 100644 --- a/velox/functions/prestosql/json/JsonStringUtil.h +++ b/velox/functions/prestosql/json/JsonStringUtil.h @@ -15,7 +15,10 @@ */ #pragma once +#include "velox/common/base/SimdUtil.h" + namespace facebook::velox { + /// Escape the unicode characters of `input` to make it canonical for JSON /// and legal to print in JSON text. It is assumed that the input is UTF-8 /// encoded. @@ -38,7 +41,15 @@ namespace facebook::velox { /// @param length: Length of the input string. /// @param output: Output string to write the escaped input to. The caller is /// responsible to allocate enough space for output. -void escapeString(const char* input, size_t length, char* output); +void normalizeForJsonCast(const char* input, size_t length, char* output); + +/// Return the size of string after the unicode characters of `input` are +/// escaped using the method as in`escapeString`. The function will iterate +/// over `input` once. +/// @param input: Input string to escape that is UTF-8 encoded. +/// @param length: Length of the input string. +/// @return The size of the string after escaping. +size_t normalizedSizeForJsonCast(const char* input, size_t length); /// Unescape the unicode characters of `input` to make it canonical for JSON /// The behavior is compatible with Presto Java's json_parse. @@ -50,25 +61,43 @@ void escapeString(const char* input, size_t length, char* output); /// @param output: Output string to write the escaped input to. The caller is /// responsible to allocate enough space for output. /// @return The number of bytes written to the output. -size_t prestoJavaEscapeString(const char* input, size_t length, char* output); +size_t normalizeForJsonParse(const char* input, size_t length, char* output); -/// Return the size of string after the unicode characters of `input` are -/// escaped using the method as in`escapeString`. The function will iterate -/// over `input` once. -/// @param input: Input string to escape that is UTF-8 encoded. -/// @param length: Length of the input string. -/// @return The size of the string after escaping. -size_t escapedStringSize(const char* input, size_t length); +size_t normalizedSizeForJsonParse(const char* input, size_t length); + +/// Return whether the string need normalize or special treatment for sort +/// object string keys (i.e. lessThanForJsonParse below). +inline bool needNormalizeForJsonParse(const char* input, size_t length) { + const auto unicodeMask = xsimd::broadcast(0x80); + const auto escape = xsimd::broadcast('\\'); + size_t i = 0; + for (; i + unicodeMask.size <= length; i += unicodeMask.size) { + auto batch = + xsimd::load_unaligned(reinterpret_cast(input) + i); + if (xsimd::any(batch >= unicodeMask || batch == escape)) { + return true; + } + } + for (; i < length; ++i) { + if ((input[i] & 0x80) || (input[i] == '\\')) { + return true; + } + } + return false; +} /// Compares two string views. The comparison takes into account /// escape sequences and also unicode characters. /// Returns true if first is less than second else false. /// @param first: First string to compare. /// @param second: Second string to compare. -bool lessThan(const std::string_view& first, const std::string_view& second); +bool lessThanForJsonParse( + const std::string_view& first, + const std::string_view& second); /// For test only. Encode `codePoint` value by UTF-16 and write the one or two /// prefixed hexadecimals to `out`. Move `out` forward by 6 or 12 chars /// accordingly. The caller shall ensure there is enough space in `out`. void testingEncodeUtf16Hex(char32_t codePoint, char*& out); + } // namespace facebook::velox diff --git a/velox/functions/prestosql/tests/JsonFunctionsTest.cpp b/velox/functions/prestosql/tests/JsonFunctionsTest.cpp index fa0bd7057d280..03fa2fcf7ab51 100644 --- a/velox/functions/prestosql/tests/JsonFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/JsonFunctionsTest.cpp @@ -267,9 +267,7 @@ TEST_F(JsonFunctionsTest, jsonParse) { velox::test::assertEqualVectors( expectedVector, evaluate("try(json_parse(c0))", data)); - VELOX_ASSERT_THROW( - evaluate("json_parse(c0)", data), - "TAPE_ERROR: The JSON document has an improper structure: missing or superfluous commas, braces, missing keys, etc."); + VELOX_ASSERT_THROW(evaluate("json_parse(c0)", data), "TRAILING_CONTENT"); data = makeRowVector({makeFlatVector( {R"("This is a long sentence")", R"("This is some other sentence")"})}); diff --git a/velox/functions/prestosql/types/JsonType.cpp b/velox/functions/prestosql/types/JsonType.cpp index 35558028336fe..d2fd99bc11bc7 100644 --- a/velox/functions/prestosql/types/JsonType.cpp +++ b/velox/functions/prestosql/types/JsonType.cpp @@ -50,10 +50,10 @@ void generateJsonTyped( auto value = input.valueAt(row); if constexpr (std::is_same_v) { - size_t resultSize = escapedStringSize(value.data(), value.size()); + size_t resultSize = normalizedSizeForJsonCast(value.data(), value.size()); result.resize(resultSize + 2); result.data()[0] = '"'; - escapeString(value.data(), value.size(), result.data() + 1); + normalizeForJsonCast(value.data(), value.size(), result.data() + 1); result.data()[resultSize + 1] = '"'; } else if constexpr (std::is_same_v) { VELOX_FAIL(