Skip to content

Commit

Permalink
Add cosine_similarity Presto function (#7374)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #7374

- https://en.wikipedia.org/wiki/Cosine_similarity
- [Presto Java implementation](https://github.com/prestodb/presto/blob/66aab8b7ef886fcebf38fae2d2e9094f12fe3f6b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java#L1627)

Reviewed By: laithsakka, mbasmanova

Differential Revision: D50589652

fbshipit-source-id: e0ddbdb47de65eb362c877d5d9e24256bdc7a71c
  • Loading branch information
amitkdutta authored and facebook-github-bot committed Nov 8, 2023
1 parent aec77a5 commit fb34b0f
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 0 deletions.
11 changes: 11 additions & 0 deletions velox/docs/functions/presto/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ Mathematical Functions
verified for performance reasons. Returns ``high`` for all values of ``x``
when ``low`` is greater than ``high``.

.. function:: cosine_similarity(map(varchar, double), map(varchar, double)) -> double

Returns the `cosine similarity <https://en.wikipedia.org/wiki/Cosine_similarity>`_ between the vectors represented as map(varchar, double).
If any input map is empty, the function returns NaN.

SELECT cosine_similarity(MAP(ARRAY['a'], ARRAY[1.0]), MAP(ARRAY['a'], ARRAY[2.0])); -- 1.0

SELECT cosine_similarity(MAP(ARRAY['a', 'b'], ARRAY[1.0, 2.0]), MAP(ARRAY['a', 'b'], ARRAY[NULL, 3.0])); -- NULL

SELECT cosine_similarity(MAP(ARRAY[], ARRAY[]), MAP(ARRAY['a', 'b'], ARRAY[2, 3])); -- NaN

.. function:: degrees(x) -> double

Converts angle x in radians to degrees.
Expand Down
52 changes: 52 additions & 0 deletions velox/functions/prestosql/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "folly/CPortability.h"
#include "velox/common/base/Exceptions.h"
#include "velox/functions/Macros.h"
#include "velox/functions/Udf.h"
#include "velox/functions/prestosql/ArithmeticImpl.h"

namespace facebook::velox::functions {
Expand Down Expand Up @@ -531,5 +532,56 @@ struct WilsonIntervalLowerFunction {
}
};

template <typename T>
struct CosineSimilarityFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

double normalizeMap(const null_free_arg_type<Map<Varchar, double>>& map) {
double norm = 0.0;
for (const auto& [key, value] : map) {
norm += (value * value);
}
return std::sqrt(norm);
}

double mapDotProduct(
const null_free_arg_type<Map<Varchar, double>>& leftMap,
const null_free_arg_type<Map<Varchar, double>>& rightMap) {
double result = 0.0;
for (const auto& [key, value] : leftMap) {
auto it = rightMap.find(key);
if (it != rightMap.end()) {
result += value * it->second;
}
}
return result;
}

void callNullFree(
out_type<double>& result,
const null_free_arg_type<Map<Varchar, double>>& leftMap,
const null_free_arg_type<Map<Varchar, double>>& rightMap) {
if (leftMap.empty() || rightMap.empty()) {
result = std::numeric_limits<double>::quiet_NaN();
return;
}

double normLeftMap = normalizeMap(leftMap);
if (normLeftMap == 0.0) {
result = std::numeric_limits<double>::quiet_NaN();
return;
}

double normRightMap = normalizeMap(rightMap);
if (normRightMap == 0.0) {
result = std::numeric_limits<double>::quiet_NaN();
return;
}

double dotProduct = mapDotProduct(leftMap, rightMap);
result = dotProduct / (normLeftMap * normRightMap);
}
};

} // namespace
} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ void registerSimpleFunctions(const std::string& prefix) {
int64_t,
int64_t,
double>({prefix + "wilson_interval_lower"});
registerFunction<
CosineSimilarityFunction,
double,
Map<Varchar, double>,
Map<Varchar, double>>({prefix + "cosine_similarity"});
}

} // namespace
Expand Down
46 changes: 46 additions & 0 deletions velox/functions/prestosql/tests/ArithmeticTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,5 +821,51 @@ TEST_F(ArithmeticTest, wilsonIntervalUpper) {
EXPECT_DOUBLE_EQ(wilsonIntervalUpper(1, 3, kInf).value(), 1.0);
}

TEST_F(ArithmeticTest, cosineSimilarity) {
const auto cosineSimilarity =
[&](const std::vector<std::pair<std::string, std::optional<double>>>&
left,
const std::vector<std::pair<std::string, std::optional<double>>>&
right) {
auto leftMap = makeMapVector<std::string, double>({left});
auto rightMap = makeMapVector<std::string, double>({right});
return evaluateOnce<double>(
"cosine_similarity(c0,c1)",
makeRowVector({leftMap, rightMap}))
.value();
};

EXPECT_DOUBLE_EQ(
(2.0 * 3.0) / (std::sqrt(5.0) * std::sqrt(10.0)),
cosineSimilarity({{"a", 1}, {"b", 2}}, {{"c", 1}, {"b", 3}}));

EXPECT_DOUBLE_EQ(
(2.0 * 3.0 + (-1) * 1) / (std::sqrt(1 + 4 + 1) * std::sqrt(1 + 9)),
cosineSimilarity({{"a", 1}, {"b", 2}, {"c", -1}}, {{"c", 1}, {"b", 3}}));

EXPECT_DOUBLE_EQ(
(2.0 * 3.0 + (-1) * 1) / (std::sqrt(1 + 4 + 1) * std::sqrt(1 + 9)),
cosineSimilarity({{"a", 1}, {"b", 2}, {"c", -1}}, {{"c", 1}, {"b", 3}}));

EXPECT_DOUBLE_EQ(
0.0,
cosineSimilarity({{"a", 1}, {"b", 2}, {"c", -1}}, {{"d", 1}, {"e", 3}}));

EXPECT_TRUE(std::isnan(cosineSimilarity({}, {})));
EXPECT_TRUE(std::isnan(cosineSimilarity({{"d", 1}, {"e", 3}}, {})));
EXPECT_TRUE(
std::isnan(cosineSimilarity({{"a", 1}, {"b", 3}}, {{"a", 0}, {"b", 0}})));

auto nullableLeftMap = makeNullableMapVector<StringView, double>(
{{{{"a"_sv, 1}, {"b"_sv, std::nullopt}}}});
auto rightMap =
makeMapVector<StringView, double>({{{{"c"_sv, 1}, {"b"_sv, 3}}}});

EXPECT_FALSE(evaluateOnce<double>(
"cosine_similarity(c0,c1)",
makeRowVector({nullableLeftMap, rightMap}))
.has_value());
}

} // namespace
} // namespace facebook::velox

0 comments on commit fb34b0f

Please sign in to comment.