diff --git a/velox/docs/functions/presto/math.rst b/velox/docs/functions/presto/math.rst index 9077da198b56..5e599286b486 100644 --- a/velox/docs/functions/presto/math.rst +++ b/velox/docs/functions/presto/math.rst @@ -27,6 +27,17 @@ Mathematical Functions verified for performance reasons. Returns ``high`` for all values of ``x`` when ``low`` is greater than ``high``. +.. function:: cosine_similarity(map(varchar, double), map(varchar, double)) -> double + + Returns the `cosine similarity `_ between the vectors represented as map(varchar, double). + If any input map is empty, the function returns NaN. + + SELECT cosine_similarity(MAP(ARRAY['a'], ARRAY[1.0]), MAP(ARRAY['a'], ARRAY[2.0])); -- 1.0 + + SELECT cosine_similarity(MAP(ARRAY['a', 'b'], ARRAY[1.0, 2.0]), MAP(ARRAY['a', 'b'], ARRAY[NULL, 3.0])); -- NULL + + SELECT cosine_similarity(MAP(ARRAY[], ARRAY[]), MAP(ARRAY['a', 'b'], ARRAY[2, 3])); -- NaN + .. function:: degrees(x) -> double Converts angle x in radians to degrees. diff --git a/velox/functions/prestosql/Arithmetic.h b/velox/functions/prestosql/Arithmetic.h index 552d8a9cf26e..f5b69d4007b4 100644 --- a/velox/functions/prestosql/Arithmetic.h +++ b/velox/functions/prestosql/Arithmetic.h @@ -28,6 +28,7 @@ #include "folly/CPortability.h" #include "velox/common/base/Exceptions.h" #include "velox/functions/Macros.h" +#include "velox/functions/Udf.h" #include "velox/functions/prestosql/ArithmeticImpl.h" namespace facebook::velox::functions { @@ -531,5 +532,56 @@ struct WilsonIntervalLowerFunction { } }; +template +struct CosineSimilarityFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + double normalizeMap(const null_free_arg_type>& map) { + double norm = 0.0; + for (const auto& [key, value] : map) { + norm += (value * value); + } + return std::sqrt(norm); + } + + double mapDotProduct( + const null_free_arg_type>& leftMap, + const null_free_arg_type>& rightMap) { + double result = 0.0; + for (const auto& [key, value] : leftMap) { + auto it = rightMap.find(key); + if (it != rightMap.end()) { + result += value * it->second; + } + } + return result; + } + + void callNullFree( + out_type& result, + const null_free_arg_type>& leftMap, + const null_free_arg_type>& rightMap) { + if (leftMap.empty() || rightMap.empty()) { + result = std::numeric_limits::quiet_NaN(); + return; + } + + double normLeftMap = normalizeMap(leftMap); + if (normLeftMap == 0.0) { + result = std::numeric_limits::quiet_NaN(); + return; + } + + double normRightMap = normalizeMap(rightMap); + if (normRightMap == 0.0) { + result = std::numeric_limits::quiet_NaN(); + return; + } + + double dotProduct = mapDotProduct(leftMap, rightMap); + result = dotProduct / (normLeftMap * normRightMap); + } +}; + } // namespace } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/ArithmeticFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ArithmeticFunctionsRegistration.cpp index 1812065700c9..697ff740de0f 100644 --- a/velox/functions/prestosql/registration/ArithmeticFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ArithmeticFunctionsRegistration.cpp @@ -137,6 +137,11 @@ void registerSimpleFunctions(const std::string& prefix) { int64_t, int64_t, double>({prefix + "wilson_interval_lower"}); + registerFunction< + CosineSimilarityFunction, + double, + Map, + Map>({prefix + "cosine_similarity"}); } } // namespace diff --git a/velox/functions/prestosql/tests/ArithmeticTest.cpp b/velox/functions/prestosql/tests/ArithmeticTest.cpp index aa4ef98a63fd..9b35bdcefd81 100644 --- a/velox/functions/prestosql/tests/ArithmeticTest.cpp +++ b/velox/functions/prestosql/tests/ArithmeticTest.cpp @@ -821,5 +821,51 @@ TEST_F(ArithmeticTest, wilsonIntervalUpper) { EXPECT_DOUBLE_EQ(wilsonIntervalUpper(1, 3, kInf).value(), 1.0); } +TEST_F(ArithmeticTest, cosineSimilarity) { + const auto cosineSimilarity = + [&](const std::vector>>& + left, + const std::vector>>& + right) { + auto leftMap = makeMapVector({left}); + auto rightMap = makeMapVector({right}); + return evaluateOnce( + "cosine_similarity(c0,c1)", + makeRowVector({leftMap, rightMap})) + .value(); + }; + + EXPECT_DOUBLE_EQ( + (2.0 * 3.0) / (std::sqrt(5.0) * std::sqrt(10.0)), + cosineSimilarity({{"a", 1}, {"b", 2}}, {{"c", 1}, {"b", 3}})); + + EXPECT_DOUBLE_EQ( + (2.0 * 3.0 + (-1) * 1) / (std::sqrt(1 + 4 + 1) * std::sqrt(1 + 9)), + cosineSimilarity({{"a", 1}, {"b", 2}, {"c", -1}}, {{"c", 1}, {"b", 3}})); + + EXPECT_DOUBLE_EQ( + (2.0 * 3.0 + (-1) * 1) / (std::sqrt(1 + 4 + 1) * std::sqrt(1 + 9)), + cosineSimilarity({{"a", 1}, {"b", 2}, {"c", -1}}, {{"c", 1}, {"b", 3}})); + + EXPECT_DOUBLE_EQ( + 0.0, + cosineSimilarity({{"a", 1}, {"b", 2}, {"c", -1}}, {{"d", 1}, {"e", 3}})); + + EXPECT_TRUE(std::isnan(cosineSimilarity({}, {}))); + EXPECT_TRUE(std::isnan(cosineSimilarity({{"d", 1}, {"e", 3}}, {}))); + EXPECT_TRUE( + std::isnan(cosineSimilarity({{"a", 1}, {"b", 3}}, {{"a", 0}, {"b", 0}}))); + + auto nullableLeftMap = makeNullableMapVector( + {{{{"a"_sv, 1}, {"b"_sv, std::nullopt}}}}); + auto rightMap = + makeMapVector({{{{"c"_sv, 1}, {"b"_sv, 3}}}}); + + EXPECT_FALSE(evaluateOnce( + "cosine_similarity(c0,c1)", + makeRowVector({nullableLeftMap, rightMap})) + .has_value()); +} + } // namespace } // namespace facebook::velox