Skip to content

Commit

Permalink
Back out "Refactor greatest and least Presto functions using simple f…
Browse files Browse the repository at this point in the history
…unction API" (facebookincubator#9493)

Summary:
Temporarily reverting as the switch to using a simple function implementation for 'greatest' and 'least' functions are causing issues registering the UDF for some internal use-cases.

Pull Request resolved: facebookincubator#9493

Original commit changeset: c389bad91197

Original Phabricator Diff: D55793910n

Reviewed By: wqfish, bikramSingh91

Differential Revision: D56160832

fbshipit-source-id: f7550b819f8b8f276b88cb33c52de05807a4f2d2
  • Loading branch information
s4ayub authored and facebook-github-bot committed Apr 16, 2024
1 parent 62fb397 commit fd5643a
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 175 deletions.
1 change: 1 addition & 0 deletions velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ add_library(
FindFirst.cpp
FromUnixTime.cpp
FromUtf8.cpp
GreatestLeast.cpp
InPredicate.cpp
JsonFunctions.cpp
Map.cpp
Expand Down
207 changes: 207 additions & 0 deletions velox/functions/prestosql/GreatestLeast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cmath>
#include <type_traits>
#include "velox/common/base/Exceptions.h"
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/type/Type.h"

namespace facebook::velox::functions {

namespace {

template <bool>
class ExtremeValueFunction;

using LeastFunction = ExtremeValueFunction<true>;
using GreatestFunction = ExtremeValueFunction<false>;

/**
* This class implements two functions:
*
* greatest(value1, value2, ..., valueN) → [same as input]
* Returns the largest of the provided values.
*
* least(value1, value2, ..., valueN) → [same as input]
* Returns the smallest of the provided values.
**/
template <bool isLeast>
class ExtremeValueFunction : public exec::VectorFunction {
private:
template <typename T>
bool shouldOverride(const T& currentValue, const T& candidateValue) const {
return isLeast ? candidateValue < currentValue
: candidateValue > currentValue;
}

// For double, presto should throw error if input is Nan
template <typename T>
void checkNan(const T& value) const {
if constexpr (std::is_same_v<T, TypeTraits<TypeKind::DOUBLE>::NativeType>) {
if (std::isnan(value)) {
VELOX_USER_FAIL(
"Invalid argument to {}: NaN", isLeast ? "least()" : "greatest()");
}
}
}

template <typename T>
void applyTyped(
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const {
context.ensureWritable(rows, outputType, result);
result->clearNulls(rows);

auto* flatResult = result->as<FlatVector<T>>();
BufferPtr resultValues = flatResult->mutableValues(rows.end());
T* __restrict rawResult = resultValues->asMutable<T>();

exec::DecodedArgs decodedArgs(rows, args, context);

std::set<size_t> usedInputs;
context.applyToSelectedNoThrow(rows, [&](int row) {
size_t valueIndex = 0;

T currentValue = decodedArgs.at(0)->valueAt<T>(row);
checkNan(currentValue);

for (auto i = 1; i < args.size(); ++i) {
auto candidateValue = decodedArgs.at(i)->template valueAt<T>(row);
checkNan(candidateValue);

if constexpr (isLeast) {
if (candidateValue < currentValue) {
currentValue = candidateValue;
valueIndex = i;
}
} else {
if (candidateValue > currentValue) {
currentValue = candidateValue;
valueIndex = i;
}
}
}
usedInputs.insert(valueIndex);

if constexpr (std::is_same_v<bool, T>) {
flatResult->set(row, currentValue);
} else {
rawResult[row] = currentValue;
}
});

if constexpr (std::is_same_v<T, StringView>) {
for (auto index : usedInputs) {
flatResult->acquireSharedStringBuffers(args[index].get());
}
}
}

public:
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
switch (outputType.get()->kind()) {
case TypeKind::BOOLEAN:
applyTyped<bool>(rows, args, outputType, context, result);
return;
case TypeKind::TINYINT:
applyTyped<int8_t>(rows, args, outputType, context, result);
return;
case TypeKind::SMALLINT:
applyTyped<int16_t>(rows, args, outputType, context, result);
return;
case TypeKind::INTEGER:
applyTyped<int32_t>(rows, args, outputType, context, result);
return;
case TypeKind::BIGINT:
applyTyped<int64_t>(rows, args, outputType, context, result);
return;
case TypeKind::HUGEINT:
applyTyped<int128_t>(rows, args, outputType, context, result);
return;
case TypeKind::REAL:
applyTyped<float>(rows, args, outputType, context, result);
return;
case TypeKind::DOUBLE:
applyTyped<double>(rows, args, outputType, context, result);
return;
case TypeKind::VARCHAR:
applyTyped<StringView>(rows, args, outputType, context, result);
return;
case TypeKind::TIMESTAMP:
applyTyped<Timestamp>(rows, args, outputType, context, result);
return;
default:
VELOX_FAIL(
"Unsupported input type for {}: {}",
isLeast ? "least()" : "greatest()",
outputType->toString());
}
}

static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
const std::vector<std::string> types = {
"boolean",
"tinyint",
"smallint",
"integer",
"bigint",
"double",
"real",
"varchar",
"timestamp",
"date",
};
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures;
for (const auto& type : types) {
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType(type)
.argumentType(type)
.variableArity()
.build());
}
signatures.emplace_back(exec::FunctionSignatureBuilder()
.integerVariable("precision")
.integerVariable("scale")
.returnType("DECIMAL(precision, scale)")
.argumentType("DECIMAL(precision, scale)")
.variableArity()
.build());
return signatures;
}
};
} // namespace

VELOX_DECLARE_VECTOR_FUNCTION(
udf_least,
LeastFunction::signatures(),
std::make_unique<LeastFunction>());

VELOX_DECLARE_VECTOR_FUNCTION(
udf_greatest,
GreatestFunction::signatures(),
std::make_unique<GreatestFunction>());

} // namespace facebook::velox::functions
101 changes: 0 additions & 101 deletions velox/functions/prestosql/GreatestLeast.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,9 @@
#include "velox/functions/Registerer.h"
#include "velox/functions/lib/IsNull.h"
#include "velox/functions/prestosql/Cardinality.h"
#include "velox/functions/prestosql/GreatestLeast.h"
#include "velox/functions/prestosql/InPredicate.h"

namespace facebook::velox::functions {

template <typename T>
inline void registerGreatestLeastFunction(const std::string& prefix) {
registerFunction<ParameterBinder<GreatestFunction, T>, T, T, Variadic<T>>(
{prefix + "greatest"});

registerFunction<ParameterBinder<LeastFunction, T>, T, T, Variadic<T>>(
{prefix + "least"});
}

inline void registerAllGreatestLeastFunctions(const std::string& prefix) {
registerGreatestLeastFunction<bool>(prefix);
registerGreatestLeastFunction<int8_t>(prefix);
registerGreatestLeastFunction<int16_t>(prefix);
registerGreatestLeastFunction<int32_t>(prefix);
registerGreatestLeastFunction<int64_t>(prefix);
registerGreatestLeastFunction<float>(prefix);
registerGreatestLeastFunction<double>(prefix);
registerGreatestLeastFunction<Varchar>(prefix);
registerGreatestLeastFunction<LongDecimal<P1, S1>>(prefix);
registerGreatestLeastFunction<ShortDecimal<P1, S1>>(prefix);
registerGreatestLeastFunction<Date>(prefix);
registerGreatestLeastFunction<Timestamp>(prefix);
}

extern void registerSubscriptFunction(
const std::string& name,
bool enableCaching);
Expand Down Expand Up @@ -73,9 +47,11 @@ void registerGeneralFunctions(const std::string& prefix) {
VELOX_REGISTER_VECTOR_FUNCTION(udf_transform, prefix + "transform");
VELOX_REGISTER_VECTOR_FUNCTION(udf_reduce, prefix + "reduce");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_filter, prefix + "filter");
VELOX_REGISTER_VECTOR_FUNCTION(udf_typeof, prefix + "typeof");

registerAllGreatestLeastFunctions(prefix);
VELOX_REGISTER_VECTOR_FUNCTION(udf_least, prefix + "least");
VELOX_REGISTER_VECTOR_FUNCTION(udf_greatest, prefix + "greatest");

VELOX_REGISTER_VECTOR_FUNCTION(udf_typeof, prefix + "typeof");

registerFunction<CardinalityFunction, int64_t, Array<Any>>(
{prefix + "cardinality"});
Expand Down
Loading

0 comments on commit fd5643a

Please sign in to comment.