Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Upstream] [COR-2297/] Fix nested offloaded type validation (#552) #5996

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions flytepropeller/pkg/compiler/validators/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,18 @@ func buildMultipleTypeUnion(innerType []*core.LiteralType) *core.LiteralType {
return unionLiteralType
}

func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType {
func literalTypeForLiterals(literals []*core.Literal) (*core.LiteralType, bool) {
innerType := make([]*core.LiteralType, 0, 1)
innerTypeSet := sets.NewString()
var noneType *core.LiteralType
isOffloadedType := false
for _, x := range literals {
otherType := LiteralTypeForLiteral(x)
otherTypeKey := otherType.String()
if _, ok := x.GetValue().(*core.Literal_OffloadedMetadata); ok {
isOffloadedType = true
return otherType, isOffloadedType
}
if _, ok := x.GetValue().(*core.Literal_Collection); ok {
if x.GetCollection().GetLiterals() == nil {
noneType = otherType
Expand All @@ -230,9 +235,9 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType {
if len(innerType) == 0 {
return &core.LiteralType{
Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE},
}
}, isOffloadedType
} else if len(innerType) == 1 {
return innerType[0]
return innerType[0], isOffloadedType
}

// sort inner types to ensure consistent union types are generated
Expand All @@ -247,7 +252,7 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType {

return 0
})
return buildMultipleTypeUnion(innerType)
return buildMultipleTypeUnion(innerType), isOffloadedType
}

// ValidateLiteralType check if the literal type is valid, return error if the literal is invalid.
Expand All @@ -274,15 +279,23 @@ func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType {
case *core.Literal_Scalar:
return literalTypeForScalar(l.GetScalar())
case *core.Literal_Collection:
collectionType, isOffloaded := literalTypeForLiterals(l.GetCollection().Literals)
if isOffloaded {
return collectionType
}
return &core.LiteralType{
Type: &core.LiteralType_CollectionType{
CollectionType: literalTypeForLiterals(l.GetCollection().Literals),
CollectionType: collectionType,
},
}
case *core.Literal_Map:
mapValueType, isOffloaded := literalTypeForLiterals(maps.Values(l.GetMap().Literals))
if isOffloaded {
return mapValueType
}
return &core.LiteralType{
Type: &core.LiteralType_MapValueType{
MapValueType: literalTypeForLiterals(maps.Values(l.GetMap().Literals)),
MapValueType: mapValueType,
},
}
case *core.Literal_OffloadedMetadata:
Expand Down
95 changes: 91 additions & 4 deletions flytepropeller/pkg/compiler/validators/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ import (

func TestLiteralTypeForLiterals(t *testing.T) {
t.Run("empty", func(t *testing.T) {
lt := literalTypeForLiterals(nil)
lt, isOffloaded := literalTypeForLiterals(nil)
assert.Equal(t, core.SimpleType_NONE.String(), lt.GetSimple().String())
assert.False(t, isOffloaded)
})

t.Run("binary idl with raw binary data and no tag", func(t *testing.T) {
Expand Down Expand Up @@ -94,17 +95,18 @@ func TestLiteralTypeForLiterals(t *testing.T) {
})

t.Run("homogeneous", func(t *testing.T) {
lt := literalTypeForLiterals([]*core.Literal{
lt, isOffloaded := literalTypeForLiterals([]*core.Literal{
coreutils.MustMakeLiteral(5),
coreutils.MustMakeLiteral(0),
coreutils.MustMakeLiteral(5),
})

assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetSimple().String())
assert.False(t, isOffloaded)
})

t.Run("non-homogenous", func(t *testing.T) {
lt := literalTypeForLiterals([]*core.Literal{
lt, isOffloaded := literalTypeForLiterals([]*core.Literal{
coreutils.MustMakeLiteral("hello"),
coreutils.MustMakeLiteral(5),
coreutils.MustMakeLiteral("world"),
Expand All @@ -115,10 +117,11 @@ func TestLiteralTypeForLiterals(t *testing.T) {
assert.Len(t, lt.GetUnionType().Variants, 2)
assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String())
assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String())
assert.False(t, isOffloaded)
})

t.Run("non-homogenous ensure ordering", func(t *testing.T) {
lt := literalTypeForLiterals([]*core.Literal{
lt, isOffloaded := literalTypeForLiterals([]*core.Literal{
coreutils.MustMakeLiteral(5),
coreutils.MustMakeLiteral("world"),
coreutils.MustMakeLiteral(0),
Expand All @@ -128,6 +131,7 @@ func TestLiteralTypeForLiterals(t *testing.T) {
assert.Len(t, lt.GetUnionType().Variants, 2)
assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String())
assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String())
assert.False(t, isOffloaded)
})

t.Run("list with mixed types", func(t *testing.T) {
Expand Down Expand Up @@ -454,6 +458,89 @@ func TestLiteralTypeForLiterals(t *testing.T) {
assert.True(t, proto.Equal(expectedLt, lt))
})

t.Run("nested Lists of offloaded List of string types", func(t *testing.T) {
inferredType := &core.LiteralType{
Type: &core.LiteralType_CollectionType{
CollectionType: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_STRING,
},
},
},
}
literals := &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: []*core.Literal{
{
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: "dummy/uri-1",
SizeBytes: 1000,
InferredType: inferredType,
},
},
},
{
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: "dummy/uri-2",
SizeBytes: 1000,
InferredType: inferredType,
},
},
},
},
},
},
}
expectedLt := inferredType
lt := LiteralTypeForLiteral(literals)
assert.True(t, proto.Equal(expectedLt, lt))
})
t.Run("nested map of offloaded map of string types", func(t *testing.T) {
inferredType := &core.LiteralType{
Type: &core.LiteralType_MapValueType{
MapValueType: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_STRING,
},
},
},
}
literals := &core.Literal{
Value: &core.Literal_Map{
Map: &core.LiteralMap{
Literals: map[string]*core.Literal{

"key1": {
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: "dummy/uri-1",
SizeBytes: 1000,
InferredType: inferredType,
},
},
},
"key2": {
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: "dummy/uri-2",
SizeBytes: 1000,
InferredType: inferredType,
},
},
},
},
},
},
}

expectedLt := inferredType
lt := LiteralTypeForLiteral(literals)
assert.True(t, proto.Equal(expectedLt, lt))
})

}

func TestJoinVariableMapsUniqueKeys(t *testing.T) {
Expand Down
Loading