diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index 836bc69979..7d51b3b695 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -166,6 +166,7 @@ require ( github.com/prometheus/procfs v0.10.1 // indirect github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect github.com/sendgrid/rest v2.6.9+incompatible // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.8.2 // indirect github.com/spf13/cast v1.4.1 // indirect diff --git a/flytecopilot/go.mod b/flytecopilot/go.mod index e1dbdc7683..166411654f 100644 --- a/flytecopilot/go.mod +++ b/flytecopilot/go.mod @@ -83,6 +83,7 @@ require ( github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.10.1 // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.8.2 // indirect github.com/spf13/cast v1.4.1 // indirect diff --git a/flytecopilot/go.sum b/flytecopilot/go.sum index 9fb93ec715..0e3773721b 100644 --- a/flytecopilot/go.sum +++ b/flytecopilot/go.sum @@ -311,6 +311,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= diff --git a/flytectl/go.mod b/flytectl/go.mod index d783ac0513..8829eb881f 100644 --- a/flytectl/go.mod +++ b/flytectl/go.mod @@ -142,6 +142,7 @@ require ( github.com/prometheus/procfs v0.10.1 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/spf13/afero v1.9.2 // indirect github.com/spf13/cast v1.4.1 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect diff --git a/flytectl/go.sum b/flytectl/go.sum index 1e3b5d7ef8..cb4054c995 100644 --- a/flytectl/go.sum +++ b/flytectl/go.sum @@ -420,6 +420,8 @@ github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/flyteidl/clients/go/coreutils/extract_literal.go b/flyteidl/clients/go/coreutils/extract_literal.go index f9918dd0f8..23302de9a3 100644 --- a/flyteidl/clients/go/coreutils/extract_literal.go +++ b/flyteidl/clients/go/coreutils/extract_literal.go @@ -54,6 +54,8 @@ func ExtractFromLiteral(literal *core.Literal) (interface{}, error) { default: return nil, fmt.Errorf("unsupported literal scalar primitive type %T", scalarValue) } + case *core.Scalar_Binary: + return scalarValue.Binary, nil case *core.Scalar_Blob: return scalarValue.Blob.Uri, nil case *core.Scalar_Schema: diff --git a/flyteidl/clients/go/coreutils/extract_literal_test.go b/flyteidl/clients/go/coreutils/extract_literal_test.go index 760e7bee0a..67e27fb74f 100644 --- a/flyteidl/clients/go/coreutils/extract_literal_test.go +++ b/flyteidl/clients/go/coreutils/extract_literal_test.go @@ -113,7 +113,7 @@ func TestFetchLiteral(t *testing.T) { s := MakeBinaryLiteral([]byte{'h'}) assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue()) _, err := ExtractFromLiteral(s) - assert.NotNil(t, err) + assert.Nil(t, err) }) t.Run("NoneType", func(t *testing.T) { @@ -124,34 +124,6 @@ func TestFetchLiteral(t *testing.T) { assert.Nil(t, err) }) - t.Run("Generic", func(t *testing.T) { - literalVal := map[string]interface{}{ - "x": 1, - "y": "ystringvalue", - } - var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} - lit, err := MakeLiteralForType(literalType, literalVal) - assert.NoError(t, err) - extractedLiteralVal, err := ExtractFromLiteral(lit) - assert.NoError(t, err) - fieldsMap := map[string]*structpb.Value{ - "x": { - Kind: &structpb.Value_NumberValue{NumberValue: 1}, - }, - "y": { - Kind: &structpb.Value_StringValue{StringValue: "ystringvalue"}, - }, - } - expectedStructVal := &structpb.Struct{ - Fields: fieldsMap, - } - extractedStructValue := extractedLiteralVal.(*structpb.Struct) - assert.Equal(t, len(expectedStructVal.Fields), len(extractedStructValue.Fields)) - for key, val := range expectedStructVal.Fields { - assert.Equal(t, val.Kind, extractedStructValue.Fields[key].Kind) - } - }) - t.Run("Generic Passed As String", func(t *testing.T) { literalVal := "{\"x\": 1,\"y\": \"ystringvalue\"}" var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} diff --git a/flyteidl/clients/go/coreutils/literals.go b/flyteidl/clients/go/coreutils/literals.go index 278fa30dfc..c0244c5190 100644 --- a/flyteidl/clients/go/coreutils/literals.go +++ b/flyteidl/clients/go/coreutils/literals.go @@ -2,7 +2,6 @@ package coreutils import ( - "encoding/json" "fmt" "math" "reflect" @@ -14,11 +13,14 @@ import ( "github.com/golang/protobuf/ptypes" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" + "github.com/shamaton/msgpack/v2" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytestdlib/storage" ) +const MESSAGEPACK = "msgpack" + func MakePrimitive(v interface{}) (*core.Primitive, error) { switch p := v.(type) { case int: @@ -144,6 +146,7 @@ func MakeBinaryLiteral(v []byte) *core.Literal { Value: &core.Scalar_Binary{ Binary: &core.Binary{ Value: v, + Tag: MESSAGEPACK, }, }, }, @@ -389,7 +392,7 @@ func MakeLiteralForSimpleType(t core.SimpleType, s string) (*core.Literal, error scalar.Value = &core.Scalar_Binary{ Binary: &core.Binary{ Value: []byte(s), - // TODO Tag not supported at the moment + Tag: MESSAGEPACK, }, } case core.SimpleType_ERROR: @@ -559,12 +562,35 @@ func MakeLiteralForType(t *core.LiteralType, v interface{}) (*core.Literal, erro strValue = fmt.Sprintf("%.0f", math.Trunc(f)) } if newT.Simple == core.SimpleType_STRUCT { + // If the type is a STRUCT, we expect the input to be a complex object + // like the following example: + // inputs: + // dc: + // a: 1 + // b: 3.14 + // c: "example_string" + // Instead of storing it directly as a structured value, we will serialize + // the input object using MsgPack and return it as a binary IDL object. + + // If the value is not already a string (meaning it's not already serialized), + // proceed with serialization. if _, isValueStringType := v.(string); !isValueStringType { - byteValue, err := json.Marshal(v) + byteValue, err := msgpack.Marshal(v) if err != nil { return nil, fmt.Errorf("unable to marshal to json string for struct value %v", v) } - strValue = string(byteValue) + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: byteValue, + Tag: MESSAGEPACK, + }, + }, + }, + }, + }, nil } } lv, err := MakeLiteralForSimpleType(newT.Simple, strValue) diff --git a/flyteidl/clients/go/coreutils/literals_test.go b/flyteidl/clients/go/coreutils/literals_test.go index 24a0af4865..009703bc1c 100644 --- a/flyteidl/clients/go/coreutils/literals_test.go +++ b/flyteidl/clients/go/coreutils/literals_test.go @@ -14,6 +14,7 @@ import ( "github.com/golang/protobuf/ptypes" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" + "github.com/shamaton/msgpack/v2" "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -242,6 +243,16 @@ func TestMakeDefaultLiteralForType(t *testing.T) { assert.NotNil(t, l.GetScalar().GetError()) }) + t.Run("binary", func(t *testing.T) { + l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BINARY, + }}) + assert.NoError(t, err) + assert.NotNil(t, l.GetScalar().GetBinary()) + assert.NotNil(t, l.GetScalar().GetBinary().GetValue()) + assert.NotNil(t, l.GetScalar().GetBinary().GetTag()) + }) + t.Run("struct", func(t *testing.T) { l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Simple{ Simple: core.SimpleType_STRUCT, @@ -444,6 +455,68 @@ func TestMakeLiteralForType(t *testing.T) { assert.Equal(t, expectedVal, actualVal) }) + t.Run("SimpleBinary", func(t *testing.T) { + // We compare the deserialized values instead of the raw msgpack bytes because Go does not guarantee the order + // of map keys during serialization. This means that while the serialized bytes may differ, the deserialized + // values should be logically equivalent. + + var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} + v := map[string]interface{}{ + "a": int64(1), + "b": 3.14, + "c": "example_string", + "d": map[string]interface{}{ + "1": int64(100), + "2": int64(200), + }, + "e": map[string]interface{}{ + "a": int64(1), + "b": 3.14, + }, + "f": []string{"a", "b", "c"}, + } + + val, err := MakeLiteralForType(literalType, v) + assert.NoError(t, err) + + msgpackBytes, err := msgpack.Marshal(v) + assert.NoError(t, err) + + literalVal := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: msgpackBytes, + Tag: MESSAGEPACK, + }, + }, + }, + }, + } + + expectedLiteralVal, err := ExtractFromLiteral(literalVal) + assert.NoError(t, err) + actualLiteralVal, err := ExtractFromLiteral(val) + assert.NoError(t, err) + + // Check if the extracted value is of type *core.Binary (not []byte) + expectedBinary, ok := expectedLiteralVal.(*core.Binary) + assert.True(t, ok, "expectedLiteralVal is not of type *core.Binary") + actualBinary, ok := actualLiteralVal.(*core.Binary) + assert.True(t, ok, "actualLiteralVal is not of type *core.Binary") + + // Now check if the Binary values match + var expectedVal, actualVal map[string]interface{} + err = msgpack.Unmarshal(expectedBinary.Value, &expectedVal) + assert.NoError(t, err) + err = msgpack.Unmarshal(actualBinary.Value, &actualVal) + assert.NoError(t, err) + + // Finally, assert that the deserialized values are equal + assert.Equal(t, expectedVal, actualVal) + }) + t.Run("ArrayStrings", func(t *testing.T) { var literalType = &core.LiteralType{Type: &core.LiteralType_CollectionType{ CollectionType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}}}} diff --git a/flyteidl/go.mod b/flyteidl/go.mod index 55ec124554..037cac70cd 100644 --- a/flyteidl/go.mod +++ b/flyteidl/go.mod @@ -13,6 +13,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/pkg/errors v0.9.1 + github.com/shamaton/msgpack/v2 v2.2.2 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 golang.org/x/net v0.27.0 diff --git a/flyteidl/go.sum b/flyteidl/go.sum index 5d5cb7e9a2..e1e7d9782d 100644 --- a/flyteidl/go.sum +++ b/flyteidl/go.sum @@ -217,6 +217,8 @@ github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPH github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= diff --git a/flyteplugins/go.mod b/flyteplugins/go.mod index c4287581bc..1d75c44ac3 100644 --- a/flyteplugins/go.mod +++ b/flyteplugins/go.mod @@ -108,6 +108,7 @@ require ( github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.10.1 // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.8.2 // indirect github.com/spf13/cast v1.4.1 // indirect diff --git a/flyteplugins/go.sum b/flyteplugins/go.sum index fa26e3cfda..3721c28a7a 100644 --- a/flyteplugins/go.sum +++ b/flyteplugins/go.sum @@ -342,6 +342,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= diff --git a/flytepropeller/pkg/compiler/validators/utils.go b/flytepropeller/pkg/compiler/validators/utils.go index d1bffbbe09..cbb14b3124 100644 --- a/flytepropeller/pkg/compiler/validators/utils.go +++ b/flytepropeller/pkg/compiler/validators/utils.go @@ -8,6 +8,7 @@ import ( "golang.org/x/exp/slices" "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" ) @@ -47,7 +48,7 @@ func literalTypeForScalar(scalar *core.Scalar) *core.LiteralType { // If the binary has a tag, treat it as a structured type (e.g., dict, dataclass, Pydantic BaseModel). // Otherwise, treat it as raw binary data. // Reference: https://github.com/flyteorg/flyte/blob/master/rfc/system/5741-binary-idl-with-message-pack.md - if len(v.Binary.Tag) > 0 { + if v.Binary.Tag == coreutils.MESSAGEPACK { literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} } else { literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_BINARY}} diff --git a/flytepropeller/pkg/compiler/validators/utils_test.go b/flytepropeller/pkg/compiler/validators/utils_test.go index b6737c7e62..09790849f3 100644 --- a/flytepropeller/pkg/compiler/validators/utils_test.go +++ b/flytepropeller/pkg/compiler/validators/utils_test.go @@ -55,7 +55,7 @@ func TestLiteralTypeForLiterals(t *testing.T) { Value: &core.Scalar_Binary{ Binary: &core.Binary{ Value: serializedBinaryData, - Tag: "msgpack", + Tag: coreutils.MESSAGEPACK, }, }, }, @@ -83,7 +83,7 @@ func TestLiteralTypeForLiterals(t *testing.T) { Value: &core.Scalar_Binary{ Binary: &core.Binary{ Value: serializedBinaryData, - Tag: "msgpack", + Tag: coreutils.MESSAGEPACK, }, }, }, diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go index fa19d2bf5c..192fa1956c 100644 --- a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go @@ -3,8 +3,10 @@ package nodes import ( "context" + "github.com/shamaton/msgpack/v2" "google.golang.org/protobuf/types/known/structpb" + "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/common" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/errors" @@ -18,7 +20,7 @@ func resolveAttrPathInPromise(ctx context.Context, datastore *storage.DataStore, var tmpVal *core.Literal var err error var exist bool - count := 0 + index := 0 for _, attr := range bindAttrPath { if currVal.GetOffloadedMetadata() != nil { @@ -37,26 +39,31 @@ func resolveAttrPathInPromise(ctx context.Context, datastore *storage.DataStore, return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "key [%v] does not exist in literal %v", attr.GetStringValue(), currVal.GetMap().GetLiterals()) } currVal = tmpVal - count++ + index++ case *core.Literal_Collection: if int(attr.GetIntValue()) >= len(currVal.GetCollection().GetLiterals()) { return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "index [%v] is out of range of %v", attr.GetIntValue(), currVal.GetCollection().GetLiterals()) } currVal = currVal.GetCollection().GetLiterals()[attr.GetIntValue()] - count++ + index++ // scalar is always the leaf, so we can break here case *core.Literal_Scalar: break } } - // resolve dataclass - if currVal.GetScalar() != nil && currVal.GetScalar().GetGeneric() != nil { - st := currVal.GetScalar().GetGeneric() - // start from index "count" - currVal, err = resolveAttrPathInPbStruct(nodeID, st, bindAttrPath[count:]) - if err != nil { - return nil, err + // resolve dataclass and Pydantic BaseModel + if scalar := currVal.GetScalar(); scalar != nil { + if binary := scalar.GetBinary(); binary != nil { + currVal, err = resolveAttrPathInBinary(nodeID, binary, bindAttrPath[index:]) + if err != nil { + return nil, err + } + } else if generic := scalar.GetGeneric(); generic != nil { + currVal, err = resolveAttrPathInPbStruct(nodeID, generic, bindAttrPath[index:]) + if err != nil { + return nil, err + } } } @@ -66,8 +73,8 @@ func resolveAttrPathInPromise(ctx context.Context, datastore *storage.DataStore, // resolveAttrPathInPbStruct resolves the protobuf struct (e.g. dataclass) with attribute path func resolveAttrPathInPbStruct(nodeID string, st *structpb.Struct, bindAttrPath []*core.PromiseAttribute) (*core.Literal, error) { - var currVal interface{} - var tmpVal interface{} + var currVal any + var tmpVal any var exist bool currVal = st.AsMap() @@ -76,16 +83,18 @@ func resolveAttrPathInPbStruct(nodeID string, st *structpb.Struct, bindAttrPath for _, attr := range bindAttrPath { switch resolvedVal := currVal.(type) { // map - case map[string]interface{}: + case map[string]any: tmpVal, exist = resolvedVal[attr.GetStringValue()] if !exist { return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "key [%v] does not exist in literal %v", attr.GetStringValue(), currVal) } currVal = tmpVal // list - case []interface{}: - if int(attr.GetIntValue()) >= len(resolvedVal) { - return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "index [%v] is out of range of %v", attr.GetIntValue(), currVal) + case []any: + index := int(attr.GetIntValue()) + if index < 0 || index >= len(resolvedVal) { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, + "index [%v] is out of range of %v", index, resolvedVal) } currVal = resolvedVal[attr.GetIntValue()] } @@ -97,6 +106,89 @@ func resolveAttrPathInPbStruct(nodeID string, st *structpb.Struct, bindAttrPath return literal, err } +// resolveAttrPathInBinary resolves the binary idl object (e.g. dataclass, pydantic basemodel) with attribute path +func resolveAttrPathInBinary(nodeID string, binaryIDL *core.Binary, bindAttrPath []*core.PromiseAttribute) (*core. + Literal, + error) { + + binaryBytes := binaryIDL.GetValue() + serializationFormat := binaryIDL.GetTag() + + var currVal any + var tmpVal any + var exist bool + + if serializationFormat == coreutils.MESSAGEPACK { + err := msgpack.Unmarshal(binaryBytes, &currVal) + if err != nil { + return nil, err + } + } else { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, + "Unsupported format '%v' found for literal value.\n"+ + "Please ensure the serialization format is supported.", serializationFormat) + } + + // Turn the current value to a map, so it can be resolved more easily + for _, attr := range bindAttrPath { + switch resolvedVal := currVal.(type) { + // map + case map[any]any: + // TODO: for cases like Dict[int, Any] in a dataclass, this will fail, + // will support it in the future when flytekit supports it + promise, ok := attr.GetValue().(*core.PromiseAttribute_StringValue) + if !ok { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, + "unexpected attribute type [%T] for value %v", attr.GetValue(), attr.GetValue()) + } + key := promise.StringValue + tmpVal, exist = resolvedVal[key] + if !exist { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "key [%v] does not exist in literal %v", attr.GetStringValue(), currVal) + } + currVal = tmpVal + // list + case []any: + promise, ok := attr.GetValue().(*core.PromiseAttribute_IntValue) + if !ok { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, + "unexpected attribute type [%T] for value %v", attr.GetValue(), attr.GetValue()) + } + index := int(promise.IntValue) // convert to int64 + if index < 0 || index >= len(resolvedVal) { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, + "index [%v] is out of range of %v", index, resolvedVal) + } + currVal = resolvedVal[attr.GetIntValue()] + default: + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "unexpected type [%T] for value %v", currVal, currVal) + } + } + + // Marshal the current value to MessagePack bytes + resolvedBinaryBytes, err := msgpack.Marshal(currVal) + if err != nil { + return nil, err + } + // Construct and return the binary-encoded literal + return constructResolvedBinary(resolvedBinaryBytes, serializationFormat), nil +} + +func constructResolvedBinary(resolvedBinaryBytes []byte, serializationFormat string) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: resolvedBinaryBytes, + Tag: serializationFormat, + }, + }, + }, + }, + } +} + // convertInterfaceToLiteral converts the protobuf struct (e.g. dataclass) to literal func convertInterfaceToLiteral(nodeID string, obj interface{}) (*core.Literal, error) { @@ -141,7 +233,7 @@ func convertInterfaceToLiteral(nodeID string, obj interface{}) (*core.Literal, e return literal, nil } -// convertInterfaceToLiteralScalar converts the a single value to a literal scalar +// convertInterfaceToLiteralScalar converts a single value to a literal scalar func convertInterfaceToLiteralScalar(nodeID string, obj interface{}) (*core.Literal_Scalar, error) { value := &core.Primitive{} diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go index 8724f5287d..1467fc0ea4 100644 --- a/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go @@ -2,8 +2,11 @@ package nodes import ( "context" + "fmt" + "reflect" "testing" + "github.com/shamaton/msgpack/v2" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/structpb" @@ -11,6 +14,49 @@ import ( "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/errors" ) +// FlyteFile and FlyteDirectory represented as map[interface{}]interface{} +type FlyteFile map[interface{}]interface{} +type FlyteDirectory map[interface{}]interface{} + +// InnerDC struct (equivalent to InnerDC dataclass in Python) +type InnerDC struct { + A int `json:"a"` + B float64 `json:"b"` + C string `json:"c"` + D bool `json:"d"` + E []int `json:"e"` + F []FlyteFile `json:"f"` + G [][]int `json:"g"` + H []map[int]bool `json:"h"` + I map[int]bool `json:"i"` + J map[int]FlyteFile `json:"j"` + K map[int][]int `json:"k"` + L map[int]map[int]int `json:"l"` + M map[string]string `json:"m"` + N FlyteFile `json:"n"` + O FlyteDirectory `json:"o"` +} + +// DC struct (equivalent to DC dataclass in Python) +type DC struct { + A int `json:"a"` + B float64 `json:"b"` + C string `json:"c"` + D bool `json:"d"` + E []int `json:"e"` + F []FlyteFile `json:"f"` + G [][]int `json:"g"` + H []map[int]bool `json:"h"` + I map[int]bool `json:"i"` + J map[int]FlyteFile `json:"j"` + K map[int][]int `json:"k"` + L map[int]map[int]int `json:"l"` + M map[string]string `json:"m"` + N FlyteFile `json:"n"` + O FlyteDirectory `json:"o"` + Inner InnerDC `json:"inner_dc"` +} + func NewScalarLiteral(value string) *core.Literal { return &core.Literal{ Value: &core.Literal_Scalar{ @@ -32,7 +78,7 @@ func NewStructFromMap(m map[string]interface{}) *structpb.Struct { return st } -func TestResolveAttrPathIn(t *testing.T) { +func TestResolveAttrPathInStruct(t *testing.T) { args := []struct { literal *core.Literal @@ -52,7 +98,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, @@ -74,7 +120,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_IntValue{ IntValue: 1, }, @@ -95,7 +141,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, @@ -120,12 +166,12 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, }, - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_IntValue{ IntValue: 1, }, @@ -150,7 +196,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, @@ -160,7 +206,7 @@ func TestResolveAttrPathIn(t *testing.T) { Value: &core.Literal_Collection{ Collection: &core.LiteralCollection{ Literals: []*core.Literal{ - &core.Literal{ + { Value: &core.Literal_Collection{ Collection: &core.LiteralCollection{ Literals: []*core.Literal{ @@ -182,11 +228,11 @@ func TestResolveAttrPathIn(t *testing.T) { Value: &core.Literal_Map{ Map: &core.LiteralMap{ Literals: map[string]*core.Literal{ - "foo": &core.Literal{ + "foo": { Value: &core.Literal_Collection{ Collection: &core.LiteralCollection{ Literals: []*core.Literal{ - &core.Literal{ + { Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Generic{ @@ -204,17 +250,17 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, }, - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_IntValue{ IntValue: 0, }, }, - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "bar", }, @@ -223,6 +269,53 @@ func TestResolveAttrPathIn(t *testing.T) { expected: NewScalarLiteral("car"), hasError: false, }, + // - nested map {"foo": {"bar": {"baz": 42}}} + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: NewStructFromMap( + map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": map[string]interface{}{ + "baz": 42, + }, + }, + }, + ), + }, + }, + }, + }, + // Test accessing the entire nested map at foo.bar + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "bar", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: NewStructFromMap( + map[string]interface{}{ + "baz": 42, + }, + ), + }, + }, + }, + }, + hasError: false, + }, // - exception key error with map { literal: &core.Literal{ @@ -235,7 +328,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "random", }, @@ -257,7 +350,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_IntValue{ IntValue: 2, }, @@ -278,7 +371,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "random", }, @@ -303,12 +396,12 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, }, - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_IntValue{ IntValue: 100, }, @@ -329,3 +422,1113 @@ func TestResolveAttrPathIn(t *testing.T) { } } } + +func createNestedDC() DC { + flyteFile := FlyteFile{ + "path": "s3://my-s3-bucket/example.txt", + } + + flyteDirectory := FlyteDirectory{ + "path": "s3://my-s3-bucket/s3_flyte_dir", + } + + // Example of initializing InnerDC + innerDC := InnerDC{ + A: -1, + B: -2.1, + C: "Hello, Flyte", + D: false, + E: []int{0, 1, 2, -1, -2}, + F: []FlyteFile{flyteFile}, + G: [][]int{{0}, {1}, {-1}}, + H: []map[int]bool{{0: false}, {1: true}, {-1: true}}, + I: map[int]bool{0: false, 1: true, -1: false}, + J: map[int]FlyteFile{ + 0: flyteFile, + 1: flyteFile, + -1: flyteFile, + }, + K: map[int][]int{ + 0: {0, 1, -1}, + }, + L: map[int]map[int]int{ + 1: {-1: 0}, + }, + M: map[string]string{ + "key": "value", + }, + N: flyteFile, + O: flyteDirectory, + } + + // Initializing DC + dc := DC{ + A: 1, + B: 2.1, + C: "Hello, Flyte", + D: false, + E: []int{0, 1, 2, -1, -2}, + F: []FlyteFile{flyteFile}, + G: [][]int{{0}, {1}, {-1}}, + H: []map[int]bool{{0: false}, {1: true}, {-1: true}}, + I: map[int]bool{0: false, 1: true, -1: false}, + J: map[int]FlyteFile{ + 0: flyteFile, + 1: flyteFile, + -1: flyteFile, + }, + K: map[int][]int{ + 0: {0, 1, -1}, + }, + L: map[int]map[int]int{ + 1: {-1: 0}, + }, + M: map[string]string{ + "key": "value", + }, + N: flyteFile, + O: flyteDirectory, + Inner: innerDC, + } + return dc +} + +func TestResolveAttrPathInBinary(t *testing.T) { + // Helper function to convert a map to msgpack bytes and then to BinaryIDL + toMsgpackBytes := func(m interface{}) []byte { + msgpackBytes, err := msgpack.Marshal(m) + assert.NoError(t, err) + return msgpackBytes + } + + flyteFile := FlyteFile{ + "path": "s3://my-s3-bucket/example.txt", + } + + flyteDirectory := FlyteDirectory{ + "path": "s3://my-s3-bucket/s3_flyte_dir", + } + + nestedDC := createNestedDC() + literalNestedDC := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(nestedDC), + Tag: "msgpack", + }, + }, + }, + }, + } + + args := []struct { + literal *core.Literal + path []*core.PromiseAttribute + expected *core.Literal + hasError bool + }{ + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "A", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(1), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "B", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(2.1), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "C", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes("Hello, Flyte"), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "D", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(false), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "E", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]int{0, 1, 2, -1, -2}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "F", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]FlyteFile{{"path": "s3://my-s3-bucket/example.txt"}}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "G", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([][]int{{0}, {1}, {-1}}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "H", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]map[int]bool{{0: false}, {1: true}, {-1: true}}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "I", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[int]bool{0: false, 1: true, -1: false}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "J", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes( + map[int]FlyteFile{ + 0: flyteFile, + 1: flyteFile, + -1: flyteFile, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "K", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes( + map[int][]int{ + 0: {0, 1, -1}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "L", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes( + map[int]map[int]int{ + 1: {-1: 0}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "M", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes( + map[string]string{ + "key": "value", + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "N", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(flyteFile), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "O", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(flyteDirectory), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(nestedDC.Inner), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "A", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(-1), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "B", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(-2.1), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "C", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes("Hello, Flyte"), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "D", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(false), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "E", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]int{0, 1, 2, -1, -2}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "F", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]FlyteFile{flyteFile}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "G", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([][]int{{0}, {1}, {-1}}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "G", + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 0, + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]int{0}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "G", + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 2, + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 0, + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(-1), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "H", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]map[int]bool{{0: false}, {1: true}, {-1: true}}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "I", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[int]bool{0: false, 1: true, -1: false}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "J", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[int]FlyteFile{ + 0: flyteFile, + 1: flyteFile, + -1: flyteFile, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "K", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes( + map[int][]int{ + 0: {0, 1, -1}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "L", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes( + map[int]map[int]int{ + 1: {-1: 0}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "M", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes( + map[string]string{ + "key": "value", + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "N", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(flyteFile), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "O", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(flyteDirectory), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + // - exception case with non-existing key in nested map + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": int64(42), + "baz": map[string]interface{}{ + "qux": 3.14, + "quux": "str", + }, + }, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + // Test accessing a non-existing key in the nested map + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "baz", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "unknown", + }, + }, + }, + expected: &core.Literal{}, + hasError: true, + }, + // - exception case with out-of-range index in list + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": []interface{}{int64(42), 3.14, "str"}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + // Test accessing an out-of-range index in the list + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 10, + }, + }, + }, + expected: &core.Literal{}, + hasError: true, + }, + } + + ctx := context.Background() + for i, arg := range args { + resolved, err := resolveAttrPathInPromise(ctx, nil, "", arg.literal, arg.path) + if arg.hasError { + assert.Error(t, err, i) + assert.ErrorContains(t, err, errors.PromiseAttributeResolveError, i) + } else { + var expectedValue, actualValue interface{} + + // Helper function to unmarshal a Binary Literal into an interface{} + unmarshalBinaryLiteral := func(literal *core.Literal) (interface{}, error) { + if scalar, ok := literal.Value.(*core.Literal_Scalar); ok { + if binary, ok := scalar.Scalar.Value.(*core.Scalar_Binary); ok { + var value interface{} + err := msgpack.Unmarshal(binary.Binary.Value, &value) + return value, err + } + } + return nil, fmt.Errorf("literal is not a Binary Scalar") + } + + // Unmarshal the expected value + expectedValue, err := unmarshalBinaryLiteral(arg.expected) + if err != nil { + t.Fatalf("Failed to unmarshal expected value in test case %d: %v", i, err) + } + + // Unmarshal the resolved value + actualValue, err = unmarshalBinaryLiteral(resolved) + if err != nil { + t.Fatalf("Failed to unmarshal resolved value in test case %d: %v", i, err) + } + + // Deeply compare the expected and actual values, ignoring map ordering + if !reflect.DeepEqual(expectedValue, actualValue) { + t.Fatalf("Test case %d: Expected %+v, but got %+v", i, expectedValue, actualValue) + } + } + } +} diff --git a/go.mod b/go.mod index 6c25974da0..cf10a84c8e 100644 --- a/go.mod +++ b/go.mod @@ -166,6 +166,7 @@ require ( github.com/robfig/cron/v3 v3.0.0 // indirect github.com/sendgrid/rest v2.6.9+incompatible // indirect github.com/sendgrid/sendgrid-go v3.10.0+incompatible // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.9.2 // indirect github.com/spf13/cast v1.4.1 // indirect