diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala index 239bec57a7d7..c81a6043094a 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala @@ -1258,4 +1258,57 @@ class TestOperator extends VeloxWholeStageTransformerSuite { } } } + + test("Support Map type signature") { + // test map + withTempView("t1") { + Seq[(Int, Map[String, String])]((1, Map("byte1" -> "aaa")), (2, Map("byte2" -> "bbbb"))) + .toDF("c1", "map_c2") + .createTempView("t1") + runQueryAndCompare(""" + |SELECT c1, collect_list(map_c2) FROM t1 group by c1; + |""".stripMargin) { + checkOperatorMatch[HashAggregateExecTransformer] + } + } + // test map> + withTempView("t2") { + Seq[(Int, Map[String, Map[String, String]])]( + (1, Map("byte1" -> Map("test1" -> "aaaa"))), + (2, Map("byte2" -> Map("test1" -> "bbbb")))) + .toDF("c1", "map_c2") + .createTempView("t2") + runQueryAndCompare(""" + |SELECT c1, collect_list(map_c2) FROM t2 group by c1; + |""".stripMargin) { + checkOperatorMatch[HashAggregateExecTransformer] + } + } + // test map,map> + withTempView("t3") { + Seq[(Int, Map[Map[String, String], Map[String, String]])]( + (1, Map(Map("byte1" -> "aaaa") -> Map("test1" -> "aaaa"))), + (2, Map(Map("byte2" -> "bbbb") -> Map("test1" -> "bbbb")))) + .toDF("c1", "map_c2") + .createTempView("t3") + runQueryAndCompare(""" + |SELECT collect_list(map_c2) FROM t3 group by c1; + |""".stripMargin) { + checkOperatorMatch[HashAggregateExecTransformer] + } + } + // test map> + withTempView("t4") { + Seq[(Int, Map[String, Array[String]])]( + (1, Map("test1" -> Array("test1", "test2"))), + (2, Map("test2" -> Array("test1", "test2")))) + .toDF("c1", "map_c2") + .createTempView("t4") + runQueryAndCompare(""" + |SELECT collect_list(map_c2) FROM t4 group by c1; + |""".stripMargin) { + checkOperatorMatch[HashAggregateExecTransformer] + } + } + } } diff --git a/cpp/velox/substrait/VeloxSubstraitSignature.cc b/cpp/velox/substrait/VeloxSubstraitSignature.cc index 2d2432281071..34e0df6de2fd 100644 --- a/cpp/velox/substrait/VeloxSubstraitSignature.cc +++ b/cpp/velox/substrait/VeloxSubstraitSignature.cc @@ -121,33 +121,25 @@ TypePtr VeloxSubstraitSignature::fromSubstraitSignature(const std::string& signa return str.size() >= prefix.size() && str.substr(0, prefix.size()) == prefix; }; - if (startWith(signature, "dec")) { - // Decimal type name is in the format of dec. - auto precisionStart = signature.find_first_of('<'); - auto tokenIndex = signature.find_first_of(','); - auto scaleEnd = signature.find_first_of('>'); - auto precision = stoi(signature.substr(precisionStart + 1, (tokenIndex - precisionStart - 1))); - auto scale = stoi(signature.substr(tokenIndex + 1, (scaleEnd - tokenIndex - 1))); - return DECIMAL(precision, scale); - } - - if (startWith(signature, "struct")) { - // Struct type name is in the format of struct. - auto structStart = signature.find_first_of('<'); - auto structEnd = signature.find_last_of('>'); + auto parseNestedTypeSignature = [&](const std::string& signature) -> std::vector { + auto start = signature.find_first_of('<'); + auto end = signature.find_last_of('>'); VELOX_CHECK( - structEnd - structStart > 1, "Native validation failed due to: more information is needed to create RowType"); - std::string childrenTypes = signature.substr(structStart + 1, structEnd - structStart - 1); + end - start > 1, + "Native validation failed due to: more information is needed to create nested type for {}", + signature); + + std::string childrenTypes = signature.substr(start + 1, end - start - 1); // Split the types with delimiter. std::string delimiter = ","; std::size_t pos; std::vector types; - std::vector names; while ((pos = childrenTypes.find(delimiter)) != std::string::npos) { auto typeStr = childrenTypes.substr(0, pos); std::size_t endPos = pos; - if (startWith(typeStr, "dec") || startWith(typeStr, "struct")) { + if (startWith(typeStr, "dec") || startWith(typeStr, "struct") || startWith(typeStr, "map") || + startWith(typeStr, "list")) { endPos = childrenTypes.find(">") + 1; if (endPos > pos) { typeStr += childrenTypes.substr(pos, endPos - pos); @@ -159,16 +151,43 @@ TypePtr VeloxSubstraitSignature::fromSubstraitSignature(const std::string& signa } } types.emplace_back(fromSubstraitSignature(typeStr)); - names.emplace_back(""); childrenTypes.erase(0, endPos + delimiter.length()); } if (childrenTypes.size() > 0 && !startWith(childrenTypes, ">")) { types.emplace_back(fromSubstraitSignature(childrenTypes)); - names.emplace_back(""); + } + return types; + }; + + if (startWith(signature, "dec")) { + // Decimal type name is in the format of dec. + auto precisionStart = signature.find_first_of('<'); + auto tokenIndex = signature.find_first_of(','); + auto scaleEnd = signature.find_first_of('>'); + auto precision = stoi(signature.substr(precisionStart + 1, (tokenIndex - precisionStart - 1))); + auto scale = stoi(signature.substr(tokenIndex + 1, (scaleEnd - tokenIndex - 1))); + return DECIMAL(precision, scale); + } + + if (startWith(signature, "struct")) { + // Struct type name is in the format of struct. + auto types = parseNestedTypeSignature(signature); + std::vector names(types.size()); + for (int i = 0; i < types.size(); i++) { + names[i] = ""; } return std::make_shared(std::move(names), std::move(types)); } + if (startWith(signature, "map")) { + // Map type name is in the format of map. + auto types = parseNestedTypeSignature(signature); + if (types.size() != 2) { + VELOX_UNSUPPORTED("Substrait type signature conversion to Velox type not supported for {}.", signature); + } + return MAP(std::move(types)[0], std::move(types)[1]); + } + if (startWith(signature, "list")) { auto listStart = signature.find_first_of('<'); auto listEnd = signature.find_last_of('>'); diff --git a/cpp/velox/tests/VeloxSubstraitSignatureTest.cc b/cpp/velox/tests/VeloxSubstraitSignatureTest.cc index bbc1165add88..d6db661f76cd 100644 --- a/cpp/velox/tests/VeloxSubstraitSignatureTest.cc +++ b/cpp/velox/tests/VeloxSubstraitSignatureTest.cc @@ -139,6 +139,22 @@ TEST_F(VeloxSubstraitSignatureTest, fromSubstraitSignature) { type = fromSubstraitSignature("struct>>>"); ASSERT_EQ(type->childAt(0)->childAt(0)->childAt(1)->kind(), TypeKind::HUGEINT); ASSERT_ANY_THROW(fromSubstraitSignature("other")->kind()); + + // Map type test. + type = fromSubstraitSignature("map>>"); + ASSERT_EQ(type->kind(), TypeKind::MAP); + ASSERT_EQ(type->childAt(0)->kind(), TypeKind::BOOLEAN); + ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ARRAY); + ASSERT_EQ(type->childAt(1)->childAt(0)->kind(), TypeKind::MAP); + type = fromSubstraitSignature("struct,list>>"); + ASSERT_EQ(type->kind(), TypeKind::ROW); + ASSERT_EQ(type->childAt(0)->kind(), TypeKind::MAP); + ASSERT_EQ(type->childAt(0)->childAt(0)->kind(), TypeKind::BOOLEAN); + ASSERT_EQ(type->childAt(0)->childAt(1)->kind(), TypeKind::TINYINT); + ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ARRAY); + ASSERT_EQ(type->childAt(1)->childAt(0)->kind(), TypeKind::MAP); + ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(0)->kind(), TypeKind::VARCHAR); + ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(1)->kind(), TypeKind::INTEGER); } } // namespace gluten diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala index ec03c920a833..315fa4d31c48 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala @@ -378,8 +378,13 @@ object ConverterUtils extends Logging { }) sigName = sigName.concat(">") sigName - case MapType(_, _, _) => - "map" + case MapType(keyType, valueType, _) => + var sigName = "map<" + sigName = sigName.concat(getTypeSigName(keyType)) + sigName = sigName.concat(",") + sigName = sigName.concat(getTypeSigName(valueType)) + sigName = sigName.concat(">") + sigName case CharType(_) => "fchar" case NullType =>