From fd5643aad261d809021b701555b0bfd5dd2870d0 Mon Sep 17 00:00:00 2001 From: Shabab Ayub Date: Tue, 16 Apr 2024 12:31:21 -0700 Subject: [PATCH] Back out "Refactor greatest and least Presto functions using simple function API" (#9493) Summary: Temporarily reverting as the switch to using a simple function implementation for 'greatest' and 'least' functions are causing issues registering the UDF for some internal use-cases. Pull Request resolved: https://github.com/facebookincubator/velox/pull/9493 Original commit changeset: c389bad91197 Original Phabricator Diff: D55793910n Reviewed By: wqfish, bikramSingh91 Differential Revision: D56160832 fbshipit-source-id: f7550b819f8b8f276b88cb33c52de05807a4f2d2 --- velox/functions/prestosql/CMakeLists.txt | 1 + velox/functions/prestosql/GreatestLeast.cpp | 207 ++++++++++++++++++ velox/functions/prestosql/GreatestLeast.h | 101 --------- .../GeneralFunctionsRegistration.cpp | 32 +-- .../prestosql/tests/GreatestLeastTest.cpp | 69 ++---- 5 files changed, 235 insertions(+), 175 deletions(-) create mode 100644 velox/functions/prestosql/GreatestLeast.cpp delete mode 100644 velox/functions/prestosql/GreatestLeast.h diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index 54d959cbbb37..3a8008be601b 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -34,6 +34,7 @@ add_library( FindFirst.cpp FromUnixTime.cpp FromUtf8.cpp + GreatestLeast.cpp InPredicate.cpp JsonFunctions.cpp Map.cpp diff --git a/velox/functions/prestosql/GreatestLeast.cpp b/velox/functions/prestosql/GreatestLeast.cpp new file mode 100644 index 000000000000..afc085d4a7ea --- /dev/null +++ b/velox/functions/prestosql/GreatestLeast.cpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "velox/common/base/Exceptions.h" +#include "velox/expression/Expr.h" +#include "velox/expression/VectorFunction.h" +#include "velox/type/Type.h" + +namespace facebook::velox::functions { + +namespace { + +template +class ExtremeValueFunction; + +using LeastFunction = ExtremeValueFunction; +using GreatestFunction = ExtremeValueFunction; + +/** + * This class implements two functions: + * + * greatest(value1, value2, ..., valueN) → [same as input] + * Returns the largest of the provided values. + * + * least(value1, value2, ..., valueN) → [same as input] + * Returns the smallest of the provided values. + **/ +template +class ExtremeValueFunction : public exec::VectorFunction { + private: + template + bool shouldOverride(const T& currentValue, const T& candidateValue) const { + return isLeast ? candidateValue < currentValue + : candidateValue > currentValue; + } + + // For double, presto should throw error if input is Nan + template + void checkNan(const T& value) const { + if constexpr (std::is_same_v::NativeType>) { + if (std::isnan(value)) { + VELOX_USER_FAIL( + "Invalid argument to {}: NaN", isLeast ? "least()" : "greatest()"); + } + } + } + + template + void applyTyped( + const SelectivityVector& rows, + const std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const { + context.ensureWritable(rows, outputType, result); + result->clearNulls(rows); + + auto* flatResult = result->as>(); + BufferPtr resultValues = flatResult->mutableValues(rows.end()); + T* __restrict rawResult = resultValues->asMutable(); + + exec::DecodedArgs decodedArgs(rows, args, context); + + std::set usedInputs; + context.applyToSelectedNoThrow(rows, [&](int row) { + size_t valueIndex = 0; + + T currentValue = decodedArgs.at(0)->valueAt(row); + checkNan(currentValue); + + for (auto i = 1; i < args.size(); ++i) { + auto candidateValue = decodedArgs.at(i)->template valueAt(row); + checkNan(candidateValue); + + if constexpr (isLeast) { + if (candidateValue < currentValue) { + currentValue = candidateValue; + valueIndex = i; + } + } else { + if (candidateValue > currentValue) { + currentValue = candidateValue; + valueIndex = i; + } + } + } + usedInputs.insert(valueIndex); + + if constexpr (std::is_same_v) { + flatResult->set(row, currentValue); + } else { + rawResult[row] = currentValue; + } + }); + + if constexpr (std::is_same_v) { + for (auto index : usedInputs) { + flatResult->acquireSharedStringBuffers(args[index].get()); + } + } + } + + public: + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const override { + switch (outputType.get()->kind()) { + case TypeKind::BOOLEAN: + applyTyped(rows, args, outputType, context, result); + return; + case TypeKind::TINYINT: + applyTyped(rows, args, outputType, context, result); + return; + case TypeKind::SMALLINT: + applyTyped(rows, args, outputType, context, result); + return; + case TypeKind::INTEGER: + applyTyped(rows, args, outputType, context, result); + return; + case TypeKind::BIGINT: + applyTyped(rows, args, outputType, context, result); + return; + case TypeKind::HUGEINT: + applyTyped(rows, args, outputType, context, result); + return; + case TypeKind::REAL: + applyTyped(rows, args, outputType, context, result); + return; + case TypeKind::DOUBLE: + applyTyped(rows, args, outputType, context, result); + return; + case TypeKind::VARCHAR: + applyTyped(rows, args, outputType, context, result); + return; + case TypeKind::TIMESTAMP: + applyTyped(rows, args, outputType, context, result); + return; + default: + VELOX_FAIL( + "Unsupported input type for {}: {}", + isLeast ? "least()" : "greatest()", + outputType->toString()); + } + } + + static std::vector> signatures() { + const std::vector types = { + "boolean", + "tinyint", + "smallint", + "integer", + "bigint", + "double", + "real", + "varchar", + "timestamp", + "date", + }; + std::vector> signatures; + for (const auto& type : types) { + signatures.emplace_back(exec::FunctionSignatureBuilder() + .returnType(type) + .argumentType(type) + .variableArity() + .build()); + } + signatures.emplace_back(exec::FunctionSignatureBuilder() + .integerVariable("precision") + .integerVariable("scale") + .returnType("DECIMAL(precision, scale)") + .argumentType("DECIMAL(precision, scale)") + .variableArity() + .build()); + return signatures; + } +}; +} // namespace + +VELOX_DECLARE_VECTOR_FUNCTION( + udf_least, + LeastFunction::signatures(), + std::make_unique()); + +VELOX_DECLARE_VECTOR_FUNCTION( + udf_greatest, + GreatestFunction::signatures(), + std::make_unique()); + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/GreatestLeast.h b/velox/functions/prestosql/GreatestLeast.h deleted file mode 100644 index a648aa5611d7..000000000000 --- a/velox/functions/prestosql/GreatestLeast.h +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include "velox/functions/Macros.h" - -namespace facebook::velox::functions { -namespace details { -/** - * This class implements two functions: - * - * greatest(value1, value2, ..., valueN) → [same as input] - * Returns the largest of the provided values. - * - * least(value1, value2, ..., valueN) → [same as input] - * Returns the smallest of the provided values. - * - * For DOUBLE and REAL type, NaN is considered as the biggest according to - * https://github.com/prestodb/presto/issues/22391 - **/ -template -struct ExtremeValueFunction { - VELOX_DEFINE_FUNCTION_TYPES(TExec); - - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& firstElement, - const arg_type>& remainingElement) { - auto currentValue = firstElement; - - for (auto element : remainingElement) { - auto candidateValue = element.value(); - - if constexpr (isLeast) { - if (smallerThan(candidateValue, currentValue)) { - currentValue = candidateValue; - } - } else { - if (greaterThan(candidateValue, currentValue)) { - currentValue = candidateValue; - } - } - } - - result = currentValue; - } - - private: - template - bool greaterThan(const K& lhs, const K& rhs) const { - if constexpr (std::is_same_v || std::is_same_v) { - if (std::isnan(lhs)) { - return true; - } - - if (std::isnan(rhs)) { - return false; - } - } - - return lhs > rhs; - } - - template - bool smallerThan(const K& lhs, const K& rhs) const { - if constexpr (std::is_same_v || std::is_same_v) { - if (std::isnan(lhs)) { - return false; - } - - if (std::isnan(rhs)) { - return true; - } - } - - return lhs < rhs; - } -}; -} // namespace details - -template -using LeastFunction = details::ExtremeValueFunction; - -template -using GreatestFunction = details::ExtremeValueFunction; - -} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp index 37acd5d68130..c5a10411eb86 100644 --- a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp @@ -17,35 +17,9 @@ #include "velox/functions/Registerer.h" #include "velox/functions/lib/IsNull.h" #include "velox/functions/prestosql/Cardinality.h" -#include "velox/functions/prestosql/GreatestLeast.h" #include "velox/functions/prestosql/InPredicate.h" namespace facebook::velox::functions { - -template -inline void registerGreatestLeastFunction(const std::string& prefix) { - registerFunction, T, T, Variadic>( - {prefix + "greatest"}); - - registerFunction, T, T, Variadic>( - {prefix + "least"}); -} - -inline void registerAllGreatestLeastFunctions(const std::string& prefix) { - registerGreatestLeastFunction(prefix); - registerGreatestLeastFunction(prefix); - registerGreatestLeastFunction(prefix); - registerGreatestLeastFunction(prefix); - registerGreatestLeastFunction(prefix); - registerGreatestLeastFunction(prefix); - registerGreatestLeastFunction(prefix); - registerGreatestLeastFunction(prefix); - registerGreatestLeastFunction>(prefix); - registerGreatestLeastFunction>(prefix); - registerGreatestLeastFunction(prefix); - registerGreatestLeastFunction(prefix); -} - extern void registerSubscriptFunction( const std::string& name, bool enableCaching); @@ -73,9 +47,11 @@ void registerGeneralFunctions(const std::string& prefix) { VELOX_REGISTER_VECTOR_FUNCTION(udf_transform, prefix + "transform"); VELOX_REGISTER_VECTOR_FUNCTION(udf_reduce, prefix + "reduce"); VELOX_REGISTER_VECTOR_FUNCTION(udf_array_filter, prefix + "filter"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_typeof, prefix + "typeof"); - registerAllGreatestLeastFunctions(prefix); + VELOX_REGISTER_VECTOR_FUNCTION(udf_least, prefix + "least"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_greatest, prefix + "greatest"); + + VELOX_REGISTER_VECTOR_FUNCTION(udf_typeof, prefix + "typeof"); registerFunction>( {prefix + "cardinality"}); diff --git a/velox/functions/prestosql/tests/GreatestLeastTest.cpp b/velox/functions/prestosql/tests/GreatestLeastTest.cpp index 9bfddc552144..e19e13d61410 100644 --- a/velox/functions/prestosql/tests/GreatestLeastTest.cpp +++ b/velox/functions/prestosql/tests/GreatestLeastTest.cpp @@ -14,8 +14,6 @@ * limitations under the License. */ -#include -#include #include #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" @@ -85,50 +83,29 @@ TEST_F(GreatestLeastTest, leastReal) { {0, -100, -1.1}); } -TEST_F(GreatestLeastTest, greatestNanInput) { - auto constexpr kInf32 = std::numeric_limits::infinity(); - auto constexpr kInf64 = std::numeric_limits::infinity(); - - auto greatestFloat = [&](float a, float b, float c) { - return evaluateOnce( - "greatest(c0, c1, c2)", {a}, {b}, {c}) - .value(); - }; - - auto greatestDouble = [&](double a, double b, double c) { - return evaluateOnce( - "greatest(c0, c1, c2)", {a}, {b}, {c}) - .value(); - }; - - EXPECT_TRUE(std::isnan(greatestFloat(1.0, std::nanf("1"), 2.0))); - EXPECT_TRUE(std::isnan(greatestFloat(std::nanf("1"), 1.0, kInf32))); - - EXPECT_TRUE(std::isnan(greatestDouble(1.0, std::nan("1"), 2.0))); - EXPECT_TRUE(std::isnan(greatestDouble(std::nan("1"), 1.0, kInf64))); -} - -TEST_F(GreatestLeastTest, leastNanInput) { - auto constexpr kInf32 = std::numeric_limits::infinity(); - auto constexpr kInf64 = std::numeric_limits::infinity(); - - auto leastFloat = [&](float a, float b, float c) { - return evaluateOnce( - "least(c0, c1, c2)", {a}, {b}, {c}) - .value(); - }; - - auto leastDouble = [&](double a, double b, double c) { - return evaluateOnce( - "least(c0, c1, c2)", {a}, {b}, {c}) - .value(); - }; - - EXPECT_EQ(leastFloat(1.0, std::nanf("1"), 0.5), 0.5); - EXPECT_EQ(leastFloat(std::nanf("1"), 1.0, -kInf32), -kInf32); - - EXPECT_EQ(leastDouble(1.0, std::nan("1"), 0.5), 0.5); - EXPECT_EQ(leastDouble(std::nan("1"), 1.0, -kInf64), -kInf64); +TEST_F(GreatestLeastTest, nanInput) { + // Presto rejects NaN inputs of type DOUBLE, but allows NaN inputs of type + // REAL. + std::vector input{0, 1.1, std::nan("1")}; + VELOX_ASSERT_THROW( + runTest("least(c0)", {{0.0 / 0.0}}, {0}), + "Invalid argument to least(): NaN"); + runTest("try(least(c0, 1.0))", {input}, {0, 1.0, std::nullopt}); + + VELOX_ASSERT_THROW( + runTest("greatest(c0)", {1, {0.0 / 0.0}}, {1, 0}), + "Invalid argument to greatest(): NaN"); + runTest("try(greatest(c0, 1.0))", {input}, {1.0, 1.1, std::nullopt}); + + auto result = evaluateOnce( + "is_nan(least(c0))", std::nanf("1"), 1.2); + ASSERT_TRUE(result.has_value()); + ASSERT_TRUE(result.value()); + + result = evaluateOnce( + "is_nan(greatest(c0))", std::nanf("1"), 1.2); + ASSERT_TRUE(result.has_value()); + ASSERT_TRUE(result.value()); } TEST_F(GreatestLeastTest, greatestDouble) {