From 14e581dca563460cb03f188657b883232fc9fe28 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 19 Jul 2023 15:02:26 -0700 Subject: [PATCH 1/5] Check for None values in branch nodes Signed-off-by: Kevin Su --- .../pkg/controller/nodes/branch/comparator.go | 26 ++++++++-- .../pkg/controller/nodes/branch/evaluator.go | 20 +++++--- .../controller/nodes/branch/evaluator_test.go | 51 +++++++++++++++++++ 3 files changed, 87 insertions(+), 10 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/branch/comparator.go b/flytepropeller/pkg/controller/nodes/branch/comparator.go index 4fc4f2224f..57e1d2c02d 100644 --- a/flytepropeller/pkg/controller/nodes/branch/comparator.go +++ b/flytepropeller/pkg/controller/nodes/branch/comparator.go @@ -72,6 +72,16 @@ var perTypeComparators = map[string]comparators{ } func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) { + if lValue == nil || rValue == nil { + switch op { + case core.ComparisonExpression_EQ: + return lValue == rValue, nil + case core.ComparisonExpression_NEQ: + return lValue != rValue, nil + default: + return false, errors.Errorf(ErrorCodeMalformedBranch, "Comparison between nil and non-nil values with operator [%v] is not supported. lVal[%v]:rVal[%v]", op, lValue, rValue) + } + } lValueType := reflect.TypeOf(lValue.Value) rValueType := reflect.TypeOf(rValue.Value) if lValueType != rValueType { @@ -116,24 +126,32 @@ func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.Comparison func Evaluate1(lValue *core.Primitive, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive.") + if rValue.GetScalar().GetNoneType() == nil { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive") + } } return Evaluate(lValue, rValue.GetScalar().GetPrimitive(), op) } func Evaluate2(lValue *core.Literal, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) { if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.") + if lValue.GetScalar().GetNoneType() == nil { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.") + } } return Evaluate(lValue.GetScalar().GetPrimitive(), rValue, op) } func EvaluateLiterals(lValue *core.Literal, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.") + if lValue.GetScalar().GetNoneType() == nil { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.") + } } if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive") + if rValue.GetScalar().GetNoneType() == nil { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive") + } } return Evaluate(lValue.GetScalar().GetPrimitive(), rValue.GetScalar().GetPrimitive(), op) } diff --git a/flytepropeller/pkg/controller/nodes/branch/evaluator.go b/flytepropeller/pkg/controller/nodes/branch/evaluator.go index fe6d7edac5..3ead692406 100644 --- a/flytepropeller/pkg/controller/nodes/branch/evaluator.go +++ b/flytepropeller/pkg/controller/nodes/branch/evaluator.go @@ -24,10 +24,14 @@ func EvaluateComparison(expr *core.ComparisonExpression, nodeInputs *core.Litera var rPrim *core.Primitive if expr.GetLeftValue().GetPrimitive() == nil { - if nodeInputs == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + if len(expr.GetLeftValue().GetVar()) == 0 { + lValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_NoneType{NoneType: &core.Void{}}}}} + } else { + if nodeInputs == nil { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + } + lValue = nodeInputs.Literals[expr.GetLeftValue().GetVar()] } - lValue = nodeInputs.Literals[expr.GetLeftValue().GetVar()] if lValue == nil { return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) } @@ -36,10 +40,14 @@ func EvaluateComparison(expr *core.ComparisonExpression, nodeInputs *core.Litera } if expr.GetRightValue().GetPrimitive() == nil { - if nodeInputs == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + if len(expr.GetRightValue().GetVar()) == 0 { + rValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_NoneType{NoneType: &core.Void{}}}}} + } else { + if nodeInputs == nil { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + } + rValue = nodeInputs.Literals[expr.GetRightValue().GetVar()] } - rValue = nodeInputs.Literals[expr.GetRightValue().GetVar()] if rValue == nil { return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetRightValue().GetVar()) } diff --git a/flytepropeller/pkg/controller/nodes/branch/evaluator_test.go b/flytepropeller/pkg/controller/nodes/branch/evaluator_test.go index 895b731945..852a47c711 100644 --- a/flytepropeller/pkg/controller/nodes/branch/evaluator_test.go +++ b/flytepropeller/pkg/controller/nodes/branch/evaluator_test.go @@ -100,6 +100,57 @@ func TestEvaluateComparison(t *testing.T) { assert.NoError(t, err) assert.False(t, v) }) + t.Run("CompareNoneAndLiteral", func(t *testing.T) { + // Compare lVal -> None and rVal -> literal + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{}, + Operator: core.ComparisonExpression_EQ, + RightValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: coreutils.MustMakePrimitive(1), + }, + }, + } + v, err := EvaluateComparison(exp, nil) + assert.NoError(t, err) + assert.False(t, v) + }) + t.Run("CompareLiteralAndNone", func(t *testing.T) { + // Compare lVal -> literal and rVal -> None + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: coreutils.MustMakePrimitive(1), + }, + }, + Operator: core.ComparisonExpression_NEQ, + RightValue: &core.Operand{}, + } + v, err := EvaluateComparison(exp, nil) + assert.NoError(t, err) + assert.True(t, v) + }) + t.Run("CompareNoneAndNone", func(t *testing.T) { + // Compare lVal -> None and rVal -> None + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{}, + Operator: core.ComparisonExpression_EQ, + RightValue: &core.Operand{}, + } + v, err := EvaluateComparison(exp, nil) + assert.NoError(t, err) + assert.True(t, v) + }) + t.Run("CompareNoneAndNoneWithError", func(t *testing.T) { + // Compare lVal -> None and rVal -> None + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{}, + Operator: core.ComparisonExpression_GTE, + RightValue: &core.Operand{}, + } + _, err := EvaluateComparison(exp, nil) + assert.Error(t, err) + }) t.Run("CompareLiteralAndPrimitive", func(t *testing.T) { // Compare lVal -> literal and rVal -> primitive From c007d39a9606efd2f5af7e6c4f7d53eaf1d5cd0b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 21 Jul 2023 13:51:58 -0700 Subject: [PATCH 2/5] update Signed-off-by: Kevin Su --- flytepropeller/go.mod | 2 + flytepropeller/go.sum | 6 +- .../pkg/compiler/validators/condition.go | 9 ++- .../pkg/controller/nodes/branch/comparator.go | 60 +++++++++---------- .../nodes/branch/comparator_test.go | 24 ++++---- .../pkg/controller/nodes/branch/evaluator.go | 34 +++++++---- .../controller/nodes/branch/evaluator_test.go | 45 ++++++++++++-- 7 files changed, 114 insertions(+), 66 deletions(-) diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 805d6fd1b1..28466e39dc 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -145,3 +145,5 @@ require ( ) replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d + +replace github.com/flyteorg/flyteidl => github.com/flyteorg/flyteidl v1.5.14-0.20230721173646-90caa8bd2d9e diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index eb5a82fbc0..db9915f3bd 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -258,12 +258,10 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.5.10 h1:SHeiaWRt8EAVuFsat+BJswtc07HTZ4DqhfTEYSm621k= -github.com/flyteorg/flyteidl v1.5.10/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= +github.com/flyteorg/flyteidl v1.5.14-0.20230721173646-90caa8bd2d9e h1:TGpPrgo3THN43kA00erR0FgWUfGIKkL8Nd6yT/DWOuw= +github.com/flyteorg/flyteidl v1.5.14-0.20230721173646-90caa8bd2d9e/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= github.com/flyteorg/flyteplugins v1.1.8 h1:UVYdqDdcIqz2JIso+m3MsaPSsTZJZyZQ6Eg7nhX9r/Y= github.com/flyteorg/flyteplugins v1.1.8/go.mod h1:sRxeatEOHq1b9bTxTRNcwoIkVTAVN9dTz8toXkfcz2E= -github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= -github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/flytestdlib v1.0.19 h1:2xY9wBCFUY4UafBkxchPe0EUiRxpjnMNjvomG3W/TfA= github.com/flyteorg/flytestdlib v1.0.19/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= diff --git a/flytepropeller/pkg/compiler/validators/condition.go b/flytepropeller/pkg/compiler/validators/condition.go index a70c5dcb2b..e4f4d6753b 100644 --- a/flytepropeller/pkg/compiler/validators/condition.go +++ b/flytepropeller/pkg/compiler/validators/condition.go @@ -15,6 +15,10 @@ func validateOperand(node c.NodeBuilder, paramName string, operand *flyte.Operan } else if operand.GetPrimitive() != nil { // no validation literalType = literalTypeForPrimitive(operand.GetPrimitive()) + } else if operand.GetScalar().GetPrimitive() != nil { + literalType = literalTypeForPrimitive(operand.GetPrimitive()) + } else if operand.GetScalar().GetNoneType() != nil { + literalType = &flyte.LiteralType{Type: &flyte.LiteralType_Simple{Simple: flyte.SimpleType_NONE}} } else if len(operand.GetVar()) > 0 { if node.GetInterface() != nil { if param, paramOk := validateInputVar(node, operand.GetVar(), requireParamType, errs.NewScope()); paramOk { @@ -41,7 +45,10 @@ func ValidateBooleanExpression(w c.WorkflowBuilder, node c.NodeBuilder, expr *fl expr.GetComparison().GetRightValue(), requireParamType, errs.NewScope()) op2Type, op2Valid := validateOperand(node, "LeftValue", expr.GetComparison().GetLeftValue(), requireParamType, errs.NewScope()) - if op1Valid && op2Valid && op1Type != nil && op2Type != nil { + // Valid expression + // 1. Both operands are primitive types and have the same types. + // 2. One of the operands is the None type. + if op1Valid && op2Valid && op1Type != nil && op2Type != nil && op1Type.GetSimple() != flyte.SimpleType_NONE && op2Type.GetSimple() != flyte.SimpleType_NONE { if op1Type.String() != op2Type.String() { errs.Collect(errors.NewMismatchingTypesErr(node.GetId(), "RightValue", op1Type.String(), op2Type.String())) diff --git a/flytepropeller/pkg/controller/nodes/branch/comparator.go b/flytepropeller/pkg/controller/nodes/branch/comparator.go index 57e1d2c02d..0d15968d01 100644 --- a/flytepropeller/pkg/controller/nodes/branch/comparator.go +++ b/flytepropeller/pkg/controller/nodes/branch/comparator.go @@ -71,19 +71,21 @@ var perTypeComparators = map[string]comparators{ }, } -func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) { - if lValue == nil || rValue == nil { +func Evaluate(lValue *core.Scalar, rValue *core.Scalar, op core.ComparisonExpression_Operator) (bool, error) { + if lValue.GetNoneType() != nil || rValue.GetNoneType() != nil { + lIsNone := lValue.GetNoneType() != nil + rIsNone := rValue.GetNoneType() != nil switch op { case core.ComparisonExpression_EQ: - return lValue == rValue, nil + return lIsNone == rIsNone, nil case core.ComparisonExpression_NEQ: - return lValue != rValue, nil + return lIsNone != rIsNone, nil default: return false, errors.Errorf(ErrorCodeMalformedBranch, "Comparison between nil and non-nil values with operator [%v] is not supported. lVal[%v]:rVal[%v]", op, lValue, rValue) } } - lValueType := reflect.TypeOf(lValue.Value) - rValueType := reflect.TypeOf(rValue.Value) + lValueType := reflect.TypeOf(lValue.GetPrimitive().Value) + rValueType := reflect.TypeOf(rValue.GetPrimitive().Value) if lValueType != rValueType { return false, errors.Errorf(ErrorCodeMalformedBranch, "Comparison between different primitives types. lVal[%v]:rVal[%v]", lValueType, rValueType) } @@ -100,58 +102,50 @@ func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.Comparison if isBoolean { return false, errors.Errorf(ErrorCodeMalformedBranch, "[GT] not defined for boolean operands.") } - return comps.gt(lValue, rValue), nil + return comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil case core.ComparisonExpression_GTE: if isBoolean { return false, errors.Errorf(ErrorCodeMalformedBranch, "[GTE] not defined for boolean operands.") } - return comps.eq(lValue, rValue) || comps.gt(lValue, rValue), nil + return comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()) || comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil case core.ComparisonExpression_LT: if isBoolean { return false, errors.Errorf(ErrorCodeMalformedBranch, "[LT] not defined for boolean operands.") } - return !(comps.gt(lValue, rValue) || comps.eq(lValue, rValue)), nil + return !(comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()) || comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive())), nil case core.ComparisonExpression_LTE: if isBoolean { return false, errors.Errorf(ErrorCodeMalformedBranch, "[LTE] not defined for boolean operands.") } - return !comps.gt(lValue, rValue), nil + return !comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil case core.ComparisonExpression_EQ: - return comps.eq(lValue, rValue), nil + return comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()), nil case core.ComparisonExpression_NEQ: - return !comps.eq(lValue, rValue), nil + return !comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()), nil } return false, errors.Errorf(ErrorCodeMalformedBranch, "Unsupported operator type in Propeller. System error.") } -func Evaluate1(lValue *core.Primitive, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { - if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil { - if rValue.GetScalar().GetNoneType() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive") - } +func Evaluate1(lValue *core.Scalar, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { + if rValue.GetScalar() == nil || (rValue.GetScalar().GetPrimitive() == nil && rValue.GetScalar().GetNoneType() == nil) { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive") } - return Evaluate(lValue, rValue.GetScalar().GetPrimitive(), op) + return Evaluate(lValue, rValue.GetScalar(), op) } -func Evaluate2(lValue *core.Literal, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) { - if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil { - if lValue.GetScalar().GetNoneType() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.") - } +func Evaluate2(lValue *core.Literal, rValue *core.Scalar, op core.ComparisonExpression_Operator) (bool, error) { + if lValue.GetScalar() == nil || (lValue.GetScalar().GetPrimitive() == nil && lValue.GetScalar().GetNoneType() == nil) { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable [%v] is non primitive.", lValue) } - return Evaluate(lValue.GetScalar().GetPrimitive(), rValue, op) + return Evaluate(lValue.GetScalar(), rValue, op) } func EvaluateLiterals(lValue *core.Literal, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { - if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil { - if lValue.GetScalar().GetNoneType() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.") - } + if lValue.GetScalar() == nil || (lValue.GetScalar().GetPrimitive() == nil && lValue.GetScalar().GetNoneType() == nil) { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable [%v] is non primitive.", lValue) } - if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil { - if rValue.GetScalar().GetNoneType() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive") - } + if rValue.GetScalar() == nil || (rValue.GetScalar().GetPrimitive() == nil && rValue.GetScalar().GetNoneType() == nil) { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive") } - return Evaluate(lValue.GetScalar().GetPrimitive(), rValue.GetScalar().GetPrimitive(), op) + return Evaluate(lValue.GetScalar(), rValue.GetScalar(), op) } diff --git a/flytepropeller/pkg/controller/nodes/branch/comparator_test.go b/flytepropeller/pkg/controller/nodes/branch/comparator_test.go index d1c120ef4c..21f60272ac 100644 --- a/flytepropeller/pkg/controller/nodes/branch/comparator_test.go +++ b/flytepropeller/pkg/controller/nodes/branch/comparator_test.go @@ -11,8 +11,8 @@ import ( ) func TestEvaluate_int(t *testing.T) { - p1 := coreutils.MustMakePrimitive(1) - p2 := coreutils.MustMakePrimitive(2) + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(2)}} { // p1 > p2 = false b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) @@ -82,8 +82,8 @@ func TestEvaluate_int(t *testing.T) { } func TestEvaluate_float(t *testing.T) { - p1 := coreutils.MustMakePrimitive(1.0) - p2 := coreutils.MustMakePrimitive(2.0) + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(2)}} { // p1 > p2 = false b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) @@ -153,8 +153,8 @@ func TestEvaluate_float(t *testing.T) { } func TestEvaluate_string(t *testing.T) { - p1 := coreutils.MustMakePrimitive("a") - p2 := coreutils.MustMakePrimitive("b") + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive("a")}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive("b")}} { // p1 > p2 = false b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) @@ -224,8 +224,8 @@ func TestEvaluate_string(t *testing.T) { } func TestEvaluate_datetime(t *testing.T) { - p1 := coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 00, 00, time.UTC)) - p2 := coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 01, 00, time.UTC)) + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 00, 00, time.UTC))}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 01, 00, time.UTC))}} { // p1 > p2 = false b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) @@ -295,8 +295,8 @@ func TestEvaluate_datetime(t *testing.T) { } func TestEvaluate_duration(t *testing.T) { - p1 := coreutils.MustMakePrimitive(10 * time.Second) - p2 := coreutils.MustMakePrimitive(11 * time.Second) + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(10 * time.Second)}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(11 * time.Second)}} { // p1 > p2 = false b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) @@ -366,8 +366,8 @@ func TestEvaluate_duration(t *testing.T) { } func TestEvaluate_boolean(t *testing.T) { - p1 := coreutils.MustMakePrimitive(true) - p2 := coreutils.MustMakePrimitive(false) + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(true)}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(false)}} f := func(op core.ComparisonExpression_Operator) { // GT/LT = false msg := fmt.Sprintf("Evaluating: [%s]", op.String()) diff --git a/flytepropeller/pkg/controller/nodes/branch/evaluator.go b/flytepropeller/pkg/controller/nodes/branch/evaluator.go index 3ead692406..e81a127fb9 100644 --- a/flytepropeller/pkg/controller/nodes/branch/evaluator.go +++ b/flytepropeller/pkg/controller/nodes/branch/evaluator.go @@ -20,39 +20,53 @@ const ErrorCodeFailedFetchOutputs = "FailedFetchOutputs" func EvaluateComparison(expr *core.ComparisonExpression, nodeInputs *core.LiteralMap) (bool, error) { var lValue *core.Literal var rValue *core.Literal - var lPrim *core.Primitive - var rPrim *core.Primitive + var lPrim *core.Scalar + var rPrim *core.Scalar if expr.GetLeftValue().GetPrimitive() == nil { - if len(expr.GetLeftValue().GetVar()) == 0 { - lValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_NoneType{NoneType: &core.Void{}}}}} + if expr.GetLeftValue().GetScalar().GetNoneType() != nil { + lValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: expr.GetLeftValue().GetScalar()}} + } else if expr.GetLeftValue().GetScalar().GetUnion() != nil { + lValue = expr.GetLeftValue().GetScalar().GetUnion().GetValue() } else { if nodeInputs == nil { return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) } - lValue = nodeInputs.Literals[expr.GetLeftValue().GetVar()] + input := nodeInputs.Literals[expr.GetLeftValue().GetVar()] + if input.GetScalar().GetUnion().GetValue() != nil { + lValue = input.GetScalar().GetUnion().GetValue() + } else { + lValue = input + } } if lValue == nil { return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) } } else { - lPrim = expr.GetLeftValue().GetPrimitive() + lPrim = &core.Scalar{Value: &core.Scalar_Primitive{Primitive: expr.GetLeftValue().GetPrimitive()}} } if expr.GetRightValue().GetPrimitive() == nil { - if len(expr.GetRightValue().GetVar()) == 0 { - rValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_NoneType{NoneType: &core.Void{}}}}} + if expr.GetRightValue().GetScalar().GetNoneType() != nil { + rValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: expr.GetRightValue().GetScalar()}} + } else if expr.GetRightValue().GetScalar().GetUnion() != nil { + rValue = expr.GetRightValue().GetScalar().GetUnion().GetValue() } else { if nodeInputs == nil { return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) } - rValue = nodeInputs.Literals[expr.GetRightValue().GetVar()] + input := nodeInputs.Literals[expr.GetRightValue().GetVar()] + if input.GetScalar().GetUnion().GetValue() != nil { + rValue = input.GetScalar().GetUnion().GetValue() + } else { + rValue = input + } } if rValue == nil { return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetRightValue().GetVar()) } } else { - rPrim = expr.GetRightValue().GetPrimitive() + rPrim = &core.Scalar{Value: &core.Scalar_Primitive{Primitive: expr.GetRightValue().GetPrimitive()}} } if lValue != nil && rValue != nil { diff --git a/flytepropeller/pkg/controller/nodes/branch/evaluator_test.go b/flytepropeller/pkg/controller/nodes/branch/evaluator_test.go index 852a47c711..a64031ae24 100644 --- a/flytepropeller/pkg/controller/nodes/branch/evaluator_test.go +++ b/flytepropeller/pkg/controller/nodes/branch/evaluator_test.go @@ -56,6 +56,16 @@ func createUnaryConjunction(l *core.ComparisonExpression, op core.ConjunctionExp } } +func getNoneOperand() *core.Operand { + return &core.Operand{ + Val: &core.Operand_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_NoneType{NoneType: &core.Void{}}, + }, + }, + } +} + func TestEvaluateComparison(t *testing.T) { t.Run("ComparePrimitives", func(t *testing.T) { // Compare primitives @@ -103,7 +113,7 @@ func TestEvaluateComparison(t *testing.T) { t.Run("CompareNoneAndLiteral", func(t *testing.T) { // Compare lVal -> None and rVal -> literal exp := &core.ComparisonExpression{ - LeftValue: &core.Operand{}, + LeftValue: getNoneOperand(), Operator: core.ComparisonExpression_EQ, RightValue: &core.Operand{ Val: &core.Operand_Primitive{ @@ -124,7 +134,30 @@ func TestEvaluateComparison(t *testing.T) { }, }, Operator: core.ComparisonExpression_NEQ, - RightValue: &core.Operand{}, + RightValue: getNoneOperand(), + } + v, err := EvaluateComparison(exp, nil) + assert.NoError(t, err) + assert.True(t, v) + }) + t.Run("CompareUnionLiteralAndNone", func(t *testing.T) { + // Compare lVal -> literal and rVal -> None + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Union{ + Union: &core.Union{ + Value: &core.Literal{ + Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}}}, + }, + }, + }, + }, + }, + }, + Operator: core.ComparisonExpression_NEQ, + RightValue: getNoneOperand(), } v, err := EvaluateComparison(exp, nil) assert.NoError(t, err) @@ -133,9 +166,9 @@ func TestEvaluateComparison(t *testing.T) { t.Run("CompareNoneAndNone", func(t *testing.T) { // Compare lVal -> None and rVal -> None exp := &core.ComparisonExpression{ - LeftValue: &core.Operand{}, + LeftValue: getNoneOperand(), Operator: core.ComparisonExpression_EQ, - RightValue: &core.Operand{}, + RightValue: getNoneOperand(), } v, err := EvaluateComparison(exp, nil) assert.NoError(t, err) @@ -144,9 +177,9 @@ func TestEvaluateComparison(t *testing.T) { t.Run("CompareNoneAndNoneWithError", func(t *testing.T) { // Compare lVal -> None and rVal -> None exp := &core.ComparisonExpression{ - LeftValue: &core.Operand{}, + LeftValue: getNoneOperand(), Operator: core.ComparisonExpression_GTE, - RightValue: &core.Operand{}, + RightValue: getNoneOperand(), } _, err := EvaluateComparison(exp, nil) assert.Error(t, err) From 926b0b44550683f0c612ee88a716e62e57b8c060 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 21 Jul 2023 13:57:53 -0700 Subject: [PATCH 3/5] nit Signed-off-by: Kevin Su --- flytepropeller/pkg/controller/nodes/branch/comparator.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/branch/comparator.go b/flytepropeller/pkg/controller/nodes/branch/comparator.go index 0d15968d01..5a05b2001f 100644 --- a/flytepropeller/pkg/controller/nodes/branch/comparator.go +++ b/flytepropeller/pkg/controller/nodes/branch/comparator.go @@ -128,7 +128,7 @@ func Evaluate(lValue *core.Scalar, rValue *core.Scalar, op core.ComparisonExpres func Evaluate1(lValue *core.Scalar, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { if rValue.GetScalar() == nil || (rValue.GetScalar().GetPrimitive() == nil && rValue.GetScalar().GetNoneType() == nil) { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive") + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable [%v] is non primitive", rValue) } return Evaluate(lValue, rValue.GetScalar(), op) } @@ -145,7 +145,7 @@ func EvaluateLiterals(lValue *core.Literal, rValue *core.Literal, op core.Compar return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable [%v] is non primitive.", lValue) } if rValue.GetScalar() == nil || (rValue.GetScalar().GetPrimitive() == nil && rValue.GetScalar().GetNoneType() == nil) { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive") + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable [%v] is non primitive", rValue) } return Evaluate(lValue.GetScalar(), rValue.GetScalar(), op) } From a3a6a3d65884a0fa408d579eabc5216c594b6521 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 16 Aug 2023 10:55:03 -0700 Subject: [PATCH 4/5] updated idl Signed-off-by: Kevin Su --- flytepropeller/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 28466e39dc..f8c73a472e 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -6,7 +6,7 @@ require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.13.0 - github.com/flyteorg/flyteidl v1.5.10 + github.com/flyteorg/flyteidl v1.5.16 github.com/flyteorg/flyteplugins v1.1.8 github.com/flyteorg/flytestdlib v1.0.19 github.com/ghodss/yaml v1.0.0 From 3b20c5af4c54003948c72a1037a822fc1c48b937 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 16 Aug 2023 10:57:34 -0700 Subject: [PATCH 5/5] updated idl Signed-off-by: Kevin Su --- flytepropeller/go.mod | 2 -- flytepropeller/go.sum | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 5b0e821166..602c8cb129 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -146,5 +146,3 @@ require ( ) replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d - -replace github.com/flyteorg/flyteidl => github.com/flyteorg/flyteidl v1.5.14-0.20230721173646-90caa8bd2d9e diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index cb444252d7..d39bddecf0 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -242,8 +242,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.5.14-0.20230721173646-90caa8bd2d9e h1:TGpPrgo3THN43kA00erR0FgWUfGIKkL8Nd6yT/DWOuw= -github.com/flyteorg/flyteidl v1.5.14-0.20230721173646-90caa8bd2d9e/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= +github.com/flyteorg/flyteidl v1.5.16 h1:S70wD7K99nKHZxmo8U16Jjhy1kZwoBh5ZQhZf3/6MPU= +github.com/flyteorg/flyteidl v1.5.16/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= github.com/flyteorg/flyteplugins v1.1.16 h1:DIQxPERFMvTGnLTkkeG9R8STF3YMvxK1nPtFf+a6o5Q= github.com/flyteorg/flyteplugins v1.1.16/go.mod h1:HEd4yf0H8XfxMcHFwrTdTIJ/9lEAz83OpgcFQe47L6I= github.com/flyteorg/flytestdlib v1.0.22 h1:8RAc+TmME54FInf4+t6+C7X8Z/dW6i6aTs6W8SEzpI8=