From b3e0e6851ef511b52fe8e196ac5290e76e2776e8 Mon Sep 17 00:00:00 2001 From: Yenda Li Date: Thu, 24 Oct 2024 14:31:34 -0700 Subject: [PATCH] Add custom comparator to IPAddressType Summary: We need a comparator for IPAddressType. - Provide providesCustomComparison = true to constructor to support custom comparison - Compare by converting the int128_t/BIGINT to a byteArray, and reversing the byte order. We then use memcmp - support between for IPAddressType Differential Revision: D64880148 --- velox/functions/prestosql/GreatestLeast.h | 5 + .../ComparisonFunctionsRegistration.cpp | 5 + .../GeneralFunctionsRegistration.cpp | 2 + .../prestosql/tests/ComparisonsTest.cpp | 230 ++++++++++++++++++ .../prestosql/types/IPAddressType.cpp | 17 +- .../functions/prestosql/types/IPAddressType.h | 48 +++- .../types/tests/IPAddressTypeTest.cpp | 84 +++++++ 7 files changed, 375 insertions(+), 16 deletions(-) diff --git a/velox/functions/prestosql/GreatestLeast.h b/velox/functions/prestosql/GreatestLeast.h index a5a0a23e00d56..2406c34a18bfc 100644 --- a/velox/functions/prestosql/GreatestLeast.h +++ b/velox/functions/prestosql/GreatestLeast.h @@ -66,6 +66,11 @@ struct ExtremeValueFunction { return wrapper; } + int128_t extractValue( + const exec::CustomTypeWithCustomComparisonView& wrapper) const { + return *wrapper; + } + int64_t extractValue( const exec::CustomTypeWithCustomComparisonView& wrapper) const { return *wrapper; diff --git a/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp index b1f34da9c6296..2870a3ff787f0 100644 --- a/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp @@ -16,6 +16,7 @@ #include "velox/functions/Registerer.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/prestosql/Comparisons.h" +#include "velox/functions/prestosql/types/IPAddressType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/type/Type.h" @@ -30,6 +31,7 @@ void registerNonSimdizableScalar(const std::vector& aliases) { registerFunction(aliases); registerFunction( aliases); + registerFunction(aliases); } } // namespace @@ -37,6 +39,7 @@ void registerComparisonFunctions(const std::string& prefix) { // Comparison functions also need TimestampWithTimezoneType, // independent of DateTimeFunctions registerTimestampWithTimeZoneType(); + registerIPAddressType(); registerNonSimdizableScalar({prefix + "eq"}); VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_eq, prefix + "eq"); @@ -116,6 +119,8 @@ void registerComparisonFunctions(const std::string& prefix) { TimestampWithTimezone, TimestampWithTimezone, TimestampWithTimezone>({prefix + "between"}); + registerFunction( + {prefix + "between"}); } } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp index dda91e413da91..a33d0a62a6bf7 100644 --- a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp @@ -21,6 +21,7 @@ #include "velox/functions/prestosql/GreatestLeast.h" #include "velox/functions/prestosql/InPredicate.h" #include "velox/functions/prestosql/Reduce.h" +#include "velox/functions/prestosql/types/IPAddressType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" namespace facebook::velox::functions { @@ -57,6 +58,7 @@ void registerAllGreatestLeastFunctions(const std::string& prefix) { registerGreatestLeastFunction(prefix); registerGreatestLeastFunction(prefix); registerGreatestLeastFunction(prefix); + registerGreatestLeastFunction(prefix); } } // namespace diff --git a/velox/functions/prestosql/tests/ComparisonsTest.cpp b/velox/functions/prestosql/tests/ComparisonsTest.cpp index 14be9d7a8cd60..1da26378fa4cd 100644 --- a/velox/functions/prestosql/tests/ComparisonsTest.cpp +++ b/velox/functions/prestosql/tests/ComparisonsTest.cpp @@ -19,6 +19,7 @@ #include "velox/functions/Udf.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/prestosql/types/IPAddressType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/type/tests/utils/CustomTypesForTesting.h" #include "velox/type/tz/TimeZoneMap.h" @@ -1182,6 +1183,235 @@ TEST_F(ComparisonsTest, TimestampWithTimezone) { false})); } +TEST_F(ComparisonsTest, IpAddressType) { + auto makeIpAdressFromString = [](const std::string& ipAddr) -> int128_t { + auto ret = tryGetIPv6asInt128FromString(ipAddr); + return ret.value(); + }; + + auto runAndCompare = [&](const std::string& expr, + RowVectorPtr inputs, + VectorPtr expectedResult) { + auto actual = evaluate>(expr, inputs); + test::assertEqualVectors(expectedResult, actual); + }; + + auto lhs = makeFlatVector( + std::vector{ + makeIpAdressFromString("1.1.1.1"), + makeIpAdressFromString("255.255.255.255"), + makeIpAdressFromString("1.2.3.4"), + makeIpAdressFromString("1.1.1.2"), + makeIpAdressFromString("1.1.2.1"), + makeIpAdressFromString("1.1.1.1"), + makeIpAdressFromString("1.1.1.1"), + makeIpAdressFromString("::1"), + makeIpAdressFromString("2001:0db8:0000:0000:0000:ff00:0042:8329"), + makeIpAdressFromString("::ffff:1.2.3.4"), + makeIpAdressFromString("::ffff:0.1.1.1"), + makeIpAdressFromString("::FFFF:FFFF:FFFF"), + makeIpAdressFromString("::0001:255.255.255.255"), + makeIpAdressFromString("::ffff:ffff:ffff"), + }, + IPADDRESS()); + + auto rhs = makeFlatVector( + std::vector{ + makeIpAdressFromString("1.1.1.1"), + makeIpAdressFromString("255.255.255.255"), + makeIpAdressFromString("1.1.1.1"), + makeIpAdressFromString("1.1.1.1"), + makeIpAdressFromString("1.1.1.2"), + makeIpAdressFromString("1.1.2.1"), + makeIpAdressFromString("255.1.1.1"), + makeIpAdressFromString("::1"), + makeIpAdressFromString("2001:db8::ff00:42:8329"), + makeIpAdressFromString("1.2.3.4"), + makeIpAdressFromString("::ffff:1.1.1.0"), + makeIpAdressFromString("::0001:255.255.255.255"), + makeIpAdressFromString("255.255.255.255"), + makeIpAdressFromString("255.255.255.255"), + }, + IPADDRESS()); + + auto input = makeRowVector({lhs, rhs}); + + runAndCompare( + "c0 = c1", + input, + makeFlatVector( + {true, + true, + false, + false, + false, + false, + false, + true, + true, + true, + false, + false, + false, + true})); + + runAndCompare( + "c0 <> c1", + input, + makeFlatVector( + {false, + false, + true, + true, + true, + true, + true, + false, + false, + false, + true, + true, + true, + false})); + + runAndCompare( + "c0 < c1", + input, + makeFlatVector( + {false, + false, + false, + false, + false, + true, + true, + false, + false, + false, + true, + false, + true, + false})); + + runAndCompare( + "c0 > c1", + input, + makeFlatVector( + {false, + false, + true, + true, + true, + false, + false, + false, + false, + false, + false, + true, + false, + false})); + + runAndCompare( + "c0 <= c1", + input, + makeFlatVector( + {true, + true, + false, + false, + false, + true, + true, + true, + true, + true, + true, + false, + true, + true})); + + runAndCompare( + "c0 >= c1", + input, + makeFlatVector( + {true, + true, + true, + true, + true, + false, + false, + true, + true, + true, + false, + true, + false, + true})); + + runAndCompare( + "c0 is distinct from c1", + input, + makeFlatVector( + {false, + false, + true, + true, + true, + true, + true, + false, + false, + false, + true, + true, + true, + false})); + + auto betweenInput = makeRowVector({ + makeFlatVector( + std::vector{ + makeIpAdressFromString("2001:db8::ff00:42:8329"), + makeIpAdressFromString("1.1.1.1"), + makeIpAdressFromString("255.255.255.255"), + makeIpAdressFromString("::ffff:1.1.1.1"), + makeIpAdressFromString("1.1.1.1"), + makeIpAdressFromString("0.0.0.0"), + makeIpAdressFromString("::ffff"), + makeIpAdressFromString("0.0.0.0")}, + IPADDRESS()), + makeFlatVector( + std::vector{ + makeIpAdressFromString("::ffff"), + makeIpAdressFromString("1.1.1.1"), + makeIpAdressFromString("255.255.255.255"), + makeIpAdressFromString("::ffff:0.1.1.1"), + makeIpAdressFromString("0.1.1.1"), + makeIpAdressFromString("0.0.0.1"), + makeIpAdressFromString("::ffff:0.0.0.1"), + makeIpAdressFromString("2001:db8::0:0:0:1")}, + IPADDRESS()), + makeFlatVector( + std::vector{ + makeIpAdressFromString("2001:db8::ff00:42:8329"), + makeIpAdressFromString("1.1.1.1"), + makeIpAdressFromString("255.255.255.255"), + makeIpAdressFromString("2001:0db8:0000:0000:0000:ff00:0042:8329"), + makeIpAdressFromString("2001:0db8:0000:0000:0000:ff00:0042:8329"), + makeIpAdressFromString("0.0.0.2"), + makeIpAdressFromString("0.0.0.2"), + makeIpAdressFromString("2001:db8::1:0:0:1")}, + IPADDRESS()), + }); + + runAndCompare( + "c0 between c1 and c2", + betweenInput, + makeFlatVector( + {true, true, true, true, true, false, false, false})); +} + TEST_F(ComparisonsTest, CustomComparisonWithGenerics) { // Tests that functions that support signatures with generics handle custom // comparison correctly. diff --git a/velox/functions/prestosql/types/IPAddressType.cpp b/velox/functions/prestosql/types/IPAddressType.cpp index 691ca0a28ce2b..c59ee855c6d74 100644 --- a/velox/functions/prestosql/types/IPAddressType.cpp +++ b/velox/functions/prestosql/types/IPAddressType.cpp @@ -18,11 +18,6 @@ #include #include "velox/expression/CastExpr.h" -static constexpr int kIPV4AddressBytes = 4; -static constexpr int kIPV4ToV6FFIndex = 10; -static constexpr int kIPV4ToV6Index = 12; -static constexpr int kIPAddressBytes = 16; - namespace facebook::velox { namespace { @@ -123,9 +118,9 @@ class IPAddressCastOperator : public exec::CastOperator { context.applyToSelectedNoThrow(rows, [&](auto row) { const auto ipAddressString = ipAddressStrings->valueAt(row); + auto maybeIpAsInt128 = tryGetIPv6asInt128FromString(ipAddressString); - auto maybeIp = folly::IPAddress::tryFromString(ipAddressString); - if (maybeIp.hasError()) { + if (maybeIpAsInt128.hasError()) { if (threadSkipErrorDetails()) { context.setStatus(row, Status::UserError()); } else { @@ -135,13 +130,7 @@ class IPAddressCastOperator : public exec::CastOperator { } return; } - folly::IPAddress addr = maybeIp.value(); - auto addrBytes = folly::IPAddress::createIPv6(addr).toByteArray(); - - std::reverse(addrBytes.begin(), addrBytes.end()); - memcpy(&intAddr, &addrBytes, kIPAddressBytes); - - flatResult->set(row, intAddr); + flatResult->set(row, maybeIpAsInt128.value()); }); } diff --git a/velox/functions/prestosql/types/IPAddressType.h b/velox/functions/prestosql/types/IPAddressType.h index e1e2d9fc1bf28..48a29ece83833 100644 --- a/velox/functions/prestosql/types/IPAddressType.h +++ b/velox/functions/prestosql/types/IPAddressType.h @@ -15,13 +15,35 @@ */ #pragma once +#include + #include "velox/type/SimpleFunctionApi.h" #include "velox/type/Type.h" namespace facebook::velox { +constexpr int kIPV4AddressBytes = 4; +constexpr int kIPV4ToV6FFIndex = 10; +constexpr int kIPV4ToV6Index = 12; +constexpr int kIPAddressBytes = 16; + +inline folly::Expected +tryGetIPv6asInt128FromString(const std::string& ipAddressStr) { + auto maybeIp = folly::IPAddress::tryFromString(ipAddressStr); + if (maybeIp.hasError()) { + return folly::makeUnexpected(maybeIp.error()); + } + + int128_t intAddr; + folly::IPAddress addr = maybeIp.value(); + auto addrBytes = folly::IPAddress::createIPv6(addr).toByteArray(); + std::reverse(addrBytes.begin(), addrBytes.end()); + memcpy(&intAddr, &addrBytes, kIPAddressBytes); + return intAddr; +} + class IPAddressType : public HugeintType { - IPAddressType() = default; + IPAddressType() : HugeintType(/*providesCustomComparison*/ true) {} public: static const std::shared_ptr& get() { @@ -31,6 +53,17 @@ class IPAddressType : public HugeintType { return instance; } + int32_t compare(const int128_t& left, const int128_t& right) const override { + const auto leftAddrBytes = toIPv6ByteArray(left); + const auto rightAddrBytes = toIPv6ByteArray(right); + return memcmp( + leftAddrBytes.begin(), rightAddrBytes.begin(), kIPAddressBytes); + } + + uint64_t hash(const int128_t& value) const override { + return folly::hasher()(value); + } + bool equivalent(const Type& other) const override { // Pointer comparison works since this type is a singleton. return this == &other; @@ -50,6 +83,17 @@ class IPAddressType : public HugeintType { obj["type"] = name(); return obj; } + + private: + static std::array toIPv6ByteArray( + const int128_t& ipAddr) { + std::array bytes{{0}}; + memcpy(bytes.data(), &ipAddr, sizeof(ipAddr)); + // Reverse because the velox is always on little endian system + // and the byte array needs to be big endian (network byte order) + std::reverse(bytes.begin(), bytes.end()); + return bytes; + } }; FOLLY_ALWAYS_INLINE bool isIPAddressType(const TypePtr& type) { @@ -67,7 +111,7 @@ struct IPAddressT { static constexpr const char* typeName = "ipaddress"; }; -using IPAddress = CustomType; +using IPAddress = CustomType; void registerIPAddressType(); diff --git a/velox/functions/prestosql/types/tests/IPAddressTypeTest.cpp b/velox/functions/prestosql/types/tests/IPAddressTypeTest.cpp index f024c785d94b3..7c4da640c9fea 100644 --- a/velox/functions/prestosql/types/tests/IPAddressTypeTest.cpp +++ b/velox/functions/prestosql/types/tests/IPAddressTypeTest.cpp @@ -23,6 +23,11 @@ class IPAddressTypeTest : public testing::Test, public TypeTestBase { IPAddressTypeTest() { registerIPAddressType(); } + + int128_t getIPv6asInt128FromStringUnchecked(const std::string& ipAddr) { + auto ret = tryGetIPv6asInt128FromString(ipAddr); + return ret.value(); + } }; TEST_F(IPAddressTypeTest, basic) { @@ -38,4 +43,83 @@ TEST_F(IPAddressTypeTest, basic) { TEST_F(IPAddressTypeTest, serde) { testTypeSerde(IPADDRESS()); } + +TEST_F(IPAddressTypeTest, compare) { + auto ipAddr = IPADDRESS(); + // Baisc IPv4 test + ASSERT_EQ( + 0, + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("1.1.1.1"), + getIPv6asInt128FromStringUnchecked("1.1.1.1"))); + ASSERT_EQ( + 0, + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("255.255.255.255"), + getIPv6asInt128FromStringUnchecked("255.255.255.255"))); + ASSERT_GT( + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("1.2.3.4"), + getIPv6asInt128FromStringUnchecked("1.1.1.1")), + 0); + ASSERT_GT( + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("1.1.1.2"), + getIPv6asInt128FromStringUnchecked("1.1.1.1")), + 0); + ASSERT_GT( + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("1.1.2.1"), + getIPv6asInt128FromStringUnchecked("1.1.1.2")), + 0); + ASSERT_LT( + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("1.1.1.1"), + getIPv6asInt128FromStringUnchecked("1.1.2.1")), + 0); + ASSERT_LT( + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("1.1.1.1"), + getIPv6asInt128FromStringUnchecked("255.1.1.1")), + 0); + + // Basic IPv6 test + ASSERT_EQ( + 0, + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("::1"), + getIPv6asInt128FromStringUnchecked("::1"))); + ASSERT_EQ( + 0, + ipAddr->compare( + getIPv6asInt128FromStringUnchecked( + "2001:0db8:0000:0000:0000:ff00:0042:8329"), + getIPv6asInt128FromStringUnchecked("2001:db8::ff00:42:8329"))); + ASSERT_EQ( + 0, + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("::ffff:1.2.3.4"), + getIPv6asInt128FromStringUnchecked("1.2.3.4"))); + ASSERT_LT( + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("::ffff:0.1.1.1"), + getIPv6asInt128FromStringUnchecked("::ffff:1.1.1.0")), + 0); + + ASSERT_GT( + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("::FFFF:FFFF:FFFF"), + getIPv6asInt128FromStringUnchecked("::0001:255.255.255.255")), + 0); + ASSERT_LT( + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("::0001:255.255.255.255"), + getIPv6asInt128FromStringUnchecked("255.255.255.255")), + 0); + ASSERT_EQ( + 0, + ipAddr->compare( + getIPv6asInt128FromStringUnchecked("::ffff:ffff:ffff"), + getIPv6asInt128FromStringUnchecked("255.255.255.255"))); +} } // namespace facebook::velox::test