From 1dcfbc488c6d7e30bbe06f972da4779993279542 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 7 Aug 2024 09:52:53 +0800 Subject: [PATCH] support attribute access Signed-off-by: Future-Outlier --- flytepropeller/go.mod | 2 + flytepropeller/go.sum | 4 + .../controller/nodes/attr_path_resolver.go | 97 +++++++++++++++++-- go.mod | 3 + go.sum | 4 + 5 files changed, 104 insertions(+), 6 deletions(-) diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 5d828f9e9b..358e4c7966 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -25,6 +25,7 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 + github.com/vmihailenco/msgpack/v5 v5.4.1 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.47.0 go.opentelemetry.io/otel v1.24.0 go.opentelemetry.io/otel/trace v1.24.0 @@ -123,6 +124,7 @@ require ( github.com/spf13/viper v1.11.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.2.0 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 // indirect go.opentelemetry.io/otel/exporters/jaeger v1.17.0 // indirect diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index 8bbdd06eba..cba516732a 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -407,6 +407,10 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go index 42150cb887..79d6ef8b6f 100644 --- a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go @@ -1,10 +1,12 @@ package nodes import ( - "google.golang.org/protobuf/types/known/structpb" - + "encoding/json" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/errors" + "github.com/vmihailenco/msgpack/v5" + "google.golang.org/protobuf/types/known/structpb" + "strings" ) // resolveAttrPathInPromise resolves the literal with attribute path @@ -12,7 +14,6 @@ import ( func resolveAttrPathInPromise(nodeID string, literal *core.Literal, bindAttrPath []*core.PromiseAttribute) (*core.Literal, error) { var currVal *core.Literal = literal var tmpVal *core.Literal - var err error var exist bool count := 0 @@ -38,13 +39,19 @@ func resolveAttrPathInPromise(nodeID string, literal *core.Literal, bindAttrPath } // resolve dataclass - if currVal.GetScalar() != nil && currVal.GetScalar().GetGeneric() != nil { - st := currVal.GetScalar().GetGeneric() + if scalar := currVal.GetScalar(); scalar != nil { // start from index "count" - currVal, err = resolveAttrPathInPbStruct(nodeID, st, bindAttrPath[count:]) + var err error + + if json := scalar.GetJson(); json != nil { + currVal, err = resolveAttrPathInJson(nodeID, json.GetValue(), bindAttrPath[count:]) + } else if generic := scalar.GetGeneric(); generic != nil { + currVal, err = resolveAttrPathInPbStruct(nodeID, generic, bindAttrPath[count:]) + } if err != nil { return nil, err } + } return currVal, nil @@ -84,6 +91,82 @@ func resolveAttrPathInPbStruct(nodeID string, st *structpb.Struct, bindAttrPath return literal, err } +// resolveAttrPathInJson resolves the msgpack bytes (e.g. dataclass) with attribute path +func resolveAttrPathInJson(nodeID string, json_byte []byte, bindAttrPath []*core.PromiseAttribute) (*core.Literal, + error) { + + var currVal interface{} + var tmpVal interface{} + var exist bool + var jsonStr string + + err := msgpack.Unmarshal(json_byte, &jsonStr) + if err != nil { + return nil, err + } + + // Golang has problem with unmarshalling integer as float64 + // reference: https://stackoverflow.com/questions/22343083/json-unmarshaling-with-long-numbers-gives-floating-point-number + + decoder := json.NewDecoder(strings.NewReader(jsonStr)) + decoder.UseNumber() + err = decoder.Decode(&tmpVal) + if err != nil { + return nil, err + } + currVal = convertNumbers(tmpVal) + + // Turn the current value to a map so it can be resolved more easily + for _, attr := range bindAttrPath { + switch resolvedVal := currVal.(type) { + // map + case map[string]interface{}: + tmpVal, exist = resolvedVal[attr.GetStringValue()] + if !exist { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "key [%v] does not exist in literal %v", attr.GetStringValue(), currVal) + } + currVal = tmpVal + // list + case []interface{}: + if int(attr.GetIntValue()) >= len(resolvedVal) { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "index [%v] is out of range of %v", attr.GetIntValue(), currVal) + } + currVal = resolvedVal[attr.GetIntValue()] + } + } + + // After resolve, convert the interface to literal + literal, err := convertInterfaceToLiteral(nodeID, currVal) + + return literal, err +} + +// convertNumbers recursively converts json.Number to int64 or float64 +func convertNumbers(v interface{}) interface{} { + switch vv := v.(type) { + case map[string]interface{}: + for key, value := range vv { + vv[key] = convertNumbers(value) + } + return vv + case []interface{}: + for i, value := range vv { + vv[i] = convertNumbers(value) + } + return vv + case json.Number: + // Try to convert to int64 first + if intVal, err := vv.Int64(); err == nil { + return intVal + } + // If it fails, fall back to float64 + if floatVal, err := vv.Float64(); err == nil { + return floatVal + } + } + return v +} + // convertInterfaceToLiteral converts the protobuf struct (e.g. dataclass) to literal func convertInterfaceToLiteral(nodeID string, obj interface{}) (*core.Literal, error) { @@ -137,6 +220,8 @@ func convertInterfaceToLiteralScalar(nodeID string, obj interface{}) (*core.Lite value.Value = &core.Primitive_StringValue{StringValue: obj} case int: value.Value = &core.Primitive_Integer{Integer: int64(obj)} + case int64: + value.Value = &core.Primitive_Integer{Integer: obj} case float64: value.Value = &core.Primitive_FloatValue{FloatValue: obj} case bool: diff --git a/go.mod b/go.mod index 3a7098d3c0..d0709feb20 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( golang.org/x/sync v0.7.0 gorm.io/driver/postgres v1.5.3 sigs.k8s.io/controller-runtime v0.16.3 + ) require ( @@ -177,6 +178,8 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect; indirects + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/wI2L/jsondiff v0.5.0 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.47.0 // indirect diff --git a/go.sum b/go.sum index 05db1b9c1c..f2b201cdeb 100644 --- a/go.sum +++ b/go.sum @@ -1331,6 +1331,10 @@ github.com/unrolled/secure v0.0.0-20181005190816-ff9db2ff917f/go.mod h1:mnPT77IA github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/wI2L/jsondiff v0.5.0 h1:RRMTi/mH+R2aXcPe1VYyvGINJqQfC3R+KSEakuU1Ikw= github.com/wI2L/jsondiff v0.5.0/go.mod h1:qqG6hnK0Lsrz2BpIVCxWiK9ItsBCpIZQiv0izJjOZ9s= github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I=