Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Jul 18, 2024
1 parent 0cddc15 commit d2e2979
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 15 deletions.
2 changes: 1 addition & 1 deletion flyteadmin/pkg/manager/impl/workflow_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (w *WorkflowManager) CreateWorkflow(
if err != nil {
logger.Errorf(ctx, "Failed to compile workflow with err: %v", err)
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"failed to compile workflow for [%+v] with err %v", request.Id, err)
"failed to compile workflow for [%+v] with err: %v", request.Id, err)
}
err = validation.ValidateCompiledWorkflow(
*request.Id, workflowClosure, w.config.RegistrationValidationConfiguration())
Expand Down
2 changes: 1 addition & 1 deletion flyteadmin/pkg/workflowengine/impl/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (c *workflowCompiler) CompileWorkflow(

compiledWorkflowClosure, err := compiler.CompileWorkflow(primaryWf, subworkflows, tasks, launchPlans)
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to compile workflow with err %v", err)
return nil, errors.NewFlyteAdminError(codes.InvalidArgument, err.Error())
}
return compiledWorkflowClosure, nil
}
Expand Down
4 changes: 2 additions & 2 deletions flytepropeller/pkg/compiler/errors/compiler_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ func NewDuplicateIDFoundErr(nodeID string) *CompileError {
)
}

func NewMismatchingTypesErr(nodeID, fromVar, fromType, toType string) *CompileError {
func NewMismatchingTypesErr(nodeID, fromVar, fromType, toVar, toType string) *CompileError {
return newError(
MismatchingTypes,
fmt.Sprintf("Variable [%v] (type [%v]) doesn't match expected type [%v].", fromVar, fromType,
fmt.Sprintf("The output variable '%v' has type [%v], but it's assigned to the input variable '%v' which has type type [%v].", fromVar, fromType, toVar,
toType),
nodeID,
)
Expand Down
2 changes: 1 addition & 1 deletion flytepropeller/pkg/compiler/transformers/k8s/inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs cor

inputType := validators.LiteralTypeForLiteral(inputVal)
if !validators.AreTypesCastable(inputType, v.Type) {
errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String()))
errs.Collect(errors.NewMismatchingTypesErr(nodeID, "", inputVar, v.Type.String(), inputType.String()))
continue
}

Expand Down
22 changes: 14 additions & 8 deletions flytepropeller/pkg/compiler/validators/bindings.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package validators

import (
"fmt"
"reflect"

"k8s.io/apimachinery/pkg/util/sets"
Expand All @@ -11,9 +12,10 @@ import (
"github.com/flyteorg/flyte/flytepropeller/pkg/compiler/typing"
)

func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, binding *flyte.BindingData,
func validateBinding(w c.WorkflowBuilder, node c.Node, nodeParam string, binding *flyte.BindingData,
expectedType *flyte.LiteralType, errs errors.CompileErrors, validateParamTypes bool) (
resolvedType *flyte.LiteralType, upstreamNodes []c.NodeID, ok bool) {
nodeID := node.GetId()

// Non-scalar bindings will fail to introspect the type through a union type so we resolve them beforehand
switch binding.GetValue().(type) {
Expand All @@ -31,7 +33,7 @@ func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, bin
var ok bool

for _, t := range expectedType.GetUnionType().GetVariants() {
resolvedType1, nodeIds1, ok1 := validateBinding(w, nodeID, nodeParam, binding, t, errors.NewCompileErrors(), validateParamTypes)
resolvedType1, nodeIds1, ok1 := validateBinding(w, node, nodeParam, binding, t, errors.NewCompileErrors(), validateParamTypes)
if ok1 {
if ok {
errs.Collect(errors.NewAmbiguousBindingUnionValue(nodeID, nodeParam, expectedType.String(), binding.String(), matchingType.String(), t.String()))
Expand Down Expand Up @@ -63,7 +65,7 @@ func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, bin
allNodeIds := make([]c.NodeID, 0, len(val.Collection.GetBindings()))
var subType *flyte.LiteralType
for _, v := range val.Collection.GetBindings() {
if resolvedType, nodeIds, ok := validateBinding(w, nodeID, nodeParam, v, expectedType.GetCollectionType(), errs.NewScope(), validateParamTypes); ok {
if resolvedType, nodeIds, ok := validateBinding(w, node, nodeParam, v, expectedType.GetCollectionType(), errs.NewScope(), validateParamTypes); ok {
allNodeIds = append(allNodeIds, nodeIds...)
subType = resolvedType
}
Expand All @@ -87,7 +89,7 @@ func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, bin
allNodeIds := make([]c.NodeID, 0, len(val.Map.GetBindings()))
var subType *flyte.LiteralType
for _, v := range val.Map.GetBindings() {
if resolvedType, nodeIds, ok := validateBinding(w, nodeID, nodeParam, v, expectedType.GetMapValueType(), errs.NewScope(), validateParamTypes); ok {
if resolvedType, nodeIds, ok := validateBinding(w, node, nodeParam, v, expectedType.GetMapValueType(), errs.NewScope(), validateParamTypes); ok {
allNodeIds = append(allNodeIds, nodeIds...)
subType = resolvedType
}
Expand Down Expand Up @@ -119,7 +121,7 @@ func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, bin
// If the variable has an index. We expect param to be a collection.
if v.Index != nil {
if cType := param.GetType().GetCollectionType(); cType == nil {
errs.Collect(errors.NewMismatchingTypesErr(nodeID, val.Promise.Var, param.Type.String(), expectedType.String()))
errs.Collect(errors.NewMismatchingTypesErr(nodeID, "", val.Promise.Var, param.Type.String(), expectedType.String()))
} else {
sourceType = cType
}
Expand Down Expand Up @@ -152,7 +154,11 @@ func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, bin
return param.GetType(), []c.NodeID{val.Promise.NodeId}, true
}

errs.Collect(errors.NewMismatchingTypesErr(nodeID, val.Promise.Var, sourceType.String(), expectedType.String()))
outputNode, _ := w.GetNode(val.Promise.NodeId)
outputVar := fmt.Sprintf("%s.%s", outputNode.GetTask().GetID().Name, val.Promise.Var)
inputVar := fmt.Sprintf("%s.%s", node.GetSubWorkflow().GetCoreWorkflow().GetTemplate().GetId().Name, nodeParam)
errs.Collect(errors.NewMismatchingTypesErr(node.GetId(), outputVar, sourceType.String(), inputVar, expectedType.String()))
return nil, nil, !errs.HasErrors()
}
}

Expand All @@ -167,7 +173,7 @@ func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, bin
if literalType == nil {
errs.Collect(errors.NewUnrecognizedValueErr(nodeID, reflect.TypeOf(val.Scalar.GetValue()).String()))
} else if validateParamTypes && !AreTypesCastable(literalType, expectedType) {
errs.Collect(errors.NewMismatchingTypesErr(nodeID, nodeParam, literalType.String(), expectedType.String()))
errs.Collect(errors.NewMismatchingTypesErr(nodeID, "", nodeParam, literalType.String(), expectedType.String()))
}

if expectedType.GetEnumType() != nil {
Expand Down Expand Up @@ -223,7 +229,7 @@ func ValidateBindings(w c.WorkflowBuilder, node c.Node, bindings []*flyte.Bindin
}

providedBindings.Insert(binding.GetVar())
if resolvedType, upstreamNodes, bindingOk := validateBinding(w, node.GetId(), binding.GetVar(), binding.GetBinding(),
if resolvedType, upstreamNodes, bindingOk := validateBinding(w, node, binding.GetVar(), binding.GetBinding(),
param.Type, errs.NewScope(), validateParamTypes); bindingOk {
for _, upNode := range upstreamNodes {
// Add implicit Edges
Expand Down
2 changes: 1 addition & 1 deletion flytepropeller/pkg/compiler/validators/condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func ValidateBooleanExpression(w c.WorkflowBuilder, node c.NodeBuilder, expr *fl
expr.GetComparison().GetLeftValue(), requireParamType, errs.NewScope())
if op1Valid && op2Valid && op1Type != nil && op2Type != nil {
if op1Type.String() != op2Type.String() {
errs.Collect(errors.NewMismatchingTypesErr(node.GetId(), "RightValue",
errs.Collect(errors.NewMismatchingTypesErr(node.GetId(), "", "RightValue",
op1Type.String(), op2Type.String()))
}
}
Expand Down
2 changes: 1 addition & 1 deletion flytepropeller/pkg/compiler/validators/vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func validateInputVar(n c.NodeBuilder, paramName string, requireParamType bool,
func validateVarType(nodeID c.NodeID, paramName string, param *flyte.Variable,
expectedType *flyte.LiteralType, errs errors.CompileErrors) (ok bool) {
if param.GetType().String() != expectedType.String() {
errs.Collect(errors.NewMismatchingTypesErr(nodeID, paramName, param.GetType().String(), expectedType.String()))
errs.Collect(errors.NewMismatchingTypesErr(nodeID, paramName, param.GetType().String(), "", expectedType.String()))
}

return !errs.HasErrors()
Expand Down

0 comments on commit d2e2979

Please sign in to comment.