From a1adafe920ada63985c00a5a2eebc5383763b057 Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Thu, 19 Dec 2024 14:23:04 -0800 Subject: [PATCH] feat: Add Spark get_json_object function (#11691) Summary: This PR proposes an implementation for Spark get_json_object function based on simdjson lib. This function returns a json object, represented by VARCHAR, from json string by searching user-specified path. Spark source code: [link](https://github.com/apache/spark/blob/v3.5.1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala#L127). Pull Request resolved: https://github.com/facebookincubator/velox/pull/11691 Reviewed By: xiaoxmeng Differential Revision: D67119142 Pulled By: kgpai fbshipit-source-id: f4a4259a1bd54c6bb6e7811480f764a9f1a0373a --- velox/docs/functions/spark/json.rst | 35 ++- velox/functions/sparksql/GetJsonObject.h | 219 ++++++++++++++++++ .../sparksql/registration/RegisterJson.cpp | 3 + velox/functions/sparksql/tests/CMakeLists.txt | 1 + .../sparksql/tests/GetJsonObjectTest.cpp | 123 ++++++++++ 5 files changed, 380 insertions(+), 1 deletion(-) create mode 100644 velox/functions/sparksql/GetJsonObject.h create mode 100644 velox/functions/sparksql/tests/GetJsonObjectTest.cpp diff --git a/velox/docs/functions/spark/json.rst b/velox/docs/functions/spark/json.rst index 800487398688..5f853f0698e1 100644 --- a/velox/docs/functions/spark/json.rst +++ b/velox/docs/functions/spark/json.rst @@ -2,9 +2,42 @@ JSON Functions ============== +JSON Format +----------- + +JSON is a language-independent data format that represents data as +human-readable text. A JSON text can represent a number, a boolean, a +string, an array, an object, or a null. A JSON text representing a string +must escape all characters and enclose the string in double quotes, e.g., +``"123\n"``, whereas a JSON text representing a number does not need to, +e.g., ``123``. A JSON text representing an array must enclose the array +elements in square brackets, e.g., ``[1,2,3]``. More detailed grammar can +be found in `this JSON introduction`_. + +.. _this JSON introduction: https://www.json.org + +JSON Functions +-------------- + +.. spark:function:: get_json_object(jsonString, path) -> varchar + + Returns a json object, represented by VARCHAR, from ``jsonString`` by searching ``path``. + Valid ``path`` should start with '$' and then contain "[index]", "['field']" or ".field" + to define a JSON path. Here are some examples: "$.a" "$.a.b", "$[0]['a'].b". Returns + ``jsonString`` if ``path`` is "$". Returns NULL if ``jsonString`` or ``path`` is malformed. + Returns NULL if ``path`` does not exist. :: + + SELECT get_json_object('{"a":"b"}', '$.a'); -- 'b' + SELECT get_json_object('{"a":{"b":"c"}}', '$.a'); -- '{"b":"c"}' + SELECT get_json_object('{"a":3}', '$.b'); -- NULL (unexisting field) + SELECT get_json_object('{"a"-3}'', '$.a'); -- NULL (malformed JSON string) + SELECT get_json_object('{"a":3}'', '.a'); -- NULL (malformed JSON path) + .. spark:function:: json_object_keys(jsonString) -> array(string) - Returns all the keys of the outermost JSON object as an array if a valid JSON object is given. If it is any other valid JSON string, an invalid JSON string or an empty string, the function returns null. :: + Returns all the keys of the outermost JSON object as an array if a valid JSON object is given. + If it is any other valid JSON string, an invalid JSON string or an empty string, the function + returns null. :: SELECT json_object_keys('{}'); -- [] SELECT json_object_keys('{"name": "Alice", "age": 5, "id": "001"}'); -- ['name', 'age', 'id'] diff --git a/velox/functions/sparksql/GetJsonObject.h b/velox/functions/sparksql/GetJsonObject.h new file mode 100644 index 000000000000..5bb249974338 --- /dev/null +++ b/velox/functions/sparksql/GetJsonObject.h @@ -0,0 +1,219 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/Macros.h" +#include "velox/functions/prestosql/json/SIMDJsonUtil.h" + +namespace facebook::velox::functions::sparksql { + +/// Parses a JSON string and returns the value at the specified path. +/// Simdjson On-Demand API is used to parse JSON string. +/// get_json_object(jsonString, path) -> value +template +struct GetJsonObjectFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + // ASCII input always produces ASCII result. + static constexpr bool is_default_ascii_behavior = true; + + FOLLY_ALWAYS_INLINE void initialize( + const std::vector& /*inputTypes*/, + const core::QueryConfig& /*config*/, + const arg_type* /*json*/, + const arg_type* jsonPath) { + if (jsonPath != nullptr && checkJsonPath(*jsonPath)) { + jsonPath_ = removeSingleQuotes(*jsonPath); + } + } + + FOLLY_ALWAYS_INLINE bool call( + out_type& result, + const arg_type& json, + const arg_type& jsonPath) { + // Spark requires the first char in jsonPath is '$'. + if (!checkJsonPath(jsonPath)) { + return false; + } + // jsonPath is "$". + if (jsonPath.size() == 1) { + result.append(json); + return true; + } + simdjson::ondemand::document jsonDoc; + simdjson::padded_string paddedJson(json.data(), json.size()); + if (simdjsonParse(paddedJson).get(jsonDoc)) { + return false; + } + const auto formattedJsonPath = jsonPath_.has_value() + ? jsonPath_.value() + : removeSingleQuotes(jsonPath); + try { + // Can return error result or throw exception possibly. + auto rawResult = jsonDoc.at_path(formattedJsonPath); + if (rawResult.error()) { + return false; + } + + if (!extractStringResult(rawResult, result)) { + return false; + } + } catch (simdjson::simdjson_error& e) { + return false; + } + + const char* currentPos; + if (jsonDoc.current_location().get(currentPos)) { + return false; + } + + return isValidEndingCharacter(currentPos); + } + + private: + FOLLY_ALWAYS_INLINE bool checkJsonPath(StringView jsonPath) { + // Spark requires the first char in jsonPath is '$'. + if (jsonPath.empty() || jsonPath.data()[0] != '$') { + return false; + } + return true; + } + + // Spark's json path requires field name surrounded by single quotes if it is + // specified in "[]". But simdjson lib requires not. This method just removes + // such single quotes to adapt to simdjson lib, e.g., converts "['a']['b']" to + // "[a][b]". + std::string removeSingleQuotes(StringView jsonPath) { + // Skip the initial "$". + std::string result(jsonPath.data() + 1, jsonPath.size() - 1); + size_t pairEnd = 0; + while (true) { + auto pairBegin = result.find("['", pairEnd); + if (pairBegin == std::string::npos) { + break; + } + pairEnd = result.find("]", pairBegin); + // If expected pattern, like ['a'], is not found. + if (pairEnd == std::string::npos || result[pairEnd - 1] != '\'') { + return "-1"; + } + result.erase(pairEnd - 1, 1); + result.erase(pairBegin + 1, 1); + pairEnd -= 2; + } + return result; + } + + // Extracts a string representation from a simdjson result. Handles various + // JSON types including numbers, booleans, strings, objects, and arrays. + // Returns true if the conversion is successful. Otherwise, returns false. + bool extractStringResult( + simdjson::simdjson_result rawResult, + out_type& result) { + std::stringstream ss; + switch (rawResult.type()) { + // For number and bool types, we need to explicitly get the value + // for specific types instead of using `ss << rawResult`. Thus, we + // can make simdjson's internal parsing position moved and then we + // can check the validity of ending character. + case simdjson::ondemand::json_type::number: { + switch (rawResult.get_number_type()) { + case simdjson::ondemand::number_type::unsigned_integer: { + uint64_t numberResult; + if (!rawResult.get_uint64().get(numberResult)) { + ss << numberResult; + result.append(ss.str()); + return true; + } + return false; + } + case simdjson::ondemand::number_type::signed_integer: { + int64_t numberResult; + if (!rawResult.get_int64().get(numberResult)) { + ss << numberResult; + result.append(ss.str()); + return true; + } + return false; + } + case simdjson::ondemand::number_type::floating_point_number: { + double numberResult; + if (!rawResult.get_double().get(numberResult)) { + ss << rawResult; + result.append(ss.str()); + return true; + } + return false; + } + default: + VELOX_UNREACHABLE(); + } + } + case simdjson::ondemand::json_type::boolean: { + bool boolResult; + if (!rawResult.get_bool().get(boolResult)) { + result.append(boolResult ? "true" : "false"); + return true; + } + return false; + } + case simdjson::ondemand::json_type::string: { + std::string_view stringResult; + if (!rawResult.get_string().get(stringResult)) { + result.append(stringResult); + return true; + } + return false; + } + case simdjson::ondemand::json_type::object: { + // For nested case, e.g., for "{"my": {"hello": 10}}", "$.my" will + // return an object type. + ss << rawResult; + result.append(ss.str()); + return true; + } + case simdjson::ondemand::json_type::array: { + ss << rawResult; + result.append(ss.str()); + return true; + } + default: + return false; + } + } + + // Checks whether the obtained result is followed by valid char. Because + // On-Demand API we are using ignores json format validation for characters + // following the current parsing position. As json doc is padded with NULL + // characters, it's safe to do recursively check. + bool isValidEndingCharacter(const char* currentPos) { + char endingChar = *currentPos; + if (endingChar == ',' || endingChar == '}' || endingChar == ']') { + return true; + } + // These chars can be prior to a valid ending char. See reference: + // https://github.com/simdjson/simdjson/blob/v3.9.0/dependencies/jsoncppdist/jsoncpp.cpp + if (endingChar == ' ' || endingChar == '\r' || endingChar == '\n' || + endingChar == '\t') { + return isValidEndingCharacter(++currentPos); + } + return false; + } + + // Used for constant json path. + std::optional jsonPath_; +}; + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/registration/RegisterJson.cpp b/velox/functions/sparksql/registration/RegisterJson.cpp index e98052563f8e..7f41807ceb94 100644 --- a/velox/functions/sparksql/registration/RegisterJson.cpp +++ b/velox/functions/sparksql/registration/RegisterJson.cpp @@ -14,11 +14,14 @@ * limitations under the License. */ #include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/sparksql/GetJsonObject.h" #include "velox/functions/sparksql/JsonObjectKeys.h" namespace facebook::velox::functions::sparksql { void registerJsonFunctions(const std::string& prefix) { + registerFunction( + {prefix + "get_json_object"}); registerFunction, Varchar>( {prefix + "json_object_keys"}); } diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index 24772505d82d..39087bd8adb5 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -33,6 +33,7 @@ add_executable( DecimalRoundTest.cpp DecimalUtilTest.cpp ElementAtTest.cpp + GetJsonObjectTest.cpp HashTest.cpp InTest.cpp JsonObjectKeysTest.cpp diff --git a/velox/functions/sparksql/tests/GetJsonObjectTest.cpp b/velox/functions/sparksql/tests/GetJsonObjectTest.cpp new file mode 100644 index 000000000000..3c370f531864 --- /dev/null +++ b/velox/functions/sparksql/tests/GetJsonObjectTest.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" +#include "velox/type/Type.h" + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class GetJsonObjectTest : public SparkFunctionBaseTest { + protected: + std::optional getJsonObject( + const std::string& json, + const std::string& jsonPath) { + return evaluateOnce( + "get_json_object(c0, c1)", + std::optional(json), + std::optional(jsonPath)); + } +}; + +TEST_F(GetJsonObjectTest, basic) { + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.hello"), "3.5"); + EXPECT_EQ(getJsonObject(R"({"hello": 3.5})", "$.hello"), "3.5"); + EXPECT_EQ(getJsonObject(R"({"hello": 292222730})", "$.hello"), "292222730"); + EXPECT_EQ(getJsonObject(R"({"hello": -292222730})", "$.hello"), "-292222730"); + EXPECT_EQ(getJsonObject(R"({"my": {"hello": 3.5}})", "$.my.hello"), "3.5"); + EXPECT_EQ(getJsonObject(R"({"my": {"hello": true}})", "$.my.hello"), "true"); + EXPECT_EQ(getJsonObject(R"({"hello": ""})", "$.hello"), ""); + EXPECT_EQ( + "0.0215434648799772", + getJsonObject(R"({"score":0.0215434648799772})", "$.score")); + // Returns input json if json path is "$". + EXPECT_EQ( + getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$"), + R"({"name": "Alice", "age": 5, "id": "001"})"); + EXPECT_EQ( + getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$.age"), + "5"); + EXPECT_EQ( + getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$.id"), + "001"); + EXPECT_EQ( + getJsonObject( + R"([{"my": {"info": {"name": "Alice", "age": "5", "id": "001"}}}, {"other": "v1"}])", + "$[0]['my']['info']['age']"), + "5"); + EXPECT_EQ( + getJsonObject( + R"([{"my": {"info": {"name": "Alice", "age": "5", "id": "001"}}}, {"other": "v1"}])", + "$[0].my.info.age"), + "5"); + + // Json object as result. + EXPECT_EQ( + getJsonObject( + R"({"my": {"info": {"name": "Alice", "age": "5", "id": "001"}}})", + "$.my.info"), + R"({"name": "Alice", "age": "5", "id": "001"})"); + EXPECT_EQ( + getJsonObject( + R"({"my": {"info": {"name": "Alice", "age": "5", "id": "001"}}})", + "$['my']['info']"), + R"({"name": "Alice", "age": "5", "id": "001"})"); + + // Array as result. + EXPECT_EQ( + getJsonObject( + R"([{"my": {"info": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])", + "$[1].other"), + R"(["v1", "v2"])"); + // Array element as result. + EXPECT_EQ( + getJsonObject( + R"([{"my": {"info": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])", + "$[1].other[0]"), + "v1"); + EXPECT_EQ( + getJsonObject( + R"([{"my": {"info": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])", + "$[1].other[1]"), + "v2"); +} + +TEST_F(GetJsonObjectTest, nullResult) { + // Field not found. + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.hi"), std::nullopt); + + // Illegal json. + EXPECT_EQ(getJsonObject(R"({"hello"-3.5})", "$.hello"), std::nullopt); + EXPECT_EQ(getJsonObject(R"({"a": bad, "b": string})", "$.a"), std::nullopt); + + // Illegal json path. + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$hello"), std::nullopt); + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$."), std::nullopt); + // The first char is not '$'. + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", ".hello"), std::nullopt); + // Constains '$' not in the first position. + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.$hello"), std::nullopt); + + // Invalid ending character. + EXPECT_EQ( + getJsonObject( + R"([{"my": {"info": {"name": "Alice"quoted""}}}, {"other": ["v1", "v2"]}])", + "$[0].my.info.name"), + std::nullopt); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test