Skip to content

Commit

Permalink
Add custom comparator to IPAddressType (facebookincubator#11347)
Browse files Browse the repository at this point in the history
Summary:

We need a comparator for IPAddressType.

- Provide  providesCustomComparison = true to constructor to support custom comparison
- Compare by converting the int128_t/BIGINT to a byteArray, and reversing the byte order. We then use memcmp
- support between for IPAddressType
- register the type to saber since comparison now takes IpAddress

Differential Revision: D64880148
  • Loading branch information
yuandagits authored and facebook-github-bot committed Oct 30, 2024
1 parent b1834bd commit d4eb5f7
Show file tree
Hide file tree
Showing 8 changed files with 463 additions and 31 deletions.
5 changes: 3 additions & 2 deletions velox/functions/prestosql/GreatestLeast.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ struct ExtremeValueFunction {
return wrapper;
}

int64_t extractValue(
const exec::CustomTypeWithCustomComparisonView<int64_t>& wrapper) const {
template <typename U>
U extractValue(
const exec::CustomTypeWithCustomComparisonView<U>& wrapper) const {
return *wrapper;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "velox/functions/Registerer.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/Comparisons.h"
#include "velox/functions/prestosql/types/IPAddressType.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/type/Type.h"

Expand All @@ -30,13 +31,15 @@ void registerNonSimdizableScalar(const std::vector<std::string>& aliases) {
registerFunction<T, TReturn, Timestamp, Timestamp>(aliases);
registerFunction<T, TReturn, TimestampWithTimezone, TimestampWithTimezone>(
aliases);
registerFunction<T, TReturn, IPAddress, IPAddress>(aliases);
}
} // namespace

void registerComparisonFunctions(const std::string& prefix) {
// Comparison functions also need TimestampWithTimezoneType,
// independent of DateTimeFunctions
registerTimestampWithTimeZoneType();
registerIPAddressType();

registerNonSimdizableScalar<EqFunction, bool>({prefix + "eq"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_eq, prefix + "eq");
Expand Down Expand Up @@ -116,6 +119,8 @@ void registerComparisonFunctions(const std::string& prefix) {
TimestampWithTimezone,
TimestampWithTimezone,
TimestampWithTimezone>({prefix + "between"});
registerFunction<BetweenFunction, bool, IPAddress, IPAddress, IPAddress>(
{prefix + "between"});
}

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "velox/functions/prestosql/GreatestLeast.h"
#include "velox/functions/prestosql/InPredicate.h"
#include "velox/functions/prestosql/Reduce.h"
#include "velox/functions/prestosql/types/IPAddressType.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

namespace facebook::velox::functions {
Expand Down Expand Up @@ -57,6 +58,7 @@ void registerAllGreatestLeastFunctions(const std::string& prefix) {
registerGreatestLeastFunction<Date>(prefix);
registerGreatestLeastFunction<Timestamp>(prefix);
registerGreatestLeastFunction<TimestampWithTimezone>(prefix);
registerGreatestLeastFunction<IPAddress>(prefix);
}
} // namespace

Expand Down
264 changes: 264 additions & 0 deletions velox/functions/prestosql/tests/ComparisonsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "velox/functions/Udf.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/IPAddressType.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/type/tests/utils/CustomTypesForTesting.h"
#include "velox/type/tz/TimeZoneMap.h"
Expand Down Expand Up @@ -1182,6 +1183,269 @@ TEST_F(ComparisonsTest, TimestampWithTimezone) {
false}));
}

TEST_F(ComparisonsTest, IpAddressType) {
auto makeIpAdressFromString = [](const std::string& ipAddr) -> int128_t {
auto ret = ipaddress::tryGetIPv6asInt128FromString(ipAddr);
return ret.value();
};

auto runAndCompare = [&](const std::string& expr,
RowVectorPtr inputs,
VectorPtr expectedResult) {
auto actual = evaluate<SimpleVector<bool>>(expr, inputs);
test::assertEqualVectors(expectedResult, actual);
};

auto lhs = makeNullableFlatVector<int128_t>(
{
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("1.2.3.4"),
makeIpAdressFromString("1.1.1.2"),
makeIpAdressFromString("1.1.2.1"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("::1"),
makeIpAdressFromString("2001:0db8:0000:0000:0000:ff00:0042:8329"),
makeIpAdressFromString("::ffff:1.2.3.4"),
makeIpAdressFromString("::ffff:0.1.1.1"),
makeIpAdressFromString("::FFFF:FFFF:FFFF"),
makeIpAdressFromString("::0001:255.255.255.255"),
makeIpAdressFromString("::ffff:ffff:ffff"),
std::nullopt,
makeIpAdressFromString("::0001:255.255.255.255"),
},
IPADDRESS());

auto rhs = makeNullableFlatVector<int128_t>(
{
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("1.1.1.2"),
makeIpAdressFromString("1.1.2.1"),
makeIpAdressFromString("255.1.1.1"),
makeIpAdressFromString("::1"),
makeIpAdressFromString("2001:db8::ff00:42:8329"),
makeIpAdressFromString("1.2.3.4"),
makeIpAdressFromString("::ffff:1.1.1.0"),
makeIpAdressFromString("::0001:255.255.255.255"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("255.255.255.255"),
std::nullopt,
},
IPADDRESS());

auto input = makeRowVector({lhs, rhs});

runAndCompare(
"c0 = c1",
input,
makeNullableFlatVector<bool>(
{true,
true,
false,
false,
false,
false,
false,
true,
true,
true,
false,
false,
false,
true,
std::nullopt,
std::nullopt}));

runAndCompare(
"c0 <> c1",
input,
makeNullableFlatVector<bool>(
{false,
false,
true,
true,
true,
true,
true,
false,
false,
false,
true,
true,
true,
false,
std::nullopt,
std::nullopt}));

runAndCompare(
"c0 < c1",
input,
makeNullableFlatVector<bool>(
{false,
false,
false,
false,
false,
true,
true,
false,
false,
false,
true,
false,
true,
false,
std::nullopt,
std::nullopt}));

runAndCompare(
"c0 > c1",
input,
makeNullableFlatVector<bool>(
{false,
false,
true,
true,
true,
false,
false,
false,
false,
false,
false,
true,
false,
false,
std::nullopt,
std::nullopt}));

runAndCompare(
"c0 <= c1",
input,
makeNullableFlatVector<bool>(
{true,
true,
false,
false,
false,
true,
true,
true,
true,
true,
true,
false,
true,
true,
std::nullopt,
std::nullopt}));

runAndCompare(
"c0 >= c1",
input,
makeNullableFlatVector<bool>(
{true,
true,
true,
true,
true,
false,
false,
true,
true,
true,
false,
true,
false,
true,
std::nullopt,
std::nullopt}));

runAndCompare(
"c0 is distinct from c1",
input,
makeNullableFlatVector<bool>(
{false,
false,
true,
true,
true,
true,
true,
false,
false,
false,
true,
true,
true,
false,
true,
true}));

auto betweenInput = makeRowVector({
makeNullableFlatVector<int128_t>(
{makeIpAdressFromString("2001:db8::ff00:42:8329"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("::ffff:1.1.1.1"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("0.0.0.0"),
makeIpAdressFromString("::ffff"),
makeIpAdressFromString("0.0.0.0"),
std::nullopt,
makeIpAdressFromString("0.0.0.0"),
makeIpAdressFromString("0.0.0.0")},
IPADDRESS()),
makeNullableFlatVector<int128_t>(
{makeIpAdressFromString("::ffff"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("::ffff:0.1.1.1"),
makeIpAdressFromString("0.1.1.1"),
makeIpAdressFromString("0.0.0.1"),
makeIpAdressFromString("::ffff:0.0.0.1"),
makeIpAdressFromString("2001:db8::0:0:0:1"),
makeIpAdressFromString("2001:db8::0:0:0:1"),
std::nullopt,
makeIpAdressFromString("2001:db8::0:0:0:1")},
IPADDRESS()),
makeNullableFlatVector<int128_t>(
{makeIpAdressFromString("2001:db8::ff00:42:8329"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("2001:0db8:0000:0000:0000:ff00:0042:8329"),
makeIpAdressFromString("2001:0db8:0000:0000:0000:ff00:0042:8329"),
makeIpAdressFromString("0.0.0.2"),
makeIpAdressFromString("0.0.0.2"),
makeIpAdressFromString("2001:db8::1:0:0:1"),
makeIpAdressFromString("2001:db8::1:0:0:1"),
makeIpAdressFromString("2001:db8::1:0:0:1"),
std::nullopt},
IPADDRESS()),
});

runAndCompare(
"c0 between c1 and c2",
betweenInput,
makeNullableFlatVector<bool>(
{true,
true,
true,
true,
true,
false,
false,
false,
std::nullopt,
std::nullopt,
std::nullopt}));
}

TEST_F(ComparisonsTest, CustomComparisonWithGenerics) {
// Tests that functions that support signatures with generics handle custom
// comparison correctly.
Expand Down
40 changes: 40 additions & 0 deletions velox/functions/prestosql/tests/GreatestLeastTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,46 @@ TEST_F(GreatestLeastTest, leastDate) {
DATE());
}

TEST_F(GreatestLeastTest, greatestLeastIpAddress) {
auto greatest = [&](const std::optional<std::string>& a,
const std::optional<std::string>& b,
const std::optional<std::string>& c) {
return evaluateOnce<std::string>(
"cast(greatest(cast(c0 as ipaddress), cast(c1 as ipaddress), cast(c2 as ipaddress)) as varchar)",
a,
b,
c);
};

auto least = [&](const std::optional<std::string>& a,
const std::optional<std::string>& b,
const std::optional<std::string>& c) {
return evaluateOnce<std::string>(
"cast(least(cast(c0 as ipaddress), cast(c1 as ipaddress), cast(c2 as ipaddress)) as varchar)",
a,
b,
c);
};

auto greatestValue = greatest(
"1.1.1.1", "255.255.255.255", "2001:0db8:0000:0000:0000:ff00:0042:832");
EXPECT_EQ("2001:db8::ff00:42:832", greatestValue.value());

auto leastValue = least(
"1.1.1.1", "255.255.255.255", "2001:0db8:0000:0000:0000:ff00:0042:832");
EXPECT_EQ("1.1.1.1", leastValue.value());

auto greatestValueWithNulls =
greatest("1.1.1.1", "255.255.255.255", std::nullopt);
EXPECT_FALSE(greatestValueWithNulls.has_value());

auto leastValueWithNulls = least(
std::nullopt,
"255.255.255.255",
"2001:0db8:0000:0000:0000:ff00:0042:832");
EXPECT_FALSE(leastValueWithNulls.has_value());
}

TEST_F(GreatestLeastTest, stringBuffersMoved) {
runTest<StringView>(
"least(c0, c1)",
Expand Down
Loading

0 comments on commit d4eb5f7

Please sign in to comment.