Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-4830][VL] Support MapType substrait signature #4833

Merged
merged 3 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1258,4 +1258,57 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
}
}
}

test("Support Map type signature") {
// test map<str,str>
withTempView("t1") {
Seq[(Int, Map[String, String])]((1, Map("byte1" -> "aaa")), (2, Map("byte2" -> "bbbb")))
.toDF("c1", "map_c2")
.createTempView("t1")
runQueryAndCompare("""
|SELECT c1, collect_list(map_c2) FROM t1 group by c1;
|""".stripMargin) {
checkOperatorMatch[HashAggregateExecTransformer]
}
}
// test map<str,map<str,str>>
withTempView("t2") {
Seq[(Int, Map[String, Map[String, String]])](
(1, Map("byte1" -> Map("test1" -> "aaaa"))),
(2, Map("byte2" -> Map("test1" -> "bbbb"))))
.toDF("c1", "map_c2")
.createTempView("t2")
runQueryAndCompare("""
|SELECT c1, collect_list(map_c2) FROM t2 group by c1;
|""".stripMargin) {
checkOperatorMatch[HashAggregateExecTransformer]
}
}
// test map<map<str,str>,map<str,str>>
withTempView("t3") {
Seq[(Int, Map[Map[String, String], Map[String, String]])](
(1, Map(Map("byte1" -> "aaaa") -> Map("test1" -> "aaaa"))),
(2, Map(Map("byte2" -> "bbbb") -> Map("test1" -> "bbbb"))))
.toDF("c1", "map_c2")
.createTempView("t3")
runQueryAndCompare("""
|SELECT collect_list(map_c2) FROM t3 group by c1;
|""".stripMargin) {
checkOperatorMatch[HashAggregateExecTransformer]
}
}
// test map<str,list<str>>
withTempView("t4") {
Seq[(Int, Map[String, Array[String]])](
(1, Map("test1" -> Array("test1", "test2"))),
(2, Map("test2" -> Array("test1", "test2"))))
.toDF("c1", "map_c2")
.createTempView("t4")
runQueryAndCompare("""
|SELECT collect_list(map_c2) FROM t4 group by c1;
|""".stripMargin) {
checkOperatorMatch[HashAggregateExecTransformer]
}
}
}
}
59 changes: 39 additions & 20 deletions cpp/velox/substrait/VeloxSubstraitSignature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,33 +121,25 @@ TypePtr VeloxSubstraitSignature::fromSubstraitSignature(const std::string& signa
return str.size() >= prefix.size() && str.substr(0, prefix.size()) == prefix;
};

if (startWith(signature, "dec")) {
// Decimal type name is in the format of dec<precision,scale>.
auto precisionStart = signature.find_first_of('<');
auto tokenIndex = signature.find_first_of(',');
auto scaleEnd = signature.find_first_of('>');
auto precision = stoi(signature.substr(precisionStart + 1, (tokenIndex - precisionStart - 1)));
auto scale = stoi(signature.substr(tokenIndex + 1, (scaleEnd - tokenIndex - 1)));
return DECIMAL(precision, scale);
}

if (startWith(signature, "struct")) {
// Struct type name is in the format of struct<T1,T2,...,Tn>.
auto structStart = signature.find_first_of('<');
auto structEnd = signature.find_last_of('>');
auto parseNestedTypeSignature = [&](const std::string& signature) -> std::vector<TypePtr> {
auto start = signature.find_first_of('<');
auto end = signature.find_last_of('>');
VELOX_CHECK(
structEnd - structStart > 1, "Native validation failed due to: more information is needed to create RowType");
std::string childrenTypes = signature.substr(structStart + 1, structEnd - structStart - 1);
end - start > 1,
"Native validation failed due to: more information is needed to create nested type for {}",
signature);

std::string childrenTypes = signature.substr(start + 1, end - start - 1);

// Split the types with delimiter.
std::string delimiter = ",";
std::size_t pos;
std::vector<TypePtr> types;
std::vector<std::string> names;
while ((pos = childrenTypes.find(delimiter)) != std::string::npos) {
auto typeStr = childrenTypes.substr(0, pos);
std::size_t endPos = pos;
if (startWith(typeStr, "dec") || startWith(typeStr, "struct")) {
if (startWith(typeStr, "dec") || startWith(typeStr, "struct") || startWith(typeStr, "map") ||
startWith(typeStr, "list")) {
endPos = childrenTypes.find(">") + 1;
if (endPos > pos) {
typeStr += childrenTypes.substr(pos, endPos - pos);
Expand All @@ -159,16 +151,43 @@ TypePtr VeloxSubstraitSignature::fromSubstraitSignature(const std::string& signa
}
}
types.emplace_back(fromSubstraitSignature(typeStr));
names.emplace_back("");
childrenTypes.erase(0, endPos + delimiter.length());
}
if (childrenTypes.size() > 0 && !startWith(childrenTypes, ">")) {
types.emplace_back(fromSubstraitSignature(childrenTypes));
names.emplace_back("");
}
return types;
};

if (startWith(signature, "dec")) {
// Decimal type name is in the format of dec<precision,scale>.
auto precisionStart = signature.find_first_of('<');
auto tokenIndex = signature.find_first_of(',');
auto scaleEnd = signature.find_first_of('>');
auto precision = stoi(signature.substr(precisionStart + 1, (tokenIndex - precisionStart - 1)));
auto scale = stoi(signature.substr(tokenIndex + 1, (scaleEnd - tokenIndex - 1)));
return DECIMAL(precision, scale);
}

if (startWith(signature, "struct")) {
// Struct type name is in the format of struct<T1,T2,...,Tn>.
auto types = parseNestedTypeSignature(signature);
std::vector<std::string> names(types.size());
for (int i = 0; i < types.size(); i++) {
names[i] = "";
}
return std::make_shared<RowType>(std::move(names), std::move(types));
}

if (startWith(signature, "map")) {
// Map type name is in the format of map<T1,T2>.
auto types = parseNestedTypeSignature(signature);
if (types.size() != 2) {
VELOX_UNSUPPORTED("Substrait type signature conversion to Velox type not supported for {}.", signature);
}
return MAP(std::move(types)[0], std::move(types)[1]);
}

if (startWith(signature, "list")) {
auto listStart = signature.find_first_of('<');
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WangGuangxin, it seems we can also use parseNestedTypeSignature to handle list type. Right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that the parseNestedTypeSignature can't properly handle list type, since list type's child doesn't have comma, so for types like list<map<str,str>> it will not work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WangGuangxin, I see. Thanks!

auto listEnd = signature.find_last_of('>');
Expand Down
16 changes: 16 additions & 0 deletions cpp/velox/tests/VeloxSubstraitSignatureTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,22 @@ TEST_F(VeloxSubstraitSignatureTest, fromSubstraitSignature) {
type = fromSubstraitSignature("struct<struct<struct<i8,dec<19,2>>>>");
ASSERT_EQ(type->childAt(0)->childAt(0)->childAt(1)->kind(), TypeKind::HUGEINT);
ASSERT_ANY_THROW(fromSubstraitSignature("other")->kind());

// Map type test.
type = fromSubstraitSignature("map<bool,list<map<str,i32>>>");
ASSERT_EQ(type->kind(), TypeKind::MAP);
ASSERT_EQ(type->childAt(0)->kind(), TypeKind::BOOLEAN);
ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ARRAY);
ASSERT_EQ(type->childAt(1)->childAt(0)->kind(), TypeKind::MAP);
type = fromSubstraitSignature("struct<map<bool,i8>,list<map<str,i32>>>");
ASSERT_EQ(type->kind(), TypeKind::ROW);
ASSERT_EQ(type->childAt(0)->kind(), TypeKind::MAP);
ASSERT_EQ(type->childAt(0)->childAt(0)->kind(), TypeKind::BOOLEAN);
ASSERT_EQ(type->childAt(0)->childAt(1)->kind(), TypeKind::TINYINT);
ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ARRAY);
ASSERT_EQ(type->childAt(1)->childAt(0)->kind(), TypeKind::MAP);
ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(0)->kind(), TypeKind::VARCHAR);
ASSERT_EQ(type->childAt(1)->childAt(0)->childAt(1)->kind(), TypeKind::INTEGER);
}

} // namespace gluten
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,13 @@ object ConverterUtils extends Logging {
})
sigName = sigName.concat(">")
sigName
case MapType(_, _, _) =>
"map"
case MapType(keyType, valueType, _) =>
var sigName = "map<"
sigName = sigName.concat(getTypeSigName(keyType))
sigName = sigName.concat(",")
sigName = sigName.concat(getTypeSigName(valueType))
sigName = sigName.concat(">")
sigName
case CharType(_) =>
"fchar"
case NullType =>
Expand Down
Loading