Skip to content

Commit

Permalink
Add support for canonicalization of JSON. (#11284)
Browse files Browse the repository at this point in the history
Summary:
This is preliminary PR that adds support for canonicalization of JSON strings. This initial PR only tackles canonicalization of json_parse. Another diff will handle CAST( _ as JSON) . Canonicalization is required since currently Velox just treats JSON as varchars thus equivalent JSON but having different backing varchar's are treated as separate JSON's which is incorrect and contrary to behavior shown by Presto.


Differential Revision: D65084925

Pulled By: kgpai
  • Loading branch information
kgpai authored and facebook-github-bot committed Nov 4, 2024
1 parent c95f1e0 commit 8b17860
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 25 deletions.
210 changes: 195 additions & 15 deletions velox/functions/prestosql/JsonFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,89 @@
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"
#include "velox/functions/prestosql/json/JsonStringUtil.h"
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
#include "velox/functions/prestosql/types/JsonType.h"

namespace facebook::velox::functions {

namespace {
const auto kArrayStart = "["_sv;
const auto kArrayEnd = "]"_sv;
const auto kSeparator = ","_sv;
const auto kObjectStart = "{"_sv;
const auto kObjectEnd = "}"_sv;
const auto kObjectKeySeparator = ":"_sv;

/// Class to keep track of json strings being written
/// in to a buffer. The size of the backing buffer must be known during
/// construction time.
class BufferTracker {
public:
BufferTracker(size_t bufferSize, memory::MemoryPool* pool)
: curPos_(0), currentViewStart_(0) {
buffer_ = AlignedBuffer::allocate<char>(bufferSize, pool);
bufPtr_ = buffer_->asMutable<char>();
}

/// Write out all the views to the buffer.
auto getCanonicalString(std::vector<StringView>& jsonViews) {
for (auto view : jsonViews) {
trimEscapeWriteToBuffer(view);
}
return getStringView();
}

/// Sets current view to the end of the previous string.
/// Should be called only after getCanonicalString ,
/// as after this call the previous view is lost.
void startNewString() {
currentViewStart_ += curPos_;
curPos_ = 0;
}

/// Returns the underlying buffer where the json strings are saved.
/// Resets everything to the initial state.
/// The caller must ensure that the buffer is not used after this call.
BufferPtr&& getUnderlyingBuffer() {
currentViewStart_ = 0;
curPos_ = 0;
return std::move(buffer_);
}

private:
/// Trims whitespace and escapes utf characters before writing to buffer.
void trimEscapeWriteToBuffer(StringView input) {
auto trimmed = velox::util::trimWhiteSpace(input.data(), input.size());
auto curBufPtr = getCurrentBufferPtr();
auto bytesWritten =
escapeString(trimmed.data(), trimmed.size(), curBufPtr, true);
incrementCounter(bytesWritten);
}

/// Returns current string view against the buffer.
StringView getStringView() {
return StringView(bufPtr_ + currentViewStart_, curPos_);
}

inline char* getCurrentBufferPtr() {
return bufPtr_ + currentViewStart_ + curPos_;
}

void incrementCounter(size_t increment) {
VELOX_CHECK_LE(
curPos_ + currentViewStart_ + increment, buffer_->capacity());
curPos_ += increment;
}

BufferPtr buffer_;
size_t curPos_;
size_t currentViewStart_;
char* bufPtr_;
};

} // namespace

namespace {
class JsonFormatFunction : public exec::VectorFunction {
public:
Expand Down Expand Up @@ -84,38 +162,71 @@ class JsonParseFunction : public exec::VectorFunction {
auto value = arg->as<ConstantVector<StringView>>()->valueAt(0);
paddedInput_.resize(value.size() + simdjson::SIMDJSON_PADDING);
memcpy(paddedInput_.data(), value.data(), value.size());
if (auto error = parse(value.size())) {
auto escapeSize = escapedStringSize(value.data(), value.size(), true);
BufferTracker bufferTracker{escapeSize, context.pool()};

std::vector<StringView> jsonViews;

if (auto error = parse(value.size(), jsonViews)) {
context.setErrors(rows, errors_[error]);
return;
}
localResult = std::make_shared<ConstantVector<StringView>>(
context.pool(), rows.end(), false, JSON(), std::move(value));

BufferPtr stringViews =
AlignedBuffer::allocate<StringView>(1, context.pool(), StringView());
auto rawStringViews = stringViews->asMutable<StringView>();
rawStringViews[0] = bufferTracker.getCanonicalString(jsonViews);

auto constantBase = std::make_shared<FlatVector<StringView>>(
context.pool(),
JSON(),
nullptr,
1,
stringViews,
std::vector<BufferPtr>{bufferTracker.getUnderlyingBuffer()});

localResult = BaseVector::wrapInConstant(rows.end(), 0, constantBase);

} else {
auto flatInput = arg->asFlatVector<StringView>();
BufferPtr stringViews = AlignedBuffer::allocate<StringView>(
rows.end(), context.pool(), StringView());
auto rawStringViews = stringViews->asMutable<StringView>();

auto stringBuffers = flatInput->stringBuffers();
VELOX_CHECK_LE(rows.end(), flatInput->size());

size_t maxSize = 0;
size_t totalOutputSize = 0;
rows.applyToSelected([&](auto row) {
auto value = flatInput->valueAt(row);
maxSize = std::max(maxSize, value.size());
totalOutputSize += escapedStringSize(value.data(), value.size(), true);
});

paddedInput_.resize(maxSize + simdjson::SIMDJSON_PADDING);
BufferTracker bufferTracker{totalOutputSize, context.pool()};

rows.applyToSelected([&](auto row) {
std::vector<StringView> jsonViews;
auto value = flatInput->valueAt(row);
memcpy(paddedInput_.data(), value.data(), value.size());
if (auto error = parse(value.size())) {
if (auto error = parse(value.size(), jsonViews)) {
context.setVeloxExceptionError(row, errors_[error]);
} else {
auto canonicalString = bufferTracker.getCanonicalString(jsonViews);

rawStringViews[row] = canonicalString;
bufferTracker.startNewString();
}
});

localResult = std::make_shared<FlatVector<StringView>>(
context.pool(),
JSON(),
nullptr,
rows.end(),
flatInput->values(),
std::move(stringBuffers));
stringViews,
std::vector<BufferPtr>{bufferTracker.getUnderlyingBuffer()});
}

context.moveOrCopyResult(localResult, rows, result);
Expand All @@ -130,45 +241,114 @@ class JsonParseFunction : public exec::VectorFunction {
}

private:
simdjson::error_code parse(size_t size) const {
simdjson::error_code parse(size_t size, std::vector<StringView>& jsonViews)
const {
simdjson::padded_string_view paddedInput(
paddedInput_.data(), size, paddedInput_.size());
SIMDJSON_ASSIGN_OR_RAISE(auto doc, simdjsonParse(paddedInput));
SIMDJSON_TRY(validate<simdjson::ondemand::document&>(doc));
SIMDJSON_TRY(validate<simdjson::ondemand::document&>(doc, jsonViews));
if (!doc.at_end()) {
return simdjson::TRAILING_CONTENT;
}
return simdjson::SUCCESS;
}

template <typename T>
static simdjson::error_code validate(T value) {
static simdjson::error_code validate(
T value,
std::vector<StringView>& jsonViews) {
SIMDJSON_ASSIGN_OR_RAISE(auto type, value.type());
switch (type) {
case simdjson::ondemand::json_type::array: {
SIMDJSON_ASSIGN_OR_RAISE(auto array, value.get_array());

jsonViews.push_back(kArrayStart);
for (auto elementOrError : array) {
SIMDJSON_ASSIGN_OR_RAISE(auto element, elementOrError);
SIMDJSON_TRY(validate(element));
std::vector<StringView> arrayElement;
SIMDJSON_TRY(validate(element, arrayElement));
jsonViews.insert(
jsonViews.end(),
std::make_move_iterator(arrayElement.begin()),
std::make_move_iterator(arrayElement.end()));
jsonViews.push_back(kSeparator);
}

// Remove last separator.
jsonViews.pop_back();
jsonViews.push_back(kArrayEnd);

return simdjson::SUCCESS;
}

case simdjson::ondemand::json_type::object: {
SIMDJSON_ASSIGN_OR_RAISE(auto object, value.get_object());

std::vector<std::pair<StringView, std::vector<StringView>>> objFields;
for (auto fieldOrError : object) {
SIMDJSON_ASSIGN_OR_RAISE(auto field, fieldOrError);
SIMDJSON_TRY(validate(field.value()));
auto key = StringView(field.key_raw_json_token());
std::vector<StringView> elementArray;
SIMDJSON_TRY(validate(field.value(), elementArray));
objFields.push_back({key, elementArray});
}

std::sort(objFields.begin(), objFields.end(), [](auto& a, auto& b) {
return a.first < b.first;
});

jsonViews.push_back(kObjectStart);

for (auto i = 0; i < objFields.size(); i++) {
auto field = objFields[i];
jsonViews.push_back(field.first);
jsonViews.push_back(kObjectKeySeparator);

jsonViews.insert(
jsonViews.end(),
std::make_move_iterator(field.second.begin()),
std::make_move_iterator(field.second.end()));

if (i < objFields.size() - 1) {
jsonViews.push_back(kSeparator);
}
}

jsonViews.push_back(kObjectEnd);
return simdjson::SUCCESS;
}
case simdjson::ondemand::json_type::number:

case simdjson::ondemand::json_type::number: {
SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json());

auto rawJsonv = StringView(rawJson);
jsonViews.push_back(rawJsonv);

return value.get_double().error();
case simdjson::ondemand::json_type::string:
}
case simdjson::ondemand::json_type::string: {
SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json());

auto rawJsonv = StringView(rawJson);
jsonViews.push_back(rawJsonv);

return value.get_string().error();
case simdjson::ondemand::json_type::boolean:
}

case simdjson::ondemand::json_type::boolean: {
SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json());

auto rawJsonv = StringView(rawJson);
jsonViews.push_back(rawJsonv);

return value.get_bool().error();
}

case simdjson::ondemand::json_type::null: {
SIMDJSON_ASSIGN_OR_RAISE(auto isNull, value.is_null());
auto rawJsonv = StringView(value.raw_json_token());

jsonViews.push_back(rawJsonv);
return isNull ? simdjson::SUCCESS : simdjson::N_ATOM_ERROR;
}
}
Expand Down
21 changes: 17 additions & 4 deletions velox/functions/prestosql/json/JsonStringUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ void testingEncodeUtf16Hex(char32_t codePoint, char*& out) {
encodeUtf16Hex(codePoint, out);
}

void escapeString(const char* input, size_t length, char* output) {
size_t
escapeString(const char* input, size_t length, char* output, bool skipAscii) {
char* pos = output;

auto* start = reinterpret_cast<const unsigned char*>(input);
Expand All @@ -117,7 +118,12 @@ void escapeString(const char* input, size_t length, char* output) {
int count = validateAndGetNextUtf8Length(start, end);
switch (count) {
case 1: {
encodeAscii(int8_t(*start), pos);
if (!skipAscii) {
encodeAscii(int8_t(*start), pos);
} else {
*pos++ = *start;
}

start++;
continue;
}
Expand Down Expand Up @@ -148,9 +154,11 @@ void escapeString(const char* input, size_t length, char* output) {
}
}
}

return (pos - output);
}

size_t escapedStringSize(const char* input, size_t length) {
size_t escapedStringSize(const char* input, size_t length, bool skipAscii) {
// 6 chars that is returned by `writeHex`.
constexpr size_t kEncodedHexSize = 6;

Expand All @@ -162,7 +170,12 @@ size_t escapedStringSize(const char* input, size_t length) {
int count = validateAndGetNextUtf8Length(start, end);
switch (count) {
case 1:
outSize += encodedAsciiSizes[int8_t(*start)];
if (!skipAscii) {
outSize += encodedAsciiSizes[int8_t(*start)];
} else {
outSize++;
}

break;
case 2:
case 3:
Expand Down
14 changes: 12 additions & 2 deletions velox/functions/prestosql/json/JsonStringUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,24 @@ namespace facebook::velox {
/// @param length: Length of the input string.
/// @param output: Output string to write the escaped input to. The caller is
/// responsible to allocate enough space for output.
void escapeString(const char* input, size_t length, char* output);
/// @param skipAscii: Do not consider ascii characters for encoding (used in
/// json_parse for example).
/// @return The number of bytes written to the output.
size_t escapeString(
const char* input,
size_t length,
char* output,
bool skipAscii = false);

/// Return the size of string after the unicode characters of `input` are
/// escaped using the method as in`escapeString`. The function will iterate
/// over `input` once.
/// @param input: Input string to escape that is UTF-8 encoded.
/// @param length: Length of the input string.
size_t escapedStringSize(const char* input, size_t length);
/// @param skipAscii: Do not consider ascii characters for encoding (used in
/// json_parse for example).
size_t
escapedStringSize(const char* input, size_t length, bool skipAscii = false);

/// For test only. Encode `codePoint` value by UTF-16 and write the one or two
/// prefixed hexadecimals to `out`. Move `out` forward by 6 or 12 chars
Expand Down
Loading

0 comments on commit 8b17860

Please sign in to comment.