Skip to content

Commit

Permalink
Add support for canonicalization of JSON.
Browse files Browse the repository at this point in the history
  • Loading branch information
kgpai committed Nov 1, 2024
1 parent 92779f9 commit 80fae98
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 25 deletions.
206 changes: 191 additions & 15 deletions velox/functions/prestosql/JsonFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,85 @@
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"
#include "velox/functions/prestosql/json/JsonStringUtil.h"
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
#include "velox/functions/prestosql/types/JsonType.h"

namespace facebook::velox::functions {

namespace {
const auto kArrayStart = "["_sv;
const auto kArrayEnd = "]"_sv;
const auto kSeparator = ","_sv;
const auto kObjectStart = "{"_sv;
const auto kObjectEnd = "}"_sv;
const auto kObjectKeySeparator = ":"_sv;

/// Class to keep track of json strings being written
/// in to a buffer. The size of the backing buffer must be known during
/// construction time.
class BufferTracker {
public:
BufferTracker(size_t bufferSize, memory::MemoryPool* pool)
: curPos_(0), currentViewStart_(0) {
buffer_ = AlignedBuffer::allocate<char>(bufferSize, pool);
bufPtr_ = buffer_->asMutable<char>();
}

/// Write out all the views to the buffer.
auto getCanonicalString(std::vector<StringView>& jsonViews) {
for (auto view : jsonViews) {
trimEscapeWriteToBuffer(view);
}
return getStringView();
}

/// Sets current view to the end of the previous string.
/// Should be called only after getCanonicalString ,
/// as after this call the previous view is lost.
void startNewString() {
currentViewStart_ += curPos_;
curPos_ = 0;
}

/// Returns the underlying buffer where the json strings are saved.
BufferPtr getUnderlyingBuffer() {
return buffer_;
}

private:
/// Trims whitespace and escapes utf characters before writing to buffer.
void trimEscapeWriteToBuffer(StringView input) {
auto trimmed = velox::util::trimWhiteSpace(input.data(), input.size());
auto curBufPtr = getCurrentBufferPtr();
auto bytesWritten =
escapeString(trimmed.data(), trimmed.size(), curBufPtr, true);
incrementCounter(bytesWritten);
}

/// Returns current string view against the buffer.
StringView getStringView() {
return StringView(bufPtr_ + currentViewStart_, curPos_);
}

inline char* getCurrentBufferPtr() {
return bufPtr_ + currentViewStart_ + curPos_;
}

void incrementCounter(size_t increment) {
VELOX_CHECK_LE(
curPos_ + currentViewStart_ + increment, buffer_->capacity());
curPos_ += increment;
}

BufferPtr buffer_;
size_t curPos_;
size_t currentViewStart_;
char* bufPtr_;
};

} // namespace

namespace {
class JsonFormatFunction : public exec::VectorFunction {
public:
Expand Down Expand Up @@ -84,38 +158,71 @@ class JsonParseFunction : public exec::VectorFunction {
auto value = arg->as<ConstantVector<StringView>>()->valueAt(0);
paddedInput_.resize(value.size() + simdjson::SIMDJSON_PADDING);
memcpy(paddedInput_.data(), value.data(), value.size());
if (auto error = parse(value.size())) {
auto escapeSize = escapedStringSize(value.data(), value.size(), true);
BufferTracker bufferTracker{escapeSize, context.pool()};

std::vector<StringView> jsonViews;

if (auto error = parse(value.size(), jsonViews)) {
context.setErrors(rows, errors_[error]);
return;
}
localResult = std::make_shared<ConstantVector<StringView>>(
context.pool(), rows.end(), false, JSON(), std::move(value));

BufferPtr stringViews =
AlignedBuffer::allocate<StringView>(1, context.pool(), StringView());
auto rawStringViews = stringViews->asMutable<StringView>();
rawStringViews[0] = bufferTracker.getCanonicalString(jsonViews);

auto constantBase = std::make_shared<FlatVector<StringView>>(
context.pool(),
JSON(),
nullptr,
1,
stringViews,
std::vector<BufferPtr>{bufferTracker.getUnderlyingBuffer()});

localResult = BaseVector::wrapInConstant(rows.end(), 0, constantBase);

} else {
auto flatInput = arg->asFlatVector<StringView>();
BufferPtr stringViews = AlignedBuffer::allocate<StringView>(
rows.end(), context.pool(), StringView());
auto rawStringViews = stringViews->asMutable<StringView>();

auto stringBuffers = flatInput->stringBuffers();
VELOX_CHECK_LE(rows.end(), flatInput->size());

size_t maxSize = 0;
size_t totalOutputSize = 0;
rows.applyToSelected([&](auto row) {
auto value = flatInput->valueAt(row);
maxSize = std::max(maxSize, value.size());
totalOutputSize += escapedStringSize(value.data(), value.size(), true);
});

paddedInput_.resize(maxSize + simdjson::SIMDJSON_PADDING);
BufferTracker bufferTracker{totalOutputSize, context.pool()};

rows.applyToSelected([&](auto row) {
std::vector<StringView> jsonViews;
auto value = flatInput->valueAt(row);
memcpy(paddedInput_.data(), value.data(), value.size());
if (auto error = parse(value.size())) {
if (auto error = parse(value.size(), jsonViews)) {
context.setVeloxExceptionError(row, errors_[error]);
} else {
auto canonicalString = bufferTracker.getCanonicalString(jsonViews);

rawStringViews[row] = canonicalString;
bufferTracker.startNewString();
}
});

localResult = std::make_shared<FlatVector<StringView>>(
context.pool(),
JSON(),
nullptr,
rows.end(),
flatInput->values(),
std::move(stringBuffers));
stringViews,
std::vector<BufferPtr>{bufferTracker.getUnderlyingBuffer()});
}

context.moveOrCopyResult(localResult, rows, result);
Expand All @@ -130,45 +237,114 @@ class JsonParseFunction : public exec::VectorFunction {
}

private:
simdjson::error_code parse(size_t size) const {
simdjson::error_code parse(size_t size, std::vector<StringView>& jsonViews)
const {
simdjson::padded_string_view paddedInput(
paddedInput_.data(), size, paddedInput_.size());
SIMDJSON_ASSIGN_OR_RAISE(auto doc, simdjsonParse(paddedInput));
SIMDJSON_TRY(validate<simdjson::ondemand::document&>(doc));
SIMDJSON_TRY(validate<simdjson::ondemand::document&>(doc, jsonViews));
if (!doc.at_end()) {
return simdjson::TRAILING_CONTENT;
}
return simdjson::SUCCESS;
}

template <typename T>
static simdjson::error_code validate(T value) {
static simdjson::error_code validate(
T value,
std::vector<StringView>& jsonViews) {
SIMDJSON_ASSIGN_OR_RAISE(auto type, value.type());
switch (type) {
case simdjson::ondemand::json_type::array: {
SIMDJSON_ASSIGN_OR_RAISE(auto array, value.get_array());

jsonViews.push_back(kArrayStart);
for (auto elementOrError : array) {
SIMDJSON_ASSIGN_OR_RAISE(auto element, elementOrError);
SIMDJSON_TRY(validate(element));
std::vector<StringView> arrayElement;
SIMDJSON_TRY(validate(element, arrayElement));
jsonViews.insert(
jsonViews.end(),
std::make_move_iterator(arrayElement.begin()),
std::make_move_iterator(arrayElement.end()));
jsonViews.push_back(kSeparator);
}

// Remove last separator.
jsonViews.pop_back();
jsonViews.push_back(kArrayEnd);

return simdjson::SUCCESS;
}

case simdjson::ondemand::json_type::object: {
SIMDJSON_ASSIGN_OR_RAISE(auto object, value.get_object());

std::vector<std::pair<StringView, std::vector<StringView>>> objFields;
for (auto fieldOrError : object) {
SIMDJSON_ASSIGN_OR_RAISE(auto field, fieldOrError);
SIMDJSON_TRY(validate(field.value()));
auto key = StringView(field.key_raw_json_token());
std::vector<StringView> elementArray;
SIMDJSON_TRY(validate(field.value(), elementArray));
objFields.push_back({key, elementArray});
}

std::sort(objFields.begin(), objFields.end(), [](auto& a, auto& b) {
return a.first < b.first;
});

jsonViews.push_back(kObjectStart);

for (auto i = 0; i < objFields.size(); i++) {
auto field = objFields[i];
jsonViews.push_back(field.first);
jsonViews.push_back(kObjectKeySeparator);

jsonViews.insert(
jsonViews.end(),
std::make_move_iterator(field.second.begin()),
std::make_move_iterator(field.second.end()));

if (i < objFields.size() - 1) {
jsonViews.push_back(kSeparator);
}
}

jsonViews.push_back(kObjectEnd);
return simdjson::SUCCESS;
}
case simdjson::ondemand::json_type::number:

case simdjson::ondemand::json_type::number: {
SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json());

auto rawJsonv = StringView(rawJson);
jsonViews.push_back(rawJsonv);

return value.get_double().error();
case simdjson::ondemand::json_type::string:
}
case simdjson::ondemand::json_type::string: {
SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json());

auto rawJsonv = StringView(rawJson);
jsonViews.push_back(rawJsonv);

return value.get_string().error();
case simdjson::ondemand::json_type::boolean:
}

case simdjson::ondemand::json_type::boolean: {
SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json());

auto rawJsonv = StringView(rawJson);
jsonViews.push_back(rawJsonv);

return value.get_bool().error();
}

case simdjson::ondemand::json_type::null: {
SIMDJSON_ASSIGN_OR_RAISE(auto isNull, value.is_null());
auto rawJsonv = StringView(value.raw_json_token());

jsonViews.push_back(rawJsonv);
return isNull ? simdjson::SUCCESS : simdjson::N_ATOM_ERROR;
}
}
Expand Down
21 changes: 17 additions & 4 deletions velox/functions/prestosql/json/JsonStringUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ void testingEncodeUtf16Hex(char32_t codePoint, char*& out) {
encodeUtf16Hex(codePoint, out);
}

void escapeString(const char* input, size_t length, char* output) {
size_t
escapeString(const char* input, size_t length, char* output, bool skipAscii) {
char* pos = output;

auto* start = reinterpret_cast<const unsigned char*>(input);
Expand All @@ -117,7 +118,12 @@ void escapeString(const char* input, size_t length, char* output) {
int count = validateAndGetNextUtf8Length(start, end);
switch (count) {
case 1: {
encodeAscii(int8_t(*start), pos);
if (!skipAscii) {
encodeAscii(int8_t(*start), pos);
} else {
*pos++ = *start;
}

start++;
continue;
}
Expand Down Expand Up @@ -148,9 +154,11 @@ void escapeString(const char* input, size_t length, char* output) {
}
}
}

return (pos - output);
}

size_t escapedStringSize(const char* input, size_t length) {
size_t escapedStringSize(const char* input, size_t length, bool skipAscii) {
// 6 chars that is returned by `writeHex`.
constexpr size_t kEncodedHexSize = 6;

Expand All @@ -162,7 +170,12 @@ size_t escapedStringSize(const char* input, size_t length) {
int count = validateAndGetNextUtf8Length(start, end);
switch (count) {
case 1:
outSize += encodedAsciiSizes[int8_t(*start)];
if (!skipAscii) {
outSize += encodedAsciiSizes[int8_t(*start)];
} else {
outSize++;
}

break;
case 2:
case 3:
Expand Down
14 changes: 12 additions & 2 deletions velox/functions/prestosql/json/JsonStringUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,24 @@ namespace facebook::velox {
/// @param length: Length of the input string.
/// @param output: Output string to write the escaped input to. The caller is
/// responsible to allocate enough space for output.
void escapeString(const char* input, size_t length, char* output);
/// @param skipAscii: Do not consider ascii characters for encoding (used in
/// json_parse for example).
/// @return The number of bytes written to the output.
size_t escapeString(
const char* input,
size_t length,
char* output,
bool skipAscii = false);

/// Return the size of string after the unicode characters of `input` are
/// escaped using the method as in`escapeString`. The function will iterate
/// over `input` once.
/// @param input: Input string to escape that is UTF-8 encoded.
/// @param length: Length of the input string.
size_t escapedStringSize(const char* input, size_t length);
/// @param skipAscii: Do not consider ascii characters for encoding (used in
/// json_parse for example).
size_t
escapedStringSize(const char* input, size_t length, bool skipAscii = false);

/// For test only. Encode `codePoint` value by UTF-16 and write the one or two
/// prefixed hexadecimals to `out`. Move `out` forward by 6 or 12 chars
Expand Down
Loading

0 comments on commit 80fae98

Please sign in to comment.