From a130005c0520fd174d02f54147cc27bdcfa0c6d8 Mon Sep 17 00:00:00 2001 From: Wills Feng Date: Mon, 15 Apr 2024 14:47:47 -0700 Subject: [PATCH] add inverse_weibull_cdf --- velox/docs/functions/presto/math.rst | 5 ++ velox/functions/prestosql/Probability.h | 20 ++++++++ .../ArithmeticFunctionsRegistration.cpp | 4 ++ .../prestosql/tests/ProbabilityTest.cpp | 47 +++++++++++++++++++ 4 files changed, 76 insertions(+) diff --git a/velox/docs/functions/presto/math.rst b/velox/docs/functions/presto/math.rst index 82a5a73594b31..002b40bdaf265 100644 --- a/velox/docs/functions/presto/math.rst +++ b/velox/docs/functions/presto/math.rst @@ -319,6 +319,11 @@ Probability Functions: inverse_cdf probability (p): P(N < n). The a, b parameters must be positive real values (all of type DOUBLE). The probability p must lie on the interval [0, 1]. +.. function:: inverse_weibull_cdf(a, b, p) -> double + + Compute the inverse of the Weibull cdf with given parameters ``a``, ``b`` for the probability ``p``. + The ``a``, ``b`` parameters must be positive double values. The probability ``p`` must be a double + on the interval [0, 1]. ==================================== Statistical Functions diff --git a/velox/functions/prestosql/Probability.h b/velox/functions/prestosql/Probability.h index d3517ee2afd48..68eb4c39ea9be 100644 --- a/velox/functions/prestosql/Probability.h +++ b/velox/functions/prestosql/Probability.h @@ -249,5 +249,25 @@ struct WeibullCDFFunction { } }; +template +struct InverseWeibullCDFFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call(double& result, double a, double b, double p) { + static constexpr double kInf = std::numeric_limits::infinity(); + + VELOX_USER_CHECK((p >= 0) && (p <= 1), "p must be in the interval [0, 1]"); + VELOX_USER_CHECK_GT(a, 0, "a must be greater than 0"); + VELOX_USER_CHECK_GT(b, 0, "b must be greater than 0"); + + if (b == kInf) { + result = kInf; + } else { + // https://commons.apache.org/proper/commons-math/javadocs/api-3.6.1/org/apache/commons/math3/distribution/WeibullDistribution.html#inverseCumulativeProbability(double) + result = b * std::pow(-std::log1p(-p), 1.0 / a); + } + } +}; + } // namespace } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/ArithmeticFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ArithmeticFunctionsRegistration.cpp index 3984366fc9ae7..f0d12a08a3ff9 100644 --- a/velox/functions/prestosql/registration/ArithmeticFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ArithmeticFunctionsRegistration.cpp @@ -189,6 +189,10 @@ void registerSimpleFunctions(const std::string& prefix) { Map>({prefix + "cosine_similarity"}); registerFunction( {prefix + "weibull_cdf"}); + registerFunction( + {prefix + "inverse_beta_cdf"}); + registerFunction( + {prefix + "inverse_weibull_cdf"}); } } // namespace diff --git a/velox/functions/prestosql/tests/ProbabilityTest.cpp b/velox/functions/prestosql/tests/ProbabilityTest.cpp index 8ccbd82d12610..28082ef2a0e4e 100644 --- a/velox/functions/prestosql/tests/ProbabilityTest.cpp +++ b/velox/functions/prestosql/tests/ProbabilityTest.cpp @@ -453,5 +453,52 @@ TEST_F(ProbabilityTest, weibullCDF) { weibullCDF(kDoubleMin, kNan, kDoubleMax), "b must be greater than 0"); } +TEST_F(ProbabilityTest, inverseWeibullCDF) { + const auto inverseWeibullCDF = [&](std::optional a, + std::optional b, + std::optional p) { + return evaluateOnce("inverse_weibull_cdf(c0, c1, c2)", a, b, p); + }; + + EXPECT_EQ(inverseWeibullCDF(1.0, 1.0, 0.), 0.0); + EXPECT_EQ(inverseWeibullCDF(1.0, 1.0, 0.632), 0.9996723408132061); + EXPECT_EQ(inverseWeibullCDF(1.0, 0.6, 0.91), 1.4447673651911233); + + EXPECT_EQ(inverseWeibullCDF(std::nullopt, 1.0, 0.3), std::nullopt); + EXPECT_EQ(inverseWeibullCDF(1.0, std::nullopt, 0.2), std::nullopt); + EXPECT_EQ(inverseWeibullCDF(1.0, 0.4, std::nullopt), std::nullopt); + + EXPECT_EQ(inverseWeibullCDF(kDoubleMin, 1.0, 0.3), 0.0); + EXPECT_EQ(inverseWeibullCDF(kDoubleMax, 1.0, 0.4), 1.0); + EXPECT_EQ(inverseWeibullCDF(1.0, kDoubleMin, 0.5), 1.5423036715619055e-308); + EXPECT_EQ(inverseWeibullCDF(1.0, kDoubleMax, 0.7), kInf); + EXPECT_EQ(inverseWeibullCDF(kInf, 1.0, 0.1), 1.0); + EXPECT_EQ(inverseWeibullCDF(kDoubleMin, kDoubleMin, 0.9), kInf); + EXPECT_EQ( + inverseWeibullCDF(kDoubleMax, kDoubleMax, 0.8), 1.7976931348623157e+308); + EXPECT_EQ(inverseWeibullCDF(kInf, 999999999.9, 0.4), 9.999999999E8); + EXPECT_THAT(inverseWeibullCDF(1.0, kInf, 0.2), IsInf()); + + VELOX_ASSERT_THROW(inverseWeibullCDF(0, 3, 0.5), "a must be greater than 0"); + VELOX_ASSERT_THROW(inverseWeibullCDF(3, 0, 0.5), "b must be greater than 0"); + VELOX_ASSERT_THROW( + inverseWeibullCDF(3, 5, -0.1), "p must be in the interval [0, 1]"); + VELOX_ASSERT_THROW( + inverseWeibullCDF(3, 5, 1.1), "p must be in the interval [0, 1]"); + VELOX_ASSERT_THROW( + inverseWeibullCDF(kNan, 1.0, 0.1), "a must be greater than 0"); + VELOX_ASSERT_THROW( + inverseWeibullCDF(1.0, kNan, 0.4), "b must be greater than 0"); + VELOX_ASSERT_THROW( + inverseWeibullCDF(1.0, 1.0, kNan), "p must be in the interval [0, 1]"); + VELOX_ASSERT_THROW( + inverseWeibullCDF(999999999.9, 999999999.0, kInf), + "p must be in the interval [0, 1]"); + VELOX_ASSERT_THROW( + inverseWeibullCDF(-1.0, 1.0, 0.1), "a must be greater than 0"); + VELOX_ASSERT_THROW( + inverseWeibullCDF(1.0, -1.0, 0.4), "b must be greater than 0"); +} + } // namespace } // namespace facebook::velox