diff --git a/go.mod b/go.mod index 54d4b4b4..27791495 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.21 require ( go.arcalot.io/assert v1.8.0 - go.arcalot.io/dgraph v1.5.0 + go.arcalot.io/dgraph v1.6.0 go.arcalot.io/lang v1.1.0 go.arcalot.io/log/v2 v2.2.0 go.flow.arcalot.io/deployer v0.6.1 diff --git a/go.sum b/go.sum index 337fc78b..7d5f8c3b 100644 --- a/go.sum +++ b/go.sum @@ -123,8 +123,8 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.arcalot.io/assert v1.8.0 h1:hGcHMPncQXwQvjj7MbyOu2gg8VIBB00crUJZpeQOjxs= go.arcalot.io/assert v1.8.0/go.mod h1:nNmWPoNUHFyrPkNrD2aASm5yPuAfiWdB/4X7Lw3ykHk= -go.arcalot.io/dgraph v1.5.0 h1:6cGlxLzmmehJoD/nj0Hkql7uh90EU0A0GtZhGYkr28M= -go.arcalot.io/dgraph v1.5.0/go.mod h1:+Kxc81utiihMSmC1/ttSPGLDlWPpvgOpNxSFmIDPxFM= +go.arcalot.io/dgraph v1.6.0 h1:mJFZ1vdPEg3KtqyhNqYtWVAkxxWBWoJVUFZQ2Z4mbvE= +go.arcalot.io/dgraph v1.6.0/go.mod h1:+Kxc81utiihMSmC1/ttSPGLDlWPpvgOpNxSFmIDPxFM= go.arcalot.io/exex v0.2.0 h1:u44pjwPwcH57TF8knhaqVZP/1V/KbnRe//pKzMwDpLw= go.arcalot.io/exex v0.2.0/go.mod h1:5zlFr+7vOQNZKYCNOEDdsad+z/dlvXKs2v4kG+v+bQo= go.arcalot.io/lang v1.1.0 h1:ugglRKpd3qIMkdghAjKJxsziIgHm8QpxrzZPSXoa08I= diff --git a/internal/infer/infer.go b/internal/infer/infer.go index ac52a384..6c4b54b6 100644 --- a/internal/infer/infer.go +++ b/internal/infer/infer.go @@ -65,17 +65,25 @@ func Scope( // Type attempts to infer the data model from the data, possibly evaluating expressions. func Type( data any, - internalDataModel *schema.ScopeSchema, + internalDataModel schema.Scope, functions map[string]schema.Function, workflowContext map[string][]byte, ) (schema.Type, error) { - if expression, ok := data.(expressions.Expression); ok { - expressionType, err := expression.Type(internalDataModel, functions, workflowContext) + switch expr := data.(type) { + case expressions.Expression: + expressionType, err := expr.Type(internalDataModel, functions, workflowContext) if err != nil { - return nil, fmt.Errorf("failed to evaluate type of expression %s (%w)", expression.String(), err) + return nil, fmt.Errorf("failed to evaluate type of expression %s (%w)", expr.String(), err) } return expressionType, nil + case *OneOfExpression: + oneOfType, err := expr.Type(internalDataModel, functions, workflowContext) + if err != nil { + return nil, fmt.Errorf("failed to evaluate type of expression %s (%w)", expr.String(), err) + } + return oneOfType, nil } + v := reflect.ValueOf(data) switch v.Kind() { case reflect.Map: @@ -132,7 +140,7 @@ func Type( // mapType infers the type of a map value. func mapType( v reflect.Value, - internalDataModel *schema.ScopeSchema, + internalDataModel schema.Scope, functions map[string]schema.Function, workflowContext map[string][]byte, ) (schema.Type, error) { @@ -141,9 +149,7 @@ func mapType( return nil, fmt.Errorf("failed to infer map key type (%w)", err) } switch keyType.TypeID() { - case schema.TypeIDString: - fallthrough - case schema.TypeIDStringEnum: + case schema.TypeIDString, schema.TypeIDStringEnum: return objectType(v, internalDataModel, functions, workflowContext) case schema.TypeIDInt: case schema.TypeIDIntEnum: @@ -186,7 +192,7 @@ func mapType( func objectType( value reflect.Value, - internalDataModel *schema.ScopeSchema, + internalDataModel schema.Scope, functions map[string]schema.Function, workflowContext map[string][]byte, ) (schema.Type, error) { @@ -207,7 +213,7 @@ func objectType( nil, ) } - return schema.NewObjectSchema( + return schema.NewUnenforcedIDObjectSchema( generateRandomObjectID("inferred_schema"), properties, ), nil @@ -216,7 +222,7 @@ func objectType( // sliceType tries to infer the type of a slice. func sliceType( v reflect.Value, - internalDataModel *schema.ScopeSchema, + internalDataModel schema.Scope, functions map[string]schema.Function, workflowContext map[string][]byte, ) (schema.Type, error) { @@ -237,7 +243,7 @@ func sliceType( func sliceItemType( values []reflect.Value, - internalDataModel *schema.ScopeSchema, + internalDataModel schema.Scope, functions map[string]schema.Function, workflowContext map[string][]byte, ) (schema.Type, error) { diff --git a/internal/infer/infer_test.go b/internal/infer/infer_test.go index 80b35561..db476da7 100644 --- a/internal/infer/infer_test.go +++ b/internal/infer/infer_test.go @@ -2,6 +2,9 @@ package infer_test import ( "fmt" + "go.arcalot.io/assert" + "go.arcalot.io/lang" + "go.flow.arcalot.io/expressions" "testing" "go.flow.arcalot.io/engine/internal/infer" @@ -11,16 +14,31 @@ import ( type testEntry struct { name string input any + dataModel schema.Scope expectedOutputType schema.TypeID validate func(t schema.Type) error } +var testOneOf = infer.OneOfExpression{ + Discriminator: "option", + Options: map[string]any{ + "a": map[string]any{ + "value-1": 1, + }, + "b": map[string]any{ + "value-2": lang.Must2(expressions.New("$.a")), + }, + }, + NodePath: "n/a", +} + var testData = []testEntry{ { "string", "foo", + nil, schema.TypeIDString, - func(t schema.Type) error { + func(_ schema.Type) error { return nil }, }, @@ -30,6 +48,7 @@ var testData = []testEntry{ "foo": "bar", "baz": 42, }, + nil, schema.TypeIDObject, func(t schema.Type) error { objectSchema := t.(*schema.ObjectSchema) @@ -46,6 +65,7 @@ var testData = []testEntry{ { "slice", []string{"foo"}, + nil, schema.TypeIDList, func(t schema.Type) error { listType := t.(*schema.ListSchema) @@ -55,19 +75,95 @@ var testData = []testEntry{ return nil }, }, + { + "expression-1", + lang.Must2(expressions.New("$.a")), + schema.NewScopeSchema( + schema.NewObjectSchema("root", map[string]*schema.PropertySchema{ + "a": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), + }), + ), + schema.TypeIDString, + func(_ schema.Type) error { + return nil + }, + }, + { + "oneof-expression", + &testOneOf, + schema.NewScopeSchema( + schema.NewObjectSchema("root", map[string]*schema.PropertySchema{ + "a": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), + }), + ), + schema.TypeIDOneOfString, + func(t schema.Type) error { + return t.ValidateCompatibility( + schema.NewOneOfStringSchema[any]( + map[string]schema.Object{ + "a": schema.NewObjectSchema("n/a", map[string]*schema.PropertySchema{ + "value-1": schema.NewPropertySchema( + schema.NewIntSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), + }), + "b": schema.NewObjectSchema("n/a", map[string]*schema.PropertySchema{ + "value-2": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), + }), + }, + "option", + false, + ), + ) + }, + }, } func TestInfer(t *testing.T) { for _, entry := range testData { entry := entry t.Run(entry.name, func(t *testing.T) { - inferredType, err := infer.Type(entry.input, nil, nil, nil) + inferredType, err := infer.Type(entry.input, entry.dataModel, nil, nil) if err != nil { t.Fatalf("%v", err) } if inferredType.TypeID() != entry.expectedOutputType { t.Fatalf("Incorrect type inferred: %s", inferredType.TypeID()) } + assert.NoError(t, entry.validate(inferredType)) }) } diff --git a/internal/infer/oneof_expression.go b/internal/infer/oneof_expression.go new file mode 100644 index 00000000..83f8ed59 --- /dev/null +++ b/internal/infer/oneof_expression.go @@ -0,0 +1,41 @@ +package infer + +import ( + "fmt" + "go.flow.arcalot.io/pluginsdk/schema" +) + +// OneOfExpression stores the discriminator, and a key-value pair of all possible oneof values. +// The keys are the value for the discriminator, and the values are the YAML inputs, which can be +// inferred within the infer class. +type OneOfExpression struct { + Discriminator string + Options map[string]any + NodePath string +} + +func (o *OneOfExpression) String() string { + return fmt.Sprintf("{OneOf Expression; Discriminator: %s; Options: %v}", o.Discriminator, o.Options) +} + +// Type returns the OneOf type. Calculates the types of all possible oneof options for this. +func (o *OneOfExpression) Type( + internalDataModel schema.Scope, + functions map[string]schema.Function, + workflowContext map[string][]byte, +) (schema.Type, error) { + schemas := map[string]schema.Object{} + // Gets the type for all options. + for optionID, data := range o.Options { + inferredType, err := Type(data, internalDataModel, functions, workflowContext) + if err != nil { + return nil, err + } + inferredObjectType, isObject := inferredType.(schema.Object) + if !isObject { + return nil, fmt.Errorf("type of OneOf option is not an object; got %T", inferredType) + } + schemas[optionID] = inferredObjectType + } + return schema.NewOneOfStringSchema[any](schemas, o.Discriminator, false), nil +} diff --git a/internal/yaml/parser.go b/internal/yaml/parser.go index 9984b525..a7878e68 100644 --- a/internal/yaml/parser.go +++ b/internal/yaml/parser.go @@ -48,7 +48,8 @@ type Node interface { // Contents returns the contents as further Node items. For maps, this will contain exactly two nodes, while // for sequences this will contain as many nodes as there are items. For strings, this will contain no items. Contents() []Node - // MapKey selects a specific map key. If the node is not a map, this function panics. + // MapKey selects a specific map key. Returns the node and a bool that represents whether the key was present. + // If the node is not a map, this function panics. MapKey(key string) (Node, bool) // MapKeys lists all keys of a map. If the node is not a map, this function panics. MapKeys() []string diff --git a/workflow/any.go b/workflow/any.go index 50eaab3e..742df229 100644 --- a/workflow/any.go +++ b/workflow/any.go @@ -2,6 +2,7 @@ package workflow import ( "fmt" + "go.flow.arcalot.io/engine/internal/infer" "reflect" "go.flow.arcalot.io/expressions" @@ -43,7 +44,8 @@ func (a *anySchemaWithExpressions) Serialize(data any) (any, error) { } func (a *anySchemaWithExpressions) checkAndConvert(data any) (any, error) { - if _, ok := data.(expressions.Expression); ok { + switch data.(type) { + case expressions.Expression, infer.OneOfExpression, *infer.OneOfExpression: return data, nil } t := reflect.ValueOf(data) diff --git a/workflow/executor.go b/workflow/executor.go index ce446ef8..bf22e745 100644 --- a/workflow/executor.go +++ b/workflow/executor.go @@ -6,6 +6,7 @@ import ( "fmt" "go.flow.arcalot.io/engine/internal/util" "reflect" + "strconv" "strings" "go.arcalot.io/dgraph" @@ -196,7 +197,7 @@ func (e *executor) Prepare(workflow *Workflow, workflowContext map[string][]byte if err != nil { return nil, fmt.Errorf("failed to add workflow output node %s to DAG (%w)", outputID, err) } - if err := e.prepareDependencies(workflowContext, outputData, outputNode, internalDataModel, dag); err != nil { + if err := e.prepareDependencies(workflowContext, outputData, outputNode, []string{}, internalDataModel, dag); err != nil { return nil, fmt.Errorf("failed to build dependency tree for output (%w)", err) } } @@ -435,7 +436,7 @@ func (e *executor) connectStepDependencies( if data != nil { stageData[inputField] = data } - if err := e.prepareDependencies(workflowContext, data, currentStageNode, internalDataModel, dag); err != nil { + if err := e.prepareDependencies(workflowContext, data, currentStageNode, []string{}, internalDataModel, dag); err != nil { return fmt.Errorf("failed to build dependency tree for '%s' (%w)", currentStageNode.ID(), err) } } @@ -570,10 +571,12 @@ func (e *executor) preValidateCompatibility(rootSchema schema.Scope, inputField func (e *executor) createTypeStructure(rootSchema schema.Scope, inputField any, workflowContext map[string][]byte) (any, error) { // Expression, so the exact value may not be known yet. So just get the type from it. - if expr, ok := inputField.(expressions.Expression); ok { - // Is expression, so evaluate it. - e.logger.Debugf("Evaluating expression %s...", expr.String()) - return expr.Type(rootSchema, e.callableFunctionSchemas, workflowContext) + switch inputField := inputField.(type) { + case expressions.Expression: + e.logger.Debugf("Evaluating expression %s...", inputField.String()) + return inputField.Type(rootSchema, e.callableFunctionSchemas, workflowContext) + case *infer.OneOfExpression: + return inputField.Type(rootSchema, e.callableFunctionSchemas, workflowContext) } v := reflect.ValueOf(inputField) @@ -867,10 +870,11 @@ func (e *executor) loadSchema(stepKind step.Provider, stepID string, stepDataMap return runnableStep, nil } -func (e *executor) prepareDependencies( //nolint:gocognit,gocyclo +func (e *executor) prepareDependencies( workflowContext map[string][]byte, stepData any, currentNode dgraph.Node[*DAGItem], + pathInCurrentNode []string, outputSchema *schema.ScopeSchema, dag dgraph.DirectedGraph[*DAGItem], ) error { @@ -910,60 +914,9 @@ func (e *executor) prepareDependencies( //nolint:gocognit,gocyclo case reflect.Struct: switch s := stepData.(type) { case expressions.Expression: - // Evaluate the dependencies of the expression on the main data structure. - pathUnpackRequirements := expressions.UnpackRequirements{ - ExcludeDataRootPaths: false, - ExcludeFunctionRootPaths: true, // We don't need to setup DAG connections for them. - StopAtTerminals: true, // We do not need the extra info. We just need the connection. - IncludeKeys: false, - } - dependencies, err := s.Dependencies(outputSchema, e.callableFunctionSchemas, workflowContext, pathUnpackRequirements) - if err != nil { - return fmt.Errorf( - "failed to evaluate dependencies of the expression %s (%w)", - s.String(), - err, - ) - } - for _, dependency := range dependencies { - dependencyKind := dependency[1] - switch dependencyKind { - case WorkflowInputKey: - inputNode, err := dag.GetNodeByID(WorkflowInputKey) - if err != nil { - return fmt.Errorf("failed to find input node (%w)", err) - } - if err := inputNode.Connect(currentNode.ID()); err != nil { - decodedErr := &dgraph.ErrConnectionAlreadyExists{} - if !errors.As(err, &decodedErr) { - return fmt.Errorf("failed to connect input to %s (%w)", currentNode.ID(), err) - } - } - case WorkflowStepsKey: - var prevNodeID string - switch dependencyNodes := len(dependency); { - case dependencyNodes == 4: // Example: $.steps.example.outputs - prevNodeID = dependency[1:4].String() - case dependencyNodes >= 5: // Example: $.steps.example.outputs.success (or longer) - prevNodeID = dependency[1:5].String() - default: - return fmt.Errorf("invalid dependency %s", dependency.String()) - } - prevNode, err := dag.GetNodeByID(prevNodeID) - if err != nil { - return fmt.Errorf("failed to find depending node %s (%w)", prevNodeID, err) - } - if err := prevNode.Connect(currentNode.ID()); err != nil { - decodedErr := &dgraph.ErrConnectionAlreadyExists{} - if !errors.As(err, &decodedErr) { - return fmt.Errorf("failed to connect DAG node (%w)", err) - } - } - default: - return fmt.Errorf("bug: invalid dependency kind: %s", dependencyKind) - } - } - return nil + return e.prepareExprDependencies(s, workflowContext, currentNode, outputSchema, dag) + case *infer.OneOfExpression: + return e.prepareOneOfExprDependencies(s, workflowContext, currentNode, pathInCurrentNode, outputSchema, dag) default: return &ErrInvalidWorkflow{fmt.Errorf("unsupported struct/pointer type in workflow input: %T", stepData)} } @@ -971,7 +924,7 @@ func (e *executor) prepareDependencies( //nolint:gocognit,gocyclo v := reflect.ValueOf(stepData) for i := 0; i < v.Len(); i++ { value := v.Index(i).Interface() - if err := e.prepareDependencies(workflowContext, value, currentNode, outputSchema, dag); err != nil { + if err := e.prepareDependencies(workflowContext, value, currentNode, append(pathInCurrentNode, strconv.Itoa(i)), outputSchema, dag); err != nil { return wrapDependencyError(currentNode.ID(), fmt.Sprintf("%d", i), err) } } @@ -981,7 +934,7 @@ func (e *executor) prepareDependencies( //nolint:gocognit,gocyclo for _, reflectedKey := range v.MapKeys() { key := reflectedKey.Interface() value := v.MapIndex(reflectedKey).Interface() - if err := e.prepareDependencies(workflowContext, value, currentNode, outputSchema, dag); err != nil { + if err := e.prepareDependencies(workflowContext, value, currentNode, append(pathInCurrentNode, key.(string)), outputSchema, dag); err != nil { return wrapDependencyError(currentNode.ID(), fmt.Sprintf("%v", key), err) } } @@ -991,6 +944,117 @@ func (e *executor) prepareDependencies( //nolint:gocognit,gocyclo } } +func (e *executor) prepareExprDependencies( + expr expressions.Expression, workflowContext map[string][]byte, + currentNode dgraph.Node[*DAGItem], + outputSchema *schema.ScopeSchema, + dag dgraph.DirectedGraph[*DAGItem], +) error { + // Evaluate the dependencies of the expression on the main data structure. + pathUnpackRequirements := expressions.UnpackRequirements{ + ExcludeDataRootPaths: false, + ExcludeFunctionRootPaths: true, // We don't need to setup DAG connections for them. + StopAtTerminals: true, // We do not need the extra info. We just need the connection. + IncludeKeys: false, + } + dependencies, err := expr.Dependencies(outputSchema, e.callableFunctionSchemas, workflowContext, pathUnpackRequirements) + if err != nil { + return fmt.Errorf( + "failed to evaluate dependencies of the expression %s (%w)", + expr.String(), + err, + ) + } + for _, dependency := range dependencies { + dependencyKind := dependency[1] + switch dependencyKind { + case WorkflowInputKey: + inputNode, err := dag.GetNodeByID(WorkflowInputKey) + if err != nil { + return fmt.Errorf("failed to find input node (%w)", err) + } + if err := inputNode.Connect(currentNode.ID()); err != nil { + decodedErr := &dgraph.ErrConnectionAlreadyExists{} + if !errors.As(err, &decodedErr) { + return fmt.Errorf("failed to connect input to %s (%w)", currentNode.ID(), err) + } + } + case WorkflowStepsKey: + var prevNodeID string + switch dependencyNodes := len(dependency); { + case dependencyNodes == 4: // Example: $.steps.example.outputs + prevNodeID = dependency[1:4].String() + case dependencyNodes >= 5: // Example: $.steps.example.outputs.success (or longer) + prevNodeID = dependency[1:5].String() + default: + return fmt.Errorf("invalid dependency %s", dependency.String()) + } + prevNode, err := dag.GetNodeByID(prevNodeID) + if err != nil { + return fmt.Errorf("failed to find depending node %s (%w)", prevNodeID, err) + } + if err := currentNode.ConnectDependency(prevNode.ID(), dgraph.AndDependency); err != nil { + decodedErr := &dgraph.ErrConnectionAlreadyExists{} + if !errors.As(err, &decodedErr) { + return fmt.Errorf("failed to connect DAG node (%w)", err) + } + } + default: + return fmt.Errorf("bug: invalid dependency kind: %s", dependencyKind) + } + } + return nil +} + +func (e *executor) prepareOneOfExprDependencies( + expr *infer.OneOfExpression, + workflowContext map[string][]byte, + currentNode dgraph.Node[*DAGItem], + pathInCurrentNode []string, + outputSchema *schema.ScopeSchema, + dag dgraph.DirectedGraph[*DAGItem], +) error { + // Evaluate dependencies of all options for the oneof, then + // create OR dependencies for all of them. + // DAG nodes will need to be created for each option. + if len(expr.Options) == 0 { + return fmt.Errorf("oneof %s has no options", expr.String()) + } + // In case there are multiple OneOfs, each oneof needs its own node. + orNodeType := &DAGItem{ + Kind: DagItemKindOrGroup, + } + oneofDagNode, err := dag.AddNode( + currentNode.ID()+"."+strings.Join(pathInCurrentNode, "."), orNodeType) + if err != nil { + return err + } + err = currentNode.ConnectDependency(oneofDagNode.ID(), dgraph.AndDependency) + if err != nil { + return err + } + // Mark the node ID on the OneOfExpression. This mutates the expression, so make sure + // this is not operating on a copy of the schema for the data to be retained. + expr.NodePath = oneofDagNode.ID() + for optionID, optionData := range expr.Options { + optionDagNode, err := dag.AddNode( + oneofDagNode.ID()+"."+optionID, orNodeType) + if err != nil { + return err + } + err = oneofDagNode.ConnectDependency(optionDagNode.ID(), dgraph.OrDependency) + if err != nil { + return err + } + err = e.prepareDependencies( + workflowContext, optionData, optionDagNode, []string{}, outputSchema, dag) + if err != nil { + return err + } + } + return nil +} + // DependencyError describes an error while preparing dependencies. type DependencyError struct { ID string `json:"id"` diff --git a/workflow/model.go b/workflow/model.go index d5e3adc4..c09c500b 100644 --- a/workflow/model.go +++ b/workflow/model.go @@ -170,6 +170,9 @@ const ( DAGItemKindStepStageOutput DAGItemKind = "stepStageOutput" // DAGItemKindOutput indicates a DAG node for the workflow output. DAGItemKindOutput DAGItemKind = "output" + // DagItemKindOrGroup indicates a DAG node used to complete a part of + // an input or output that needs dependencies grouped, typically for OR dependencies. + DagItemKindOrGroup DAGItemKind = "orGroup" ) // DAGItem is the internal structure of the DAG. diff --git a/workflow/workflow.go b/workflow/workflow.go index 9a0a0a33..06a4a798 100644 --- a/workflow/workflow.go +++ b/workflow/workflow.go @@ -4,10 +4,12 @@ package workflow import ( "context" "fmt" + "go.flow.arcalot.io/engine/internal/infer" "go.flow.arcalot.io/engine/internal/tablefmt" "go.flow.arcalot.io/engine/internal/tableprinter" "io" "reflect" + "strings" "sync" "time" @@ -130,7 +132,7 @@ func (e *executableWorkflow) Execute(ctx context.Context, serializedInput any) ( var stageHandler step.StageChangeHandler = &stageChangeHandler{ onStageChange: func( - step step.RunningStep, + _ step.RunningStep, previousStage *string, previousStageOutputID *string, previousStageOutput *any, @@ -146,7 +148,7 @@ func (e *executableWorkflow) Execute(ctx context.Context, serializedInput any) ( l.onStageComplete(stepID, previousStage, previousStageOutputID, previousStageOutput, wg) }, onStepComplete: func( - step step.RunningStep, + _ step.RunningStep, previousStage string, previousStageOutputID *string, previousStageOutput *any, @@ -159,7 +161,7 @@ func (e *executableWorkflow) Execute(ctx context.Context, serializedInput any) ( } l.onStageComplete(stepID, &previousStage, previousStageOutputID, previousStageOutput, wg) }, - onStepStageFailure: func(step step.RunningStep, stage string, wg *sync.WaitGroup, err error) { + onStepStageFailure: func(_ step.RunningStep, stage string, _ *sync.WaitGroup, err error) { if err == nil { e.logger.Debugf("Step %q stage %q declared that it will not produce an output", stepID, stage) } else { @@ -518,8 +520,17 @@ func (l *loopState) notifySteps() { //nolint:gocognit // The data structure that the particular node requires. One or more fields. May or may not contain expressions. inputData := nodeItem.Data if inputData == nil { - // No input data is needed. This is often the case for input nodes. - continue + switch nodeItem.Kind { + case DagItemKindOrGroup: + if err := node.ResolveNode(dgraph.Resolved); err != nil { + panic(fmt.Errorf("error occurred while resolving workflow OR group node (%s)", err.Error())) + } + l.notifySteps() // Needs to be called after resolving a node. + continue + default: + // No input data is needed. This is often the case for input nodes. + continue + } } // Resolve any expressions in the input data. @@ -591,8 +602,11 @@ func (l *loopState) notifySteps() { //nolint:gocognit } if err := node.ResolveNode(dgraph.Resolved); err != nil { - l.logger.Errorf("BUG: Error occurred while removing workflow output node (%w)", err) + l.logger.Errorf("BUG: Error occurred while resolving workflow output node (%s)", err.Error()) } + default: + panic(fmt.Errorf("unhandled case for type %s", nodeItem.Kind)) + } } } @@ -668,10 +682,52 @@ func (l *loopState) checkForDeadlocks(retries int, wg *sync.WaitGroup) { // resolveExpressions takes an inputData value potentially containing expressions and a dataModel containing data // for expressions and resolves the expressions contained in inputData using reflection. func (l *loopState) resolveExpressions(inputData any, dataModel any) (any, error) { - if expr, ok := inputData.(expressions.Expression); ok { + switch expr := inputData.(type) { + case expressions.Expression: l.logger.Debugf("Evaluating expression %s...", expr.String()) return expr.Evaluate(dataModel, l.callableFunctions, l.workflowContext) + case *infer.OneOfExpression: + l.logger.Debugf("Evaluating oneof expression %s...", expr.String()) + + // Get the node the OneOf uses to check which Or dependency resolved first (the others will either not be + // in the resolved list, or they will be obviated) + oneOfNode, err := l.dag.GetNodeByID(expr.NodePath) + if err != nil { + return nil, fmt.Errorf("failed to get node to resolve oneof expression (%w)", err) + } + dependencies := oneOfNode.ResolvedDependencies() + firstResolvedDependency := "" + for dependency, dependencyType := range dependencies { + if dependencyType == dgraph.OrDependency { + firstResolvedDependency = dependency + break + } else if dependencyType == dgraph.ObviatedDependency { + l.logger.Infof("Multiple OR cases triggered; skipping %q", dependency) + } + } + if firstResolvedDependency == "" { + return nil, fmt.Errorf("could not find resolved dependency for oneof expression %q", expr.String()) + } + optionID := strings.Replace(firstResolvedDependency, expr.NodePath+".", "", 1) + optionExpr, found := expr.Options[optionID] + if !found { + return nil, fmt.Errorf("could not find oneof option %q for oneof %q", optionID, expr) + } + // Still pass the current node in due to the possibility of a foreach within a foreach. + subTypeResolution, err := l.resolveExpressions(optionExpr, dataModel) + if err != nil { + return nil, err + } + // Validate that it returned a map type (this is required because oneof subtypes need to be objects) + subTypeObjectMap, ok := subTypeResolution.(map[any]any) + if !ok { + return nil, fmt.Errorf("sub-type for oneof is not an object; got %T", subTypeResolution) + } + // Now add the discriminator + subTypeObjectMap[expr.Discriminator] = optionID + return subTypeObjectMap, nil } + v := reflect.ValueOf(inputData) switch v.Kind() { case reflect.Slice: diff --git a/workflow/workflow_test.go b/workflow/workflow_test.go index 8ba4053d..dcb9af98 100644 --- a/workflow/workflow_test.go +++ b/workflow/workflow_test.go @@ -482,13 +482,13 @@ input: root: RootObject objects: RootObject: - id: RootObject + id: RootObject properties: {} steps: second_wait: wait_for: !expr $.steps.first_wait.outputs.success kind: foreach - items: + items: - wait_time_ms: 10 workflow: subworkflow.yaml first_wait: @@ -608,7 +608,7 @@ input: root: RootObject objects: RootObject: - id: RootObject + id: RootObject properties: {} steps: pre_wait: @@ -621,7 +621,7 @@ steps: second_wait: wait_for: !expr $.steps.first_wait.starting.started kind: foreach - items: + items: - wait_time_ms: 2 workflow: subworkflow.yaml first_wait: @@ -778,7 +778,6 @@ steps: input: wait_time_ms: 0 wait_2: - plugin: src: "n/a" deployment_type: "builtin" @@ -995,8 +994,8 @@ outputSchema: success: schema: root: RootObjectOut - objects: - RootObjectOut: + objects: + RootObjectOut: id: RootObjectOut properties: {}` @@ -1270,6 +1269,743 @@ func TestInputDisabledStepWorkflow(t *testing.T) { } } +var gracefullyDisabledStepWorkflow = ` +version: v0.2.0 +input: + root: WorkflowInput + objects: + WorkflowInput: + id: WorkflowInput + properties: + step_enabled: + type: + type_id: bool +steps: + simple_wait: + plugin: + src: "n/a" + deployment_type: "builtin" + step: wait + input: + wait_time_ms: 0 + enabled: !expr $.input.step_enabled +outputs: + both: + simple_wait_output: !oneof + discriminator: "result" + one_of: + success_wait_output: !expr $.steps.simple_wait.outputs.success + disabled_wait_output: !expr $.steps.simple_wait.disabled.output +` + +func TestGracefullyDisabledStepWorkflow(t *testing.T) { + // Run a workflow where both the disabled output and the success output + // result in a single valid workflow output. + preparedWorkflow := assert.NoErrorR[workflow.ExecutableWorkflow](t)( + getTestImplPreparedWorkflow(t, gracefullyDisabledStepWorkflow), + ) + outputID, outputData, err := preparedWorkflow.Execute(context.Background(), map[string]any{ + "step_enabled": true, + }) + assert.NoError(t, err) + assert.Equals(t, outputID, "both") + outputDataMap := outputData.(map[any]any) + assert.MapContainsKeyAny[any](t, "simple_wait_output", outputDataMap) + outputDataMap = outputDataMap["simple_wait_output"].(map[any]any) + assert.MapContainsKeyAny[any](t, "result", outputDataMap) + assert.Equals(t, outputDataMap["result"], "success_wait_output") + // Test step disabled case + outputID, outputData, err = preparedWorkflow.Execute(context.Background(), map[string]any{ + "step_enabled": false, + }) + assert.NoError(t, err) + assert.Equals(t, outputID, "both") + outputDataMap = outputData.(map[any]any) + assert.MapContainsKeyAny[any](t, "simple_wait_output", outputDataMap) + outputDataMap = outputDataMap["simple_wait_output"].(map[any]any) + assert.MapContainsKeyAny[any](t, "result", outputDataMap) + assert.Equals(t, outputDataMap["result"], "disabled_wait_output") +} + +var oneofWithOneOptionWorkflow = ` +version: v0.2.0 +input: + root: WorkflowInput + objects: + WorkflowInput: + id: WorkflowInput + properties: + step_enabled: + type: + type_id: bool +steps: + simple_wait: + plugin: + src: "n/a" + deployment_type: "builtin" + step: wait + input: + wait_time_ms: 0 + enabled: !expr $.input.step_enabled +outputs: + workflow-success: + test_object: !oneof + discriminator: "option" + one_of: + test: !expr $.steps.simple_wait.outputs.success +` + +func TestSingleOneofOptionWorkflow(t *testing.T) { + // Runs a workflow where the output has a oneof that has one option that + // depends on a step that can be disabled. This ensures that oneof works + // properly with the unresolvable detection. + preparedWorkflow := assert.NoErrorR[workflow.ExecutableWorkflow](t)( + getTestImplPreparedWorkflow(t, oneofWithOneOptionWorkflow), + ) + // The workflow should pass with it enabled + outputID, _, err := preparedWorkflow.Execute(context.Background(), map[string]any{ + "step_enabled": true, + }) + assert.NoError(t, err) + assert.Equals(t, outputID, "workflow-success") + // The workflow should fail with it disabled because the output cannot be resolved. + _, _, err = preparedWorkflow.Execute(context.Background(), map[string]any{ + "step_enabled": false, + }) + assert.Error(t, err) + var typedError *workflow.ErrNoMorePossibleOutputs + if !errors.As(err, &typedError) { + t.Fatalf("incorrect error type returned: %T (%s)", err, err) + } +} + +var manyOneOfOptionsWf = ` +version: v0.2.0 +input: + root: RootObject + objects: + RootObject: + id: RootObject + properties: + step_to_run: + type: + type_id: string +steps: + step_a: + plugin: + src: "n/a" + deployment_type: "builtin" + step: hello + input: + fail: !expr false + enabled: !expr $.input.step_to_run == "a" + step_b: + plugin: + src: "n/a" + deployment_type: "builtin" + step: wait + input: + wait_time_ms: 0 + enabled: !expr $.input.step_to_run == "b" + step_c: + plugin: + src: "n/a" + deployment_type: "builtin" + step: hello + input: + fail: !expr false + enabled: !expr $.input.step_to_run == "c" + step_d: + plugin: + src: "n/a" + deployment_type: "builtin" + step: hello + input: + fail: !expr false + enabled: !expr $.input.step_to_run == "d" + step_e: + plugin: + src: "n/a" + deployment_type: "builtin" + step: hello + input: + fail: !expr false + enabled: !expr $.input.step_to_run == "e" + step_f: + plugin: + src: "n/a" + deployment_type: "builtin" + step: hello + input: + fail: !expr false + enabled: !expr $.input.step_to_run == "f" +outputs: + success: + ran_step: !oneof + discriminator: "result" + one_of: + a: !expr $.steps.step_a.outputs.success + b: !expr $.steps.step_b.outputs.success + c: !expr $.steps.step_c.outputs.success + d: !expr $.steps.step_d.outputs.success + e: !expr $.steps.step_e.outputs.success + f: !expr $.steps.step_f.outputs.success +` + +func TestManyOneofOptionsWorkflow(t *testing.T) { + type TestCase struct { + input map[string]any + expectedOutput map[any]any + } + + cases := []TestCase{ + { + input: map[string]any{ + "step_to_run": "a", + }, + expectedOutput: map[any]any{ + "result": "a", + }, + }, + { + input: map[string]any{ + "step_to_run": "b", + }, + expectedOutput: map[any]any{ + "result": "b", + "message": "Plugin slept for 0 ms.", + }, + }, + { + input: map[string]any{ + "step_to_run": "c", + }, + expectedOutput: map[any]any{ + "result": "c", + }, + }, + { + input: map[string]any{ + "step_to_run": "d", + }, + expectedOutput: map[any]any{ + "result": "d", + }, + }, + { + input: map[string]any{ + "step_to_run": "e", + }, + expectedOutput: map[any]any{ + "result": "e", + }, + }, + { + input: map[string]any{ + "step_to_run": "f", + }, + expectedOutput: map[any]any{ + "result": "f", + }, + }, + } + + // Run a workflow where both the disabled output and the success output + // result in a single valid workflow output. + preparedWorkflow := assert.NoErrorR[workflow.ExecutableWorkflow](t)( + getTestImplPreparedWorkflow(t, manyOneOfOptionsWf), + ) + + for _, testCase := range cases { + t.Logf("Testing with input %v", testCase.input) + outputID, outputData, err := preparedWorkflow.Execute(context.Background(), testCase.input) + assert.NoError(t, err) + assert.Equals(t, outputID, "success") + outputDataMap := outputData.(map[any]any) + assert.MapContainsKeyAny[any](t, "ran_step", outputDataMap) + outputDataMap = outputDataMap["ran_step"].(map[any]any) + assert.Equals(t, outputDataMap, testCase.expectedOutput) + } +} + +var nestedOneOfWorkflow = ` +version: v0.2.0 +input: + root: WorkflowInput + objects: + WorkflowInput: + id: WorkflowInput + properties: + step_1_enabled: + type: + type_id: bool + step_2_enabled: + type: + type_id: bool +steps: + wait_1: + plugin: + src: "n/a" + deployment_type: "builtin" + step: wait + input: + wait_time_ms: 0 + enabled: !expr $.input.step_1_enabled + wait_2: + plugin: + src: "n/a" + deployment_type: "builtin" + step: wait + input: + wait_time_ms: 0 + enabled: !expr $.input.step_2_enabled +outputs: + all: + simple_wait_output: !oneof + discriminator: "result" + one_of: + a: + simple: !expr $.steps.wait_2.outputs.success + sub_oneof: !oneof + discriminator: "sub-result" + one_of: + # Use the same IDs to test for conflicts. + a: !expr $.steps.wait_1.outputs.success + b: !expr $.steps.wait_1.disabled.output + b: !expr $.steps.wait_2.disabled.output +` + +func TestNestedOneOfWorkflow(t *testing.T) { + type TestCase struct { + input map[string]any + expectedOutput map[any]any + } + + cases := []TestCase{ + { + input: map[string]any{ + "step_1_enabled": true, + "step_2_enabled": true, + }, + expectedOutput: map[any]any{ + "result": "a", + "simple": map[any]any{ + "message": "Plugin slept for 0 ms.", + }, + "sub_oneof": map[any]any{ + "sub-result": "a", + "message": "Plugin slept for 0 ms.", + }, + }, + }, + { + input: map[string]any{ + "step_1_enabled": false, + "step_2_enabled": true, + }, + expectedOutput: map[any]any{ + "result": "a", + "simple": map[any]any{ + "message": "Plugin slept for 0 ms.", + }, + "sub_oneof": map[any]any{ + "sub-result": "b", + "message": "Step wait_1/wait disabled", + }, + }, + }, + { + input: map[string]any{ + "step_1_enabled": true, + "step_2_enabled": false, + }, + expectedOutput: map[any]any{ + "result": "b", + "message": "Step wait_2/wait disabled", + }, + }, + { + input: map[string]any{ + "step_1_enabled": false, + "step_2_enabled": false, + }, + expectedOutput: map[any]any{ + "result": "b", + "message": "Step wait_2/wait disabled", + }, + }, + } + + // Run a workflow where both the disabled output and the success output + // result in a single valid workflow output. + preparedWorkflow := assert.NoErrorR[workflow.ExecutableWorkflow](t)( + getTestImplPreparedWorkflow(t, nestedOneOfWorkflow), + ) + + for _, testCase := range cases { + t.Logf("Testing with input %v", testCase.input) + outputID, outputData, err := preparedWorkflow.Execute(context.Background(), testCase.input) + assert.NoError(t, err) + assert.Equals(t, outputID, "all") + outputDataMap := outputData.(map[any]any) + assert.MapContainsKeyAny[any](t, "simple_wait_output", outputDataMap) + outputDataMap = outputDataMap["simple_wait_output"].(map[any]any) + assert.Equals(t, outputDataMap, testCase.expectedOutput) + } +} + +var oneofInListWorkflow = ` +version: v0.2.0 +input: + root: WorkflowInput + objects: + WorkflowInput: + id: WorkflowInput + properties: + step_1_enabled: + type: + type_id: bool + step_2_enabled: + type: + type_id: bool +steps: + wait_1: + plugin: + src: "n/a" + deployment_type: "builtin" + step: wait + input: + wait_time_ms: 0 + enabled: !expr $.input.step_1_enabled + wait_2: + plugin: + src: "n/a" + deployment_type: "builtin" + step: wait + input: + wait_time_ms: 1 + enabled: !expr $.input.step_2_enabled +outputs: + all: + simple_wait_output: + - list_item: !oneof + discriminator: "result" + one_of: + a: !expr $.steps.wait_1.outputs.success + b: !expr $.steps.wait_1.disabled.output + - list_item: !oneof + discriminator: "result" + one_of: + a: !expr $.steps.wait_2.outputs.success + b: !expr $.steps.wait_2.disabled.output +` + +func TestOneOfInListWorkflow(t *testing.T) { + type TestCase struct { + input map[string]any + expectedOutput []any + } + + cases := []TestCase{ + { + input: map[string]any{ + "step_1_enabled": true, + "step_2_enabled": true, + }, + expectedOutput: []any{ + map[any]any{ + "list_item": map[any]any{ + "result": "a", + "message": "Plugin slept for 0 ms.", + }, + }, + map[any]any{ + "list_item": map[any]any{ + "result": "a", + "message": "Plugin slept for 1 ms.", + }, + }, + }, + }, + { + input: map[string]any{ + "step_1_enabled": true, + "step_2_enabled": false, + }, + expectedOutput: []any{ + map[any]any{ + "list_item": map[any]any{ + "result": "a", + "message": "Plugin slept for 0 ms.", + }, + }, + map[any]any{ + "list_item": map[any]any{ + "result": "b", + "message": "Step wait_2/wait disabled", + }, + }, + }, + }, + { + input: map[string]any{ + "step_1_enabled": false, + "step_2_enabled": true, + }, + expectedOutput: []any{ + map[any]any{ + "list_item": map[any]any{ + "result": "b", + "message": "Step wait_1/wait disabled", + }, + }, + map[any]any{ + "list_item": map[any]any{ + "result": "a", + "message": "Plugin slept for 1 ms.", + }, + }, + }, + }, + { + input: map[string]any{ + "step_1_enabled": false, + "step_2_enabled": false, + }, + expectedOutput: []any{ + map[any]any{ + "list_item": map[any]any{ + "result": "b", + "message": "Step wait_1/wait disabled", + }, + }, + map[any]any{ + "list_item": map[any]any{ + "result": "b", + "message": "Step wait_2/wait disabled", + }, + }, + }, + }, + } + + // Run a workflow where both the disabled output and the success output + // result in a single valid workflow output. + preparedWorkflow := assert.NoErrorR[workflow.ExecutableWorkflow](t)( + getTestImplPreparedWorkflow(t, oneofInListWorkflow), + ) + + for _, testCase := range cases { + t.Logf("Testing with input %v", testCase.input) + outputID, outputData, err := preparedWorkflow.Execute(context.Background(), testCase.input) + assert.NoError(t, err) + assert.Equals(t, outputID, "all") + outputDataMap := outputData.(map[any]any) + assert.MapContainsKeyAny[any](t, "simple_wait_output", outputDataMap) + outputDataList := outputDataMap["simple_wait_output"].([]any) + assert.Equals(t, outputDataList, testCase.expectedOutput) + } +} + +var forEachWithOneOfWf = ` +version: v0.2.0 +input: + root: RootObject + objects: + RootObject: + id: RootObject + properties: + step_to_run: + type: + type_id: string +steps: + step_a: + plugin: + src: "n/a" + deployment_type: "builtin" + step: hello + input: + fail: !expr false + enabled: !expr $.input.step_to_run == "a" + step_b: + plugin: + src: "n/a" + deployment_type: "builtin" + step: wait + input: + wait_time_ms: 0 + enabled: !expr $.input.step_to_run == "b" + subwf_step: + kind: foreach + items: + - input_1: !oneof + discriminator: "result" + one_of: + a: !expr $.steps.step_a.outputs.success + b: !expr $.steps.step_b.outputs.success + input_2: !oneof + discriminator: "result" + one_of: + a: !expr $.steps.step_a.disabled.output + b: !expr $.steps.step_b.disabled.output + workflow: subworkflow.yaml +outputs: + success: + subwf_result: !expr $.steps.subwf_step.outputs +` + +var forEachWithOneOfSubWf = ` +version: v0.2.0 +input: + root: RootObject + objects: + RootObject: + id: RootObject + properties: + input_1: + type: + type_id: one_of_string + discriminator_field_name: result + types: + a: + type_id: object + id: hello-output + properties: {} + b: + type_id: object + id: output + properties: + message: + type: + type_id: string + input_2: + type: + type_id: one_of_string + discriminator_field_name: result + types: + a: + type_id: object + id: DisabledMessageOutput + properties: + message: + type: + type_id: string + b: + type_id: object + id: DisabledMessageOutput + properties: + message: + type: + type_id: string +steps: + placeholder_step: + plugin: + src: "n/a" + deployment_type: "builtin" + step: hello + input: + fail: !expr false +outputs: + success: !expr $.input +` + +func TestForeachWithOneOf(t *testing.T) { + // This test tests the oneof tag `!oneof` being used to create an input to + // a subworkflow. + // It would be redundant to use oneof with a single output, so we run two steps, and only one + // succeeds at a time. + // Since the subworkflow schema must match the input, this also validates that + // the inferred schema is correct. + logConfig := log.Config{ + Level: log.LevelDebug, + Destination: log.DestinationStdout, + } + logger := log.New( + logConfig, + ) + cfg := &config.Config{ + Log: logConfig, + } + factories := workflowFactory{ + config: cfg, + } + deployerRegistry := deployerregistry.New( + deployer.Any(testimpl.NewFactory()), + ) + + pluginProvider := assert.NoErrorR[step.Provider](t)( + plugin.New(logger, deployerRegistry, map[string]interface{}{ + "builtin": map[string]any{ + "deployer_name": "test-impl", + "deploy_time": "0", + }, + }), + ) + stepRegistry, err := stepregistry.New( + pluginProvider, + lang.Must2(foreach.New(logger, factories.createYAMLParser, factories.createWorkflow)), + ) + assert.NoError(t, err) + + factories.stepRegistry = stepRegistry + executor := lang.Must2(workflow.NewExecutor( + logger, + cfg, + stepRegistry, + builtinfunctions.GetFunctions(), + )) + wf := lang.Must2(workflow.NewYAMLConverter(stepRegistry).FromYAML([]byte(forEachWithOneOfWf))) + preparedWorkflow := lang.Must2(executor.Prepare(wf, map[string][]byte{ + "subworkflow.yaml": []byte(forEachWithOneOfSubWf), + })) + outputID, outputData, err := preparedWorkflow.Execute(context.Background(), map[string]any{ + "step_to_run": "a", + }) + assert.NoError(t, err) + assert.Equals(t, outputID, "success") + assert.Equals(t, outputData.(map[any]any), map[any]any{ + "subwf_result": map[string]any{ + "success": map[string]any{ + "data": []any{ + map[string]any{ + "input_1": map[string]any{ + "result": "a", + }, + "input_2": map[string]any{ + "result": "b", + "message": "Step step_b/wait disabled", + }, + }, + }, + }, + }, + }) + + outputID, outputData, err = preparedWorkflow.Execute(context.Background(), map[string]any{ + "step_to_run": "b", + }) + assert.NoError(t, err) + assert.Equals(t, outputID, "success") + assert.Equals(t, outputData.(map[any]any), map[any]any{ + "subwf_result": map[string]any{ + "success": map[string]any{ + "data": []any{ + map[string]any{ + "input_1": map[string]any{ + "result": "b", + "message": "Plugin slept for 0 ms.", + }, + "input_2": map[string]any{ + "result": "a", + "message": "Step step_a/hello disabled", + }, + }, + }, + }, + }, + }) +} + var dynamicDisabledStepWorkflow = ` version: v0.2.0 input: diff --git a/workflow/yaml.go b/workflow/yaml.go index a1d5b2d8..5ed051a2 100644 --- a/workflow/yaml.go +++ b/workflow/yaml.go @@ -2,6 +2,7 @@ package workflow import ( "fmt" + "go.flow.arcalot.io/engine/internal/infer" "strings" "go.flow.arcalot.io/engine/internal/step" @@ -49,16 +50,63 @@ func (y yamlConverter) FromYAML(data []byte) (*Workflow, error) { return workflow, nil } +// YamlOneOfKey is the key to specify the oneof options within a !oneof section. +const YamlOneOfKey = "one_of" + +// YamlDiscriminatorKey is the key to specify the discriminator inside a !oneof section. +const YamlDiscriminatorKey = "discriminator" + +// YamlOneOfTag is the yaml tag that allows the section to be interpreted as a OneOf. +const YamlOneOfTag = "!oneof" + +func buildOneOfExpressions(data yaml.Node, path []string) (any, error) { + if data.Type() != yaml.TypeIDMap { + return nil, fmt.Errorf("!oneof found on non-map node at %s; expected a map with a list of options and the discriminator ", strings.Join(path, " -> ")) + } + discriminatorNode, found := data.MapKey(YamlDiscriminatorKey) + if !found { + return nil, fmt.Errorf("key %q not present within !oneof at %q", YamlDiscriminatorKey, strings.Join(path, " -> ")) + } + if discriminatorNode.Type() != yaml.TypeIDString { + return nil, fmt.Errorf("%q within !oneof should be a string; got %s", discriminatorNode.Type(), YamlDiscriminatorKey) + } + oneOfOptionsNode, found := data.MapKey(YamlOneOfKey) + if !found { + return nil, fmt.Errorf("key %q not present within !oneof at %q", YamlOneOfKey, strings.Join(path, " -> ")) + } + if oneOfOptionsNode.Type() != yaml.TypeIDMap { + return nil, fmt.Errorf("%q within !oneof should be a map; got %s", YamlOneOfKey, discriminatorNode.Type()) + } + options := map[string]any{} + for _, optionNodeKey := range oneOfOptionsNode.MapKeys() { + optionNode, _ := oneOfOptionsNode.MapKey(optionNodeKey) + var err error + options[optionNodeKey], err = yamlBuildExpressions(optionNode, append(path, optionNodeKey)) + if err != nil { + return nil, err + } + } + + discriminator := discriminatorNode.Value() + return &infer.OneOfExpression{ + Discriminator: discriminator, + Options: options, + }, nil +} + func yamlBuildExpressions(data yaml.Node, path []string) (any, error) { - if data.Tag() == "!expr" { + switch data.Tag() { + case "!expr": if data.Type() != yaml.TypeIDString { - return nil, fmt.Errorf("!!expr found on non-string node at %s", strings.Join(path, " -> ")) + return nil, fmt.Errorf("!expr found on non-string node at %s", strings.Join(path, " -> ")) } expr, err := expressions.New(data.Value()) if err != nil { return nil, fmt.Errorf("failed to compile expression at %s (%w)", strings.Join(path, " -> "), err) } return expr, nil + case YamlOneOfTag: + return buildOneOfExpressions(data, path) } switch data.Type() { case yaml.TypeIDString: