Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flytepropeller][flyteadmin] Compiler unknown literal type error handling #5651

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flyteadmin/pkg/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,8 @@
}
return statusErr
}

func NewInvalidLiteralTypeError(name string, err error) FlyteAdminError {
return NewFlyteAdminErrorf(codes.InvalidArgument,
fmt.Sprintf("Failed to validate literal type for [%s] with err: %s", name, err))

Check warning on line 219 in flyteadmin/pkg/errors/errors.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/pkg/errors/errors.go#L217-L219

Added lines #L217 - L219 were not covered by tests
}
4 changes: 4 additions & 0 deletions flyteadmin/pkg/manager/impl/validation/execution_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ func CheckAndFetchInputsForExecution(
default:
inputType = validators.LiteralTypeForLiteral(executionInputMap[name])
}
err := validators.ValidateLiteralType(inputType)
if err != nil {
return nil, errors.NewInvalidLiteralTypeError(name, err)
}
if !validators.AreTypesCastable(inputType, expectedInput.GetVar().GetType()) {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid %s input wrong type. Expected %s, but got %s", name, expectedInput.GetVar().GetType(), inputType)
}
Expand Down
38 changes: 38 additions & 0 deletions flyteadmin/pkg/manager/impl/validation/execution_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (

var execConfig = testutils.GetApplicationConfigWithDefaultDomains()

const failedToValidateLiteralType = "Failed to validate literal type"

func TestValidateExecEmptyProject(t *testing.T) {
request := testutils.GetExecutionRequest()
request.Project = ""
Expand Down Expand Up @@ -209,6 +211,42 @@ func TestValidateExecEmptyInputs(t *testing.T) {
assert.EqualValues(t, expectedMap, actualInputs)
}

func TestValidateExecUnknownIDLInputs(t *testing.T) {
unsupportedLiteral := &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{},
},
}
defaultInputs := &core.ParameterMap{
Parameters: map[string]*core.Parameter{
"foo": {
Var: &core.Variable{
// 1000 means an unsupported type
Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: 1000}},
},
Behavior: &core.Parameter_Default{
Default: unsupportedLiteral,
},
},
},
}
userInputs := &core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": unsupportedLiteral, // This will lead to a nil inputType
},
}

_, err := CheckAndFetchInputsForExecution(
userInputs,
nil,
defaultInputs,
)
assert.NotNil(t, err)

// Expected error message
assert.Contains(t, err.Error(), failedToValidateLiteralType)
}

func TestValidExecutionId(t *testing.T) {
err := CheckValidExecutionID("abcde123", "a")
assert.Nil(t, err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ func checkAndFetchExpectedInputForLaunchPlan(
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "unexpected fixed_input %s", name)
}
inputType := validators.LiteralTypeForLiteral(fixedInput)
err := validators.ValidateLiteralType(inputType)
if err != nil {
return nil, errors.NewInvalidLiteralTypeError(name, err)
}
if !validators.AreTypesCastable(inputType, value.GetType()) {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid fixed_input wrong type %s, expected %v, got %v instead", name, value.GetType(), inputType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,50 @@ func TestGetLpExpectedInvalidFixedInput(t *testing.T) {
assert.Nil(t, actualMap)
}

func TestGetLpExpectedInvalidFixedInputWithUnknownIDL(t *testing.T) {
unsupportedLiteral := &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{},
},
}
workflowVariableMap := &core.VariableMap{
Variables: map[string]*core.Variable{
"foo": {
Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: 1000}},
},
},
}
defaultInputs := &core.ParameterMap{
Parameters: map[string]*core.Parameter{
"foo": {
Var: &core.Variable{
// 1000 means an unsupported type
Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: 1000}},
},
Behavior: &core.Parameter_Default{
Default: unsupportedLiteral,
},
},
},
}
fixedInputs := &core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": unsupportedLiteral, // This will lead to a nil inputType
},
}

_, err := checkAndFetchExpectedInputForLaunchPlan(
workflowVariableMap,
fixedInputs,
defaultInputs,
)

assert.NotNil(t, err)

// Expected error message
assert.Contains(t, err.Error(), failedToValidateLiteralType)
}

func TestGetLpExpectedNoFixedInput(t *testing.T) {
request := testutils.GetLaunchPlanRequest()
actualMap, err := checkAndFetchExpectedInputForLaunchPlan(
Expand Down
4 changes: 4 additions & 0 deletions flyteadmin/pkg/manager/impl/validation/signal_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ func ValidateSignalSetRequest(ctx context.Context, db repositoryInterfaces.Repos
if err != nil {
return err
}
err = propellervalidators.ValidateLiteralType(valueType)
if err != nil {
return errors.NewInvalidLiteralTypeError("", err)
}
if !propellervalidators.AreTypesCastable(lookupSignal.Type, valueType) {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"requested signal value [%v] is not castable to existing signal type [%v]",
Expand Down
48 changes: 48 additions & 0 deletions flyteadmin/pkg/manager/impl/validation/signal_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,4 +283,52 @@ func TestValidateSignalUpdateRequest(t *testing.T) {
utils.AssertEqualWithSanitizedRegex(t,
"requested signal value [scalar:{ primitive:{ boolean:false } } ] is not castable to existing signal type [[8 1]]", ValidateSignalSetRequest(ctx, repo, request).Error())
})

t.Run("UnknownIDLType", func(t *testing.T) {
ctx := context.TODO()

// Define an unsupported literal type with a simple type of 1000
unsupportedLiteralType := &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: 1000, // Using 1000 as an unsupported type
},
}
unsupportedLiteralTypeBytes, _ := proto.Marshal(unsupportedLiteralType)

// Mock the repository to return a signal with this unsupported type
repo := repositoryMocks.NewMockRepository()
repo.SignalRepo().(*repositoryMocks.SignalRepoInterface).
OnGetMatch(mock.Anything, mock.Anything).Return(
models.Signal{
Type: unsupportedLiteralTypeBytes, // Set the unsupported type
},
nil,
)

// Set up the unsupported literal that will trigger the nil valueType condition
unsupportedLiteral := &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{},
},
}

request := admin.SignalSetRequest{
Id: &core.SignalIdentifier{
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "project",
Domain: "domain",
Name: "name",
},
SignalId: "signal",
},
Value: unsupportedLiteral, // This will lead to valueType being nil
}

// Invoke the function and check for the expected error
err := ValidateSignalSetRequest(ctx, repo, &request)
assert.NotNil(t, err)

// Expected error message
assert.Contains(t, err.Error(), failedToValidateLiteralType)
})
}
17 changes: 3 additions & 14 deletions flyteadmin/pkg/manager/impl/validation/validation.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package validation

import (
"fmt"
"net/url"
"strconv"
"strings"
Expand Down Expand Up @@ -283,19 +282,9 @@ func validateParameterMap(inputMap *core.ParameterMap, fieldName string) error {
defaultValue := defaultInput.GetDefault()
if defaultValue != nil {
inputType := validators.LiteralTypeForLiteral(defaultValue)

if inputType == nil {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument,
fmt.Sprintf(
"Flyte encountered an issue while determining\n"+
"the type of the default value for Parameter '%s' in '%s'.\n"+
"Registered type: [%s].\n"+
"Flyte needs to support the latest FlyteIDL to support this type.\n"+
"Suggested solution: Please update all of your Flyte images to the latest version and "+
"try again.",
name, fieldName, defaultInput.GetVar().GetType().String(),
),
)
err := validators.ValidateLiteralType(inputType)
if err != nil {
return errors.NewInvalidLiteralTypeError(name, err)
}

if !validators.AreTypesCastable(inputType, defaultInput.GetVar().GetType()) {
Expand Down
11 changes: 1 addition & 10 deletions flyteadmin/pkg/manager/impl/validation/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,16 +347,7 @@ func TestValidateParameterMap(t *testing.T) {
err := validateParameterMap(&exampleMap, fieldName)
assert.Error(t, err)
fmt.Println(err.Error())
expectedErrMsg := fmt.Sprintf(
"Flyte encountered an issue while determining\n"+
"the type of the default value for Parameter '%s' in '%s'.\n"+
"Registered type: [%s].\n"+
"Flyte needs to support the latest FlyteIDL to support this type.\n"+
"Suggested solution: Please update all of your Flyte images to the latest version and "+
"try again.",
name, fieldName, exampleMap.Parameters[name].GetVar().GetType().String(),
)
assert.Equal(t, expectedErrMsg, err.Error())
assert.Contains(t, err.Error(), failedToValidateLiteralType)
})
}

Expand Down
3 changes: 2 additions & 1 deletion flytepropeller/pkg/compiler/errors/compiler_error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestErrorCodes(t *testing.T) {
UnrecognizedValue: NewUnrecognizedValueErr("", ""),
WorkflowBuildError: NewWorkflowBuildError(errors.New("")),
NoNodesFound: NewNoNodesFoundErr(""),
InvalidLiteralTypeError: NewInvalidLiteralTypeErr("", "", errors.New("")),
}

for key, value := range testCases {
Expand All @@ -48,6 +49,6 @@ func TestIncludeSource(t *testing.T) {

SetConfig(Config{IncludeSource: true})
e = NewCycleDetectedInWorkflowErr("", "")
assert.Equal(t, e.source, "compiler_error_test.go:50")
assert.Equal(t, e.source, "compiler_error_test.go:51")
SetConfig(Config{})
}
11 changes: 11 additions & 0 deletions flytepropeller/pkg/compiler/errors/compiler_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ const (

// Field not found in the dataclass
FieldNotFoundError ErrorCode = "FieldNotFound"

// IDL not found when variable binding
InvalidLiteralTypeError ErrorCode = "InvalidLiteralType"
)

func NewBranchNodeNotSpecified(branchNodeID string) *CompileError {
Expand Down Expand Up @@ -218,6 +221,14 @@ func NewMismatchingVariablesErr(nodeID, fromVar, fromType, toVar, toType string)
)
}

func NewInvalidLiteralTypeErr(nodeID, inputVar string, err error) *CompileError {
return newError(
InvalidLiteralTypeError,
fmt.Sprintf("Failed to validate literal type for [%s] with err: %s", inputVar, err),
nodeID,
)
}

func NewMismatchingBindingsErr(nodeID, sinkParam, expectedType, receivedType string) *CompileError {
return newError(
MismatchingBindings,
Expand Down
7 changes: 7 additions & 0 deletions flytepropeller/pkg/compiler/transformers/k8s/inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs cor
default:
inputType = validators.LiteralTypeForLiteral(inputVal)
}

err := validators.ValidateLiteralType(inputType)
if err != nil {
errs.Collect(errors.NewInvalidLiteralTypeErr(nodeID, inputVar, err))
continue
}

if !validators.AreTypesCastable(inputType, v.Type) {
errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String()))
continue
Expand Down
54 changes: 54 additions & 0 deletions flytepropeller/pkg/compiler/transformers/k8s/inputs_test.go
Original file line number Diff line number Diff line change
@@ -1 +1,55 @@
package k8s

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytepropeller/pkg/compiler/common"
"github.com/flyteorg/flyte/flytepropeller/pkg/compiler/errors"
)

func TestValidateInputs_InvalidLiteralType(t *testing.T) {
nodeID := common.NodeID("test-node")

iface := &core.TypedInterface{
Inputs: &core.VariableMap{
Variables: map[string]*core.Variable{
"input1": {
Type: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: 1000,
},
},
},
},
},
}

inputs := core.LiteralMap{
Literals: map[string]*core.Literal{
"input1": nil, // Set this to nil to trigger the nil case
},
}

errs := errors.NewCompileErrors()
ok := validateInputs(nodeID, iface, inputs, errs)

assert.False(t, ok)
assert.True(t, errs.HasErrors())

idlNotFound := false
var errMsg string
for _, err := range errs.Errors().List() {
if err.Code() == "InvalidLiteralType" {
idlNotFound = true
errMsg = err.Error()
break
}
}
assert.True(t, idlNotFound, "Expected InvalidLiteralType error was not found in errors")

expectedContainedErrorMsg := "Failed to validate literal type"
assert.Contains(t, errMsg, expectedContainedErrorMsg)
}
17 changes: 17 additions & 0 deletions flytepropeller/pkg/compiler/validators/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,23 @@
return buildMultipleTypeUnion(innerType)
}

// ValidateLiteralType check if the literal type is valid, return error if the literal is invalid.
func ValidateLiteralType(lt *core.LiteralType) error {
if lt == nil {
err := fmt.Errorf("got unknown literal type: [%v].\n"+
"Suggested solution: Please update all your Flyte deployment images to the latest version and try again", lt)
return err

Check warning on line 257 in flytepropeller/pkg/compiler/validators/utils.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/utils.go#L253-L257

Added lines #L253 - L257 were not covered by tests
}
if lt.GetCollectionType() != nil {
return ValidateLiteralType(lt.GetCollectionType())

Check warning on line 260 in flytepropeller/pkg/compiler/validators/utils.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/utils.go#L259-L260

Added lines #L259 - L260 were not covered by tests
}
if lt.GetMapValueType() != nil {
return ValidateLiteralType(lt.GetMapValueType())

Check warning on line 263 in flytepropeller/pkg/compiler/validators/utils.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/utils.go#L262-L263

Added lines #L262 - L263 were not covered by tests
}

return nil

Check warning on line 266 in flytepropeller/pkg/compiler/validators/utils.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/utils.go#L266

Added line #L266 was not covered by tests
}

// LiteralTypeForLiteral gets LiteralType for literal, nil if the value of literal is unknown, or type collection/map of
// type None if the literal is a non-homogeneous type.
func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType {
Expand Down
9 changes: 8 additions & 1 deletion flytepropeller/pkg/controller/nodes/array/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,15 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu
}

size := -1
for _, variable := range literalMap.Literals {
for key, variable := range literalMap.Literals {
literalType := validators.LiteralTypeForLiteral(variable)
err := validators.ValidateLiteralType(literalType)
if err != nil {
errMsg := fmt.Sprintf("Failed to validate literal type for [%s] with err: %s", key, err)
return handler.DoTransition(handler.TransitionTypeEphemeral,
handler.PhaseInfoFailure(idlcore.ExecutionError_USER, errors.IDLNotFoundErr, errMsg, nil),
), nil
}
switch literalType.Type.(type) {
case *idlcore.LiteralType_CollectionType:
collectionLength := len(variable.GetCollection().Literals)
Expand Down
Loading
Loading