diff --git a/flytecopilot/go.mod b/flytecopilot/go.mod index d943bb5153..a8071b5a8a 100644 --- a/flytecopilot/go.mod +++ b/flytecopilot/go.mod @@ -82,6 +82,7 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.53.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.8.2 // indirect github.com/spf13/cast v1.4.1 // indirect diff --git a/flytecopilot/go.sum b/flytecopilot/go.sum index b1f65b79e1..8f33fe7002 100644 --- a/flytecopilot/go.sum +++ b/flytecopilot/go.sum @@ -309,6 +309,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= diff --git a/flyteidl/clients/go/coreutils/extract_literal_test.go b/flyteidl/clients/go/coreutils/extract_literal_test.go index 66b20439c2..0cd4c2fb16 100644 --- a/flyteidl/clients/go/coreutils/extract_literal_test.go +++ b/flyteidl/clients/go/coreutils/extract_literal_test.go @@ -4,6 +4,7 @@ package coreutils import ( + "os" "testing" "time" @@ -125,6 +126,7 @@ func TestFetchLiteral(t *testing.T) { }) t.Run("Generic", func(t *testing.T) { + os.Setenv(FlyteUseOldDcFormat, "true") literalVal := map[string]interface{}{ "x": 1, "y": "ystringvalue", @@ -150,6 +152,7 @@ func TestFetchLiteral(t *testing.T) { for key, val := range expectedStructVal.Fields { assert.Equal(t, val.Kind, extractedStructValue.Fields[key].Kind) } + os.Unsetenv(FlyteUseOldDcFormat) }) t.Run("Generic Passed As String", func(t *testing.T) { diff --git a/flyteidl/clients/go/coreutils/literals.go b/flyteidl/clients/go/coreutils/literals.go index 2bb789b423..6f292d7118 100644 --- a/flyteidl/clients/go/coreutils/literals.go +++ b/flyteidl/clients/go/coreutils/literals.go @@ -5,20 +5,24 @@ import ( "encoding/json" "fmt" "math" + "os" "reflect" "strconv" "strings" "time" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flytestdlib/storage" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/ptypes" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" + "github.com/shamaton/msgpack/v2" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytestdlib/storage" ) const MESSAGEPACK = "msgpack" +const FlyteUseOldDcFormat = "FLYTE_USE_OLD_DC_FORMAT" func MakePrimitive(v interface{}) (*core.Primitive, error) { switch p := v.(type) { @@ -561,12 +565,32 @@ func MakeLiteralForType(t *core.LiteralType, v interface{}) (*core.Literal, erro strValue = fmt.Sprintf("%.0f", math.Trunc(f)) } if newT.Simple == core.SimpleType_STRUCT { + useOldFormat := strings.ToLower(os.Getenv(FlyteUseOldDcFormat)) if _, isValueStringType := v.(string); !isValueStringType { - byteValue, err := json.Marshal(v) - if err != nil { - return nil, fmt.Errorf("unable to marshal to json string for struct value %v", v) + if useOldFormat == "1" || useOldFormat == "t" || useOldFormat == "true" { + byteValue, err := json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("unable to marshal to json string for struct value %v", v) + } + strValue = string(byteValue) + } else { + byteValue, err := msgpack.Marshal(v) + if err != nil { + return nil, fmt.Errorf("unable to marshal to msgpack bytes for struct value %v", v) + } + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: byteValue, + Tag: MESSAGEPACK, + }, + }, + }, + }, + }, nil } - strValue = string(byteValue) } } lv, err := MakeLiteralForSimpleType(newT.Simple, strValue) diff --git a/flyteidl/clients/go/coreutils/literals_test.go b/flyteidl/clients/go/coreutils/literals_test.go index 3b5daf4b27..f2d8c9e5b2 100644 --- a/flyteidl/clients/go/coreutils/literals_test.go +++ b/flyteidl/clients/go/coreutils/literals_test.go @@ -5,6 +5,7 @@ package coreutils import ( "fmt" + "os" "reflect" "strconv" "testing" @@ -14,6 +15,7 @@ import ( "github.com/golang/protobuf/ptypes" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" + "github.com/shamaton/msgpack/v2" "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -455,6 +457,7 @@ func TestMakeLiteralForType(t *testing.T) { }) t.Run("Generic", func(t *testing.T) { + os.Setenv(FlyteUseOldDcFormat, "true") literalVal := map[string]interface{}{ "x": 1, "y": "ystringvalue", @@ -480,6 +483,69 @@ func TestMakeLiteralForType(t *testing.T) { for key, val := range expectedStructVal.Fields { assert.Equal(t, val.Kind, extractedStructValue.Fields[key].Kind) } + os.Unsetenv(FlyteUseOldDcFormat) + }) + + t.Run("SimpleBinary", func(t *testing.T) { + // We compare the deserialized values instead of the raw msgpack bytes because Go does not guarantee the order + // of map keys during serialization. This means that while the serialized bytes may differ, the deserialized + // values should be logically equivalent. + + var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} + v := map[string]interface{}{ + "a": int64(1), + "b": 3.14, + "c": "example_string", + "d": map[string]interface{}{ + "1": int64(100), + "2": int64(200), + }, + "e": map[string]interface{}{ + "a": int64(1), + "b": 3.14, + }, + "f": []string{"a", "b", "c"}, + } + + val, err := MakeLiteralForType(literalType, v) + assert.NoError(t, err) + + msgpackBytes, err := msgpack.Marshal(v) + assert.NoError(t, err) + + literalVal := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: msgpackBytes, + Tag: MESSAGEPACK, + }, + }, + }, + }, + } + + expectedLiteralVal, err := ExtractFromLiteral(literalVal) + assert.NoError(t, err) + actualLiteralVal, err := ExtractFromLiteral(val) + assert.NoError(t, err) + + // Check if the extracted value is of type *core.Binary (not []byte) + expectedBinary, ok := expectedLiteralVal.(*core.Binary) + assert.True(t, ok, "expectedLiteralVal is not of type *core.Binary") + actualBinary, ok := actualLiteralVal.(*core.Binary) + assert.True(t, ok, "actualLiteralVal is not of type *core.Binary") + + // Now check if the Binary values match + var expectedVal, actualVal map[string]interface{} + err = msgpack.Unmarshal(expectedBinary.Value, &expectedVal) + assert.NoError(t, err) + err = msgpack.Unmarshal(actualBinary.Value, &actualVal) + assert.NoError(t, err) + + // Finally, assert that the deserialized values are equal + assert.Equal(t, expectedVal, actualVal) }) t.Run("ArrayStrings", func(t *testing.T) { diff --git a/flyteidl/go.mod b/flyteidl/go.mod index 4c913dcb4d..0da94cea32 100644 --- a/flyteidl/go.mod +++ b/flyteidl/go.mod @@ -13,6 +13,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/pkg/errors v0.9.1 + github.com/shamaton/msgpack/v2 v2.2.2 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 golang.org/x/net v0.27.0 diff --git a/flyteidl/go.sum b/flyteidl/go.sum index f440e247e9..b398d5d02f 100644 --- a/flyteidl/go.sum +++ b/flyteidl/go.sum @@ -215,6 +215,8 @@ github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoG github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=