diff --git a/velox/functions/prestosql/tests/IPAddressCastTest.cpp b/velox/functions/prestosql/tests/IPAddressCastTest.cpp index a13ec9114e94a..d211f031a3cfe 100644 --- a/velox/functions/prestosql/tests/IPAddressCastTest.cpp +++ b/velox/functions/prestosql/tests/IPAddressCastTest.cpp @@ -43,6 +43,13 @@ class IPAddressCastTest : public functions::test::FunctionBaseTest { input); return result; } + + auto castToIPPrefixAndBackToIpVarchar( + const std::optional& input) { + return evaluateOnce( + "cast(cast(cast(cast(cast(cast(c0 as ipaddress) as ipprefix) as varchar) as ipprefix) as ipaddress) as varchar)", + input); + } }; int128_t stringToInt128(const std::string& value) { @@ -53,6 +60,24 @@ int128_t stringToInt128(const std::string& value) { return res; } +TEST_F(IPAddressCastTest, castToIPPrefix) { + EXPECT_EQ(castToIPPrefixAndBackToIpVarchar("1.2.3.4"), "1.2.3.4"); + EXPECT_EQ(castToIPPrefixAndBackToIpVarchar("::ffff:1.2.3.4"), "1.2.3.4"); + EXPECT_EQ(castToIPPrefixAndBackToIpVarchar("::ffff:102:304"), "1.2.3.4"); + EXPECT_EQ(castToIPPrefixAndBackToIpVarchar("192.168.0.0"), "192.168.0.0"); + EXPECT_EQ( + castToIPPrefixAndBackToIpVarchar( + "2001:0db8:0000:0000:0000:ff00:0042:8329"), + "2001:db8::ff00:42:8329"); + EXPECT_EQ( + castToIPPrefixAndBackToIpVarchar("2001:db8:0:0:1:0:0:1"), + "2001:db8::1:0:0:1"); + EXPECT_EQ(castToIPPrefixAndBackToIpVarchar("::1"), "::1"); + EXPECT_EQ( + castToIPPrefixAndBackToIpVarchar("2001:db8::ff00:42:8329"), + "2001:db8::ff00:42:8329"); +} + TEST_F(IPAddressCastTest, castToVarchar) { EXPECT_EQ(castToVarchar("::ffff:1.2.3.4"), "1.2.3.4"); EXPECT_EQ(castToVarchar("0:0:0:0:0:0:13.1.68.3"), "::13.1.68.3"); diff --git a/velox/functions/prestosql/tests/IPPrefixCastTest.cpp b/velox/functions/prestosql/tests/IPPrefixCastTest.cpp index c8c77f40ec528..745e2cfa72d07 100644 --- a/velox/functions/prestosql/tests/IPPrefixCastTest.cpp +++ b/velox/functions/prestosql/tests/IPPrefixCastTest.cpp @@ -26,8 +26,47 @@ class IPPrefixTypeTest : public functions::test::FunctionBaseTest { "cast(cast(c0 as ipprefix) as varchar)", input); return result; } + + std::optional castToIpAddress( + const std::optional& input) { + return evaluateOnce( + "cast(cast(cast(c0 as ipprefix) as ipaddress) as varchar)", input); + } + + std::optional castFromIPAddress( + const std::optional& input) { + return evaluateOnce( + "cast(cast(cast(c0 as ipaddress) as ipprefix) as varchar)", input); + } }; +TEST_F(IPPrefixTypeTest, castFromIpAddress) { + EXPECT_EQ(castFromIPAddress(std::nullopt), std::nullopt); + EXPECT_EQ(castFromIPAddress("1.2.3.4"), "1.2.3.4/32"); + EXPECT_EQ(castFromIPAddress("::ffff:1.2.3.4"), "1.2.3.4/32"); + EXPECT_EQ(castFromIPAddress("::ffff:102:304"), "1.2.3.4/32"); + EXPECT_EQ(castFromIPAddress("192.168.0.0"), "192.168.0.0/32"); + EXPECT_EQ( + castFromIPAddress("2001:0db8:0000:0000:0000:ff00:0042:8329"), + "2001:db8::ff00:42:8329/128"); + EXPECT_EQ(castFromIPAddress("2001:db8:0:0:1:0:0:1"), "2001:db8::1:0:0:1/128"); + EXPECT_EQ(castFromIPAddress("::1"), "::1/128"); + EXPECT_EQ( + castFromIPAddress("2001:db8::ff00:42:8329"), + "2001:db8::ff00:42:8329/128"); + EXPECT_EQ(castFromIPAddress("2001:db8::"), "2001:db8::/128"); +} + +TEST_F(IPPrefixTypeTest, castToIpAddress) { + EXPECT_EQ(castToIpAddress(std::nullopt), std::nullopt); + EXPECT_EQ(castToIpAddress("1.2.3.4/32"), "1.2.3.4"); + EXPECT_EQ(castToIpAddress("1.2.3.4/24"), "1.2.3.0"); + EXPECT_EQ(castToIpAddress("::1/128"), "::1"); + EXPECT_EQ( + castToIpAddress("2001:db8::ff00:42:8329/128"), "2001:db8::ff00:42:8329"); + EXPECT_EQ(castToIpAddress("2001:db8::ff00:42:8329/64"), "2001:db8::"); +} + TEST_F(IPPrefixTypeTest, castToVarchar) { EXPECT_EQ(castToVarchar("::ffff:1.2.3.4/24"), "1.2.3.0/24"); EXPECT_EQ(castToVarchar("192.168.0.0/24"), "192.168.0.0/24"); diff --git a/velox/functions/prestosql/types/IPAddressType.cpp b/velox/functions/prestosql/types/IPAddressType.cpp index d9569c01f7a97..6f5a3a634c2e2 100644 --- a/velox/functions/prestosql/types/IPAddressType.cpp +++ b/velox/functions/prestosql/types/IPAddressType.cpp @@ -16,6 +16,8 @@ #include "velox/functions/prestosql/types/IPAddressType.h" #include "velox/expression/CastExpr.h" +#include "velox/expression/VectorWriters.h" +#include "velox/functions/prestosql/types/IPPrefixType.h" namespace facebook::velox { @@ -28,6 +30,11 @@ class IPAddressCastOperator : public exec::CastOperator { case TypeKind::VARBINARY: case TypeKind::VARCHAR: return true; + case TypeKind::ROW: + if (isIPPrefixType(other)) { + return true; + } + [[fallthrough]]; default: return false; } @@ -38,6 +45,11 @@ class IPAddressCastOperator : public exec::CastOperator { case TypeKind::VARBINARY: case TypeKind::VARCHAR: return true; + case TypeKind::ROW: + if (isIPPrefixType(other)) { + return true; + } + [[fallthrough]]; default: return false; } @@ -55,6 +67,9 @@ class IPAddressCastOperator : public exec::CastOperator { castFromString(input, context, rows, *result); } else if (input.typeKind() == TypeKind::VARBINARY) { castFromVarbinary(input, context, rows, *result); + } else if ( + input.typeKind() == TypeKind::ROW && isIPPrefixType(input.type())) { + castFromIPPrefix(input, context, rows, *result); } else { VELOX_UNSUPPORTED( "Cast from {} to IPAddress not supported", resultType->toString()); @@ -73,6 +88,9 @@ class IPAddressCastOperator : public exec::CastOperator { castToString(input, context, rows, *result); } else if (resultType->kind() == TypeKind::VARBINARY) { castToVarbinary(input, context, rows, *result); + } else if ( + resultType->kind() == TypeKind::ROW && isIPPrefixType(resultType)) { + castToIPPrefix(input, context, rows, *result); } else { VELOX_UNSUPPORTED( "Cast from IPAddress to {} not supported", resultType->toString()); @@ -155,6 +173,47 @@ class IPAddressCastOperator : public exec::CastOperator { }); } + static void castToIPPrefix( + const BaseVector& input, + exec::EvalCtx& context, + const SelectivityVector& rows, + BaseVector& result) { + auto* rowVectorResult = result.as(); + const auto* ipaddresses = input.as>(); + + context.applyToSelectedNoThrow(rows, [&](auto row) { + const auto ipAddrVal = ipaddresses->valueAt(row); + const auto tryPrefixLength = + ipaddress::tryIpPrefixLengthFromIPAddressType(ipAddrVal); + if (tryPrefixLength.hasError()) { + context.setStatus(row, std::move(tryPrefixLength).error()); + return; + } + + auto writer = exec::VectorWriter>(); + writer.init(*rowVectorResult); + writer.setOffset(row); + auto& rowWriter = writer.current(); + rowWriter.get_writer_at<0>() = ipAddrVal; + rowWriter.get_writer_at<1>() = tryPrefixLength.value(); + writer.commit(); + }); + } + + static void castFromIPPrefix( + const BaseVector& input, + exec::EvalCtx& context, + const SelectivityVector& rows, + BaseVector& result) { + auto* flatResult = result.as>(); + const auto* ipprefix = input.as(); + const auto* ipaddr = + ipprefix->childAt(ipaddress::kIpRowIndex)->as>(); + + context.applyToSelectedNoThrow( + rows, [&](auto row) { flatResult->set(row, ipaddr->valueAt(row)); }); + } + static void castFromVarbinary( const BaseVector& input, exec::EvalCtx& context, diff --git a/velox/functions/prestosql/types/IPPrefixType.cpp b/velox/functions/prestosql/types/IPPrefixType.cpp index b835876df2322..86486607af58f 100644 --- a/velox/functions/prestosql/types/IPPrefixType.cpp +++ b/velox/functions/prestosql/types/IPPrefixType.cpp @@ -29,6 +29,8 @@ class IPPrefixCastOperator : public exec::CastOperator { switch (other->kind()) { case TypeKind::VARCHAR: return true; + case TypeKind::HUGEINT: + return isIPAddressType(other); default: return false; } @@ -38,6 +40,8 @@ class IPPrefixCastOperator : public exec::CastOperator { switch (other->kind()) { case TypeKind::VARCHAR: return true; + case TypeKind::HUGEINT: + return isIPAddressType(other); default: return false; } @@ -53,6 +57,12 @@ class IPPrefixCastOperator : public exec::CastOperator { switch (input.typeKind()) { case TypeKind::VARCHAR: return castFromString(input, context, rows, *result); + case TypeKind::HUGEINT: { + if (isIPAddressType(input.type())) { + return castFromIpAddress(input, context, rows, *result); + } + [[fallthrough]]; + } default: VELOX_NYI( "Cast from {} to IPPrefix not yet supported", @@ -70,6 +80,12 @@ class IPPrefixCastOperator : public exec::CastOperator { switch (resultType->kind()) { case TypeKind::VARCHAR: return castToString(input, context, rows, *result); + case TypeKind::HUGEINT: { + if (isIPAddressType(input.type())) { + return castToIpAddress(input, context, rows, *result); + } + [[fallthrough]]; + } default: VELOX_NYI( "Cast from IPPrefix to {} not yet supported", @@ -78,6 +94,19 @@ class IPPrefixCastOperator : public exec::CastOperator { } private: + static void castToIpAddress( + const BaseVector& input, + exec::EvalCtx& context, + const SelectivityVector& rows, + BaseVector& result) { + auto* flatResult = result.as>(); + auto rowVector = input.as(); + const auto* ipaddr = rowVector->childAt(ipaddress::kIpRowIndex) + ->as>(); + context.applyToSelectedNoThrow( + rows, [&](auto row) { flatResult->set(row, ipaddr->valueAt(row)); }); + } + static void castToString( const BaseVector& input, exec::EvalCtx& context, @@ -85,7 +114,6 @@ class IPPrefixCastOperator : public exec::CastOperator { BaseVector& result) { auto* flatResult = result.as>(); auto rowVector = input.as(); - auto rowType = rowVector->type(); const auto* ipaddr = rowVector->childAt(ipaddress::kIpRowIndex) ->as>(); const auto* prefix = rowVector->childAt(ipaddress::kIpPrefixRowIndex) @@ -120,6 +148,33 @@ class IPPrefixCastOperator : public exec::CastOperator { }); } + static void castFromIpAddress( + const BaseVector& input, + exec::EvalCtx& context, + const SelectivityVector& rows, + BaseVector& result) { + auto* rowVectorResult = result.as(); + const auto* ipAddrVector = input.as>(); + + context.applyToSelectedNoThrow(rows, [&](auto row) { + auto intIpAddr = ipAddrVector->valueAt(row); + const auto tryPrefixLength = + ipaddress::tryIpPrefixLengthFromIPAddressType(intIpAddr); + if (tryPrefixLength.hasError()) { + context.setStatus(row, std::move(tryPrefixLength).error()); + return; + } + + auto writer = exec::VectorWriter>(); + writer.init(*rowVectorResult); + writer.setOffset(row); + auto& rowWriter = writer.current(); + rowWriter.get_writer_at<0>() = intIpAddr; + rowWriter.get_writer_at<1>() = tryPrefixLength.value(); + writer.commit(); + }); + } + static void castFromString( const BaseVector& input, exec::EvalCtx& context, diff --git a/velox/functions/prestosql/types/IPPrefixType.h b/velox/functions/prestosql/types/IPPrefixType.h index bb387ef6f496d..84de117548dc6 100644 --- a/velox/functions/prestosql/types/IPPrefixType.h +++ b/velox/functions/prestosql/types/IPPrefixType.h @@ -89,6 +89,24 @@ Status handleFailedToCreateNetworkError( } } // namespace +inline folly::Expected tryIpPrefixLengthFromIPAddressType( + const int128_t& intIpAddr) { + folly::ByteArray16 addrBytes = {0}; + memcpy(&addrBytes, &intIpAddr, sizeof(intIpAddr)); + std::reverse(addrBytes.begin(), addrBytes.end()); + auto tryV6Addr = folly::IPAddressV6::tryFromBinary(addrBytes); + if (tryV6Addr.hasError()) { + return folly::makeUnexpected( + threadSkipErrorDetails() + ? Status::UserError() + : Status::UserError( + "Received invalid ip address '{}'", tryV6Addr.error())); + } + + return tryV6Addr.value().isIPv4Mapped() ? ipaddress::kIPV4Bits + : ipaddress::kIPV6Bits; +} + inline folly::Expected, Status> tryParseIpPrefixString(folly::StringPiece ipprefixString) { // Ensure '/' is present