Skip to content

Commit

Permalink
Add custom comparator to IPAddressType (facebookincubator#11347)
Browse files Browse the repository at this point in the history
Summary:

We need a comparator for IPAddressType.

- Provide  providesCustomComparison = true to constructor to support custom comparison
- Compare by converting the int128_t/BIGINT to a byteArray, and reversing the byte order. We then use memcmp
- support between for IPAddressType
- register the type to saber since comparison now takes IpAddress

Differential Revision: D64880148
  • Loading branch information
yuandagits authored and facebook-github-bot committed Oct 25, 2024
1 parent c7fe8e7 commit 00ff7b8
Show file tree
Hide file tree
Showing 7 changed files with 374 additions and 17 deletions.
5 changes: 5 additions & 0 deletions velox/functions/prestosql/GreatestLeast.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ struct ExtremeValueFunction {
return wrapper;
}

int128_t extractValue(
const exec::CustomTypeWithCustomComparisonView<int128_t>& wrapper) const {
return *wrapper;
}

int64_t extractValue(
const exec::CustomTypeWithCustomComparisonView<int64_t>& wrapper) const {
return *wrapper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "velox/functions/Registerer.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/Comparisons.h"
#include "velox/functions/prestosql/types/IPAddressType.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/type/Type.h"

Expand All @@ -30,13 +31,15 @@ void registerNonSimdizableScalar(const std::vector<std::string>& aliases) {
registerFunction<T, TReturn, Timestamp, Timestamp>(aliases);
registerFunction<T, TReturn, TimestampWithTimezone, TimestampWithTimezone>(
aliases);
registerFunction<T, TReturn, IPAddress, IPAddress>(aliases);
}
} // namespace

void registerComparisonFunctions(const std::string& prefix) {
// Comparison functions also need TimestampWithTimezoneType,
// independent of DateTimeFunctions
registerTimestampWithTimeZoneType();
registerIPAddressType();

registerNonSimdizableScalar<EqFunction, bool>({prefix + "eq"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_eq, prefix + "eq");
Expand Down Expand Up @@ -116,6 +119,8 @@ void registerComparisonFunctions(const std::string& prefix) {
TimestampWithTimezone,
TimestampWithTimezone,
TimestampWithTimezone>({prefix + "between"});
registerFunction<BetweenFunction, bool, IPAddress, IPAddress, IPAddress>(
{prefix + "between"});
}

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "velox/functions/prestosql/GreatestLeast.h"
#include "velox/functions/prestosql/InPredicate.h"
#include "velox/functions/prestosql/Reduce.h"
#include "velox/functions/prestosql/types/IPAddressType.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

namespace facebook::velox::functions {
Expand Down Expand Up @@ -57,6 +58,7 @@ void registerAllGreatestLeastFunctions(const std::string& prefix) {
registerGreatestLeastFunction<Date>(prefix);
registerGreatestLeastFunction<Timestamp>(prefix);
registerGreatestLeastFunction<TimestampWithTimezone>(prefix);
registerGreatestLeastFunction<IPAddress>(prefix);
}
} // namespace

Expand Down
230 changes: 230 additions & 0 deletions velox/functions/prestosql/tests/ComparisonsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "velox/functions/Udf.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/IPAddressType.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/type/tests/utils/CustomTypesForTesting.h"
#include "velox/type/tz/TimeZoneMap.h"
Expand Down Expand Up @@ -1182,6 +1183,235 @@ TEST_F(ComparisonsTest, TimestampWithTimezone) {
false}));
}

TEST_F(ComparisonsTest, IpAddressType) {
auto makeIpAdressFromString = [](const std::string& ipAddr) -> int128_t {
auto ret = tryGetIPv6asInt128FromString(ipAddr);
return ret.value();
};

auto runAndCompare = [&](const std::string& expr,
RowVectorPtr inputs,
VectorPtr expectedResult) {
auto actual = evaluate<SimpleVector<bool>>(expr, inputs);
test::assertEqualVectors(expectedResult, actual);
};

auto lhs = makeFlatVector<int128_t>(
std::vector<int128_t>{
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("1.2.3.4"),
makeIpAdressFromString("1.1.1.2"),
makeIpAdressFromString("1.1.2.1"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("::1"),
makeIpAdressFromString("2001:0db8:0000:0000:0000:ff00:0042:8329"),
makeIpAdressFromString("::ffff:1.2.3.4"),
makeIpAdressFromString("::ffff:0.1.1.1"),
makeIpAdressFromString("::FFFF:FFFF:FFFF"),
makeIpAdressFromString("::0001:255.255.255.255"),
makeIpAdressFromString("::ffff:ffff:ffff"),
},
IPADDRESS());

auto rhs = makeFlatVector<int128_t>(
std::vector<int128_t>{
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("1.1.1.2"),
makeIpAdressFromString("1.1.2.1"),
makeIpAdressFromString("255.1.1.1"),
makeIpAdressFromString("::1"),
makeIpAdressFromString("2001:db8::ff00:42:8329"),
makeIpAdressFromString("1.2.3.4"),
makeIpAdressFromString("::ffff:1.1.1.0"),
makeIpAdressFromString("::0001:255.255.255.255"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("255.255.255.255"),
},
IPADDRESS());

auto input = makeRowVector({lhs, rhs});

runAndCompare(
"c0 = c1",
input,
makeFlatVector<bool>(
{true,
true,
false,
false,
false,
false,
false,
true,
true,
true,
false,
false,
false,
true}));

runAndCompare(
"c0 <> c1",
input,
makeFlatVector<bool>(
{false,
false,
true,
true,
true,
true,
true,
false,
false,
false,
true,
true,
true,
false}));

runAndCompare(
"c0 < c1",
input,
makeFlatVector<bool>(
{false,
false,
false,
false,
false,
true,
true,
false,
false,
false,
true,
false,
true,
false}));

runAndCompare(
"c0 > c1",
input,
makeFlatVector<bool>(
{false,
false,
true,
true,
true,
false,
false,
false,
false,
false,
false,
true,
false,
false}));

runAndCompare(
"c0 <= c1",
input,
makeFlatVector<bool>(
{true,
true,
false,
false,
false,
true,
true,
true,
true,
true,
true,
false,
true,
true}));

runAndCompare(
"c0 >= c1",
input,
makeFlatVector<bool>(
{true,
true,
true,
true,
true,
false,
false,
true,
true,
true,
false,
true,
false,
true}));

runAndCompare(
"c0 is distinct from c1",
input,
makeFlatVector<bool>(
{false,
false,
true,
true,
true,
true,
true,
false,
false,
false,
true,
true,
true,
false}));

auto betweenInput = makeRowVector({
makeFlatVector<int128_t>(
std::vector<int128_t>{
makeIpAdressFromString("2001:db8::ff00:42:8329"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("::ffff:1.1.1.1"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("0.0.0.0"),
makeIpAdressFromString("::ffff"),
makeIpAdressFromString("0.0.0.0")},
IPADDRESS()),
makeFlatVector<int128_t>(
std::vector<int128_t>{
makeIpAdressFromString("::ffff"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("::ffff:0.1.1.1"),
makeIpAdressFromString("0.1.1.1"),
makeIpAdressFromString("0.0.0.1"),
makeIpAdressFromString("::ffff:0.0.0.1"),
makeIpAdressFromString("2001:db8::0:0:0:1")},
IPADDRESS()),
makeFlatVector<int128_t>(
std::vector<int128_t>{
makeIpAdressFromString("2001:db8::ff00:42:8329"),
makeIpAdressFromString("1.1.1.1"),
makeIpAdressFromString("255.255.255.255"),
makeIpAdressFromString("2001:0db8:0000:0000:0000:ff00:0042:8329"),
makeIpAdressFromString("2001:0db8:0000:0000:0000:ff00:0042:8329"),
makeIpAdressFromString("0.0.0.2"),
makeIpAdressFromString("0.0.0.2"),
makeIpAdressFromString("2001:db8::1:0:0:1")},
IPADDRESS()),
});

runAndCompare(
"c0 between c1 and c2",
betweenInput,
makeFlatVector<bool>(
{true, true, true, true, true, false, false, false}));
}

TEST_F(ComparisonsTest, CustomComparisonWithGenerics) {
// Tests that functions that support signatures with generics handle custom
// comparison correctly.
Expand Down
18 changes: 3 additions & 15 deletions velox/functions/prestosql/types/IPAddressType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,8 @@
*/

#include "velox/functions/prestosql/types/IPAddressType.h"
#include <folly/IPAddress.h>
#include "velox/expression/CastExpr.h"

static constexpr int kIPV4AddressBytes = 4;
static constexpr int kIPV4ToV6FFIndex = 10;
static constexpr int kIPV4ToV6Index = 12;
static constexpr int kIPAddressBytes = 16;

namespace facebook::velox {

namespace {
Expand Down Expand Up @@ -123,9 +117,9 @@ class IPAddressCastOperator : public exec::CastOperator {

context.applyToSelectedNoThrow(rows, [&](auto row) {
const auto ipAddressString = ipAddressStrings->valueAt(row);
auto maybeIpAsInt128 = tryGetIPv6asInt128FromString(ipAddressString);

auto maybeIp = folly::IPAddress::tryFromString(ipAddressString);
if (maybeIp.hasError()) {
if (maybeIpAsInt128.hasError()) {
if (threadSkipErrorDetails()) {
context.setStatus(row, Status::UserError());
} else {
Expand All @@ -135,13 +129,7 @@ class IPAddressCastOperator : public exec::CastOperator {
}
return;
}
folly::IPAddress addr = maybeIp.value();
auto addrBytes = folly::IPAddress::createIPv6(addr).toByteArray();

std::reverse(addrBytes.begin(), addrBytes.end());
memcpy(&intAddr, &addrBytes, kIPAddressBytes);

flatResult->set(row, intAddr);
flatResult->set(row, maybeIpAsInt128.value());
});
}

Expand Down
Loading

0 comments on commit 00ff7b8

Please sign in to comment.