Skip to content

Commit

Permalink
[COR-2297/] Fix nested offloaded type validation (#552)
Browse files Browse the repository at this point in the history
The following workflow works when we are not offloading  literals in flytekit

```
import logging
from typing import List
from flytekit import map_task, task, workflow,LaunchPlan

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("flytekit")
logger.setLevel(logging.DEBUG)

@task(cache=True, cache_version="1.1")
def my_30mb_task(i: str) -> str:
    return f"Hello world {i}" * 30 * 100 * 1024

@task(cache=True, cache_version="1.1")
def generate_strs(count: int) -> List[str]:
    return ["a"] * count

@workflow
def my_30mb_wf(mbs: int) -> List[str]:
  strs = generate_strs(count=mbs)
  return map_task(my_30mb_task)(i=strs)

@workflow
def big_inputs_wf(input: List[str]):
   noop()

@task(cache=True, cache_version="1.1")
def noop():
    ...

big_inputs_wf_lp = LaunchPlan.get_or_create(name="big_inputs_wf_lp", workflow=big_inputs_wf)

@workflow
def ref_wf(mbs: int):
  big_inputs_wf_lp(input=my_30mb_wf(mbs))
```

Without flytekit offloading the return type is OffloadedLiteral{inferredType:{Collection{String}} and when checked against big_inputs_wf launchplan which needs Collection{String} , the LiteralTypeToLiteral returns the inferredType : Collection{String}

If we enable offloading in flytekit, the returned data from map task is
Collection{OffloadedLiteral<{inferredType:{Collection{String}}}

When passing this Input to big_inputs_wf which takes Collection{String} then the type check fails due to LiteralTypeToLiteral returning Collection{OffloadedLiteral{inferredType:{Collection{String}}} as Collection{Collection{String}}

Flytekit handles this case by special casing Collection{OffloadedLiteral} and similar special casing is needed in flyte code base

Tested this by deploying this PR changes

https://dogfood-gcp.cloud-staging.union.ai/console/projects/flytesnacks/domains/development/executions/akxs97cdmkmxhhqp228x/nodes

Earlier it would fail like this https://dogfood-gcp.cloud-staging.union.ai/console/projects/flytesnacks/domains/development/executions/ap4thjp5528kjfspcsds/nodes
```
[UserError] failed to launch workflow, caused by: rpc error: code = InvalidArgument desc = invalid input input wrong type. Expected collection_type:{simple:STRING}, but got collection_type:{collection_type:{simple:STRING}}
```

Rollout to canary and then all prod byoc and serverless tenants

Should this change be upstreamed to OSS (flyteorg/flyte)? If not, please uncheck this box, which is used for auditing. Note, it is the responsibility of each developer to actually upstream their changes. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F).
- [x] To be upstreamed to OSS

*TODO: Link Linear issue(s) using [magic words](https://linear.app/docs/github#magic-words). `fixes` will move to merged status, while `ref` will only link the PR.*

* [X] Added tests
* [ ] Ran a deploy dry run and shared the terraform plan
* [ ] Added logging and metrics
* [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list)
* [ ] Updated documentation
  • Loading branch information
pmahindrakar-oss committed Nov 12, 2024
1 parent 3c3ae05 commit 7f0db4a
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 10 deletions.
25 changes: 19 additions & 6 deletions flytepropeller/pkg/compiler/validators/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,18 @@ func buildMultipleTypeUnion(innerType []*core.LiteralType) *core.LiteralType {
return unionLiteralType
}

func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType {
func literalTypeForLiterals(literals []*core.Literal) (*core.LiteralType, bool) {
innerType := make([]*core.LiteralType, 0, 1)
innerTypeSet := sets.NewString()
var noneType *core.LiteralType
isOffloadedType := false
for _, x := range literals {
otherType := LiteralTypeForLiteral(x)
otherTypeKey := otherType.String()
if _, ok := x.GetValue().(*core.Literal_OffloadedMetadata); ok {
isOffloadedType = true
return otherType, isOffloadedType
}
if _, ok := x.GetValue().(*core.Literal_Collection); ok {
if x.GetCollection().GetLiterals() == nil {
noneType = otherType
Expand All @@ -230,9 +235,9 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType {
if len(innerType) == 0 {
return &core.LiteralType{
Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE},
}
}, isOffloadedType
} else if len(innerType) == 1 {
return innerType[0]
return innerType[0], isOffloadedType
}

// sort inner types to ensure consistent union types are generated
Expand All @@ -247,7 +252,7 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType {

return 0
})
return buildMultipleTypeUnion(innerType)
return buildMultipleTypeUnion(innerType), isOffloadedType
}

// ValidateLiteralType check if the literal type is valid, return error if the literal is invalid.
Expand All @@ -274,15 +279,23 @@ func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType {
case *core.Literal_Scalar:
return literalTypeForScalar(l.GetScalar())
case *core.Literal_Collection:
collectionType, isOffloaded := literalTypeForLiterals(l.GetCollection().Literals)
if isOffloaded {
return collectionType
}
return &core.LiteralType{
Type: &core.LiteralType_CollectionType{
CollectionType: literalTypeForLiterals(l.GetCollection().Literals),
CollectionType: collectionType,
},
}
case *core.Literal_Map:
mapValueType, isOffloaded := literalTypeForLiterals(maps.Values(l.GetMap().Literals))
if isOffloaded {
return mapValueType
}
return &core.LiteralType{
Type: &core.LiteralType_MapValueType{
MapValueType: literalTypeForLiterals(maps.Values(l.GetMap().Literals)),
MapValueType: mapValueType,
},
}
case *core.Literal_OffloadedMetadata:
Expand Down
95 changes: 91 additions & 4 deletions flytepropeller/pkg/compiler/validators/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ import (

func TestLiteralTypeForLiterals(t *testing.T) {
t.Run("empty", func(t *testing.T) {
lt := literalTypeForLiterals(nil)
lt, isOffloaded := literalTypeForLiterals(nil)
assert.Equal(t, core.SimpleType_NONE.String(), lt.GetSimple().String())
assert.False(t, isOffloaded)
})

t.Run("binary idl with raw binary data and no tag", func(t *testing.T) {
Expand Down Expand Up @@ -94,17 +95,18 @@ func TestLiteralTypeForLiterals(t *testing.T) {
})

t.Run("homogeneous", func(t *testing.T) {
lt := literalTypeForLiterals([]*core.Literal{
lt, isOffloaded := literalTypeForLiterals([]*core.Literal{
coreutils.MustMakeLiteral(5),
coreutils.MustMakeLiteral(0),
coreutils.MustMakeLiteral(5),
})

assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetSimple().String())
assert.False(t, isOffloaded)
})

t.Run("non-homogenous", func(t *testing.T) {
lt := literalTypeForLiterals([]*core.Literal{
lt, isOffloaded := literalTypeForLiterals([]*core.Literal{
coreutils.MustMakeLiteral("hello"),
coreutils.MustMakeLiteral(5),
coreutils.MustMakeLiteral("world"),
Expand All @@ -115,10 +117,11 @@ func TestLiteralTypeForLiterals(t *testing.T) {
assert.Len(t, lt.GetUnionType().Variants, 2)
assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String())
assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String())
assert.False(t, isOffloaded)
})

t.Run("non-homogenous ensure ordering", func(t *testing.T) {
lt := literalTypeForLiterals([]*core.Literal{
lt, isOffloaded := literalTypeForLiterals([]*core.Literal{
coreutils.MustMakeLiteral(5),
coreutils.MustMakeLiteral("world"),
coreutils.MustMakeLiteral(0),
Expand All @@ -128,6 +131,7 @@ func TestLiteralTypeForLiterals(t *testing.T) {
assert.Len(t, lt.GetUnionType().Variants, 2)
assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String())
assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String())
assert.False(t, isOffloaded)
})

t.Run("list with mixed types", func(t *testing.T) {
Expand Down Expand Up @@ -454,6 +458,89 @@ func TestLiteralTypeForLiterals(t *testing.T) {
assert.True(t, proto.Equal(expectedLt, lt))
})

t.Run("nested Lists of offloaded List of string types", func(t *testing.T) {
inferredType := &core.LiteralType{
Type: &core.LiteralType_CollectionType{
CollectionType: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_STRING,
},
},
},
}
literals := &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: []*core.Literal{
{
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: "dummy/uri-1",
SizeBytes: 1000,
InferredType: inferredType,
},
},
},
{
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: "dummy/uri-2",
SizeBytes: 1000,
InferredType: inferredType,
},
},
},
},
},
},
}
expectedLt := inferredType
lt := LiteralTypeForLiteral(literals)
assert.True(t, proto.Equal(expectedLt, lt))
})
t.Run("nested map of offloaded map of string types", func(t *testing.T) {
inferredType := &core.LiteralType{
Type: &core.LiteralType_MapValueType{
MapValueType: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_STRING,
},
},
},
}
literals := &core.Literal{
Value: &core.Literal_Map{
Map: &core.LiteralMap{
Literals: map[string]*core.Literal{

"key1": {
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: "dummy/uri-1",
SizeBytes: 1000,
InferredType: inferredType,
},
},
},
"key2": {
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: "dummy/uri-2",
SizeBytes: 1000,
InferredType: inferredType,
},
},
},
},
},
},
}

expectedLt := inferredType
lt := LiteralTypeForLiteral(literals)
assert.True(t, proto.Equal(expectedLt, lt))
})

}

func TestJoinVariableMapsUniqueKeys(t *testing.T) {
Expand Down

0 comments on commit 7f0db4a

Please sign in to comment.