From 7878b8e9e0fd985a76a1fb5db29fff7419e598b2 Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Thu, 30 Nov 2023 09:13:05 -0600 Subject: [PATCH] Implemented simple echo plugin for testing (#4489) * implemented simple echo plugin for testing Signed-off-by: Daniel Rammer * using external timestamp tracking to work with ArrayNode Signed-off-by: Daniel Rammer * Implemented sleep configuration Signed-off-by: Daniel Rammer * fixed lint Signed-off-by: Daniel Rammer * moved echo plugin to testing package Signed-off-by: Daniel Rammer * updated config to set pflags Signed-off-by: Daniel Rammer --------- Signed-off-by: Daniel Rammer Co-authored-by: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> --- .../go/tasks/plugins/testing/config.go | 23 +++ .../go/tasks/plugins/testing/config_flags.go | 55 ++++++ .../plugins/testing/config_flags_test.go | 116 +++++++++++++ flyteplugins/go/tasks/plugins/testing/echo.go | 161 ++++++++++++++++++ flytepropeller/plugins/loader.go | 1 + 5 files changed, 356 insertions(+) create mode 100644 flyteplugins/go/tasks/plugins/testing/config.go create mode 100755 flyteplugins/go/tasks/plugins/testing/config_flags.go create mode 100755 flyteplugins/go/tasks/plugins/testing/config_flags_test.go create mode 100644 flyteplugins/go/tasks/plugins/testing/echo.go diff --git a/flyteplugins/go/tasks/plugins/testing/config.go b/flyteplugins/go/tasks/plugins/testing/config.go new file mode 100644 index 0000000000..9508bcafa7 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/testing/config.go @@ -0,0 +1,23 @@ +package testing + +import ( + "time" + + "github.com/flyteorg/flyte/flyteplugins/go/tasks/config" + flytestdconfig "github.com/flyteorg/flyte/flytestdlib/config" +) + +//go:generate pflags Config --default-var=defaultConfig + +var ( + defaultConfig = Config{ + SleepDuration: flytestdconfig.Duration{Duration: 0 * time.Second}, + } + + ConfigSection = config.MustRegisterSubSection(echoTaskType, &defaultConfig) +) + +type Config struct { + // SleepDuration indicates the amount of time before transitioning to success + SleepDuration flytestdconfig.Duration `json:"sleep-duration" pflag:",Indicates the amount of time before transitioning to success"` +} diff --git a/flyteplugins/go/tasks/plugins/testing/config_flags.go b/flyteplugins/go/tasks/plugins/testing/config_flags.go new file mode 100755 index 0000000000..f4b2e60c7a --- /dev/null +++ b/flyteplugins/go/tasks/plugins/testing/config_flags.go @@ -0,0 +1,55 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package testing + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "sleep-duration"), defaultConfig.SleepDuration.String(), "Indicates the amount of time before transitioning to success") + return cmdFlags +} diff --git a/flyteplugins/go/tasks/plugins/testing/config_flags_test.go b/flyteplugins/go/tasks/plugins/testing/config_flags_test.go new file mode 100755 index 0000000000..023e8986e0 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/testing/config_flags_test.go @@ -0,0 +1,116 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package testing + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_sleep-duration", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.SleepDuration.String() + + cmdFlags.Set("sleep-duration", testValue) + if vString, err := cmdFlags.GetString("sleep-duration"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.SleepDuration) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flyteplugins/go/tasks/plugins/testing/echo.go b/flyteplugins/go/tasks/plugins/testing/echo.go new file mode 100644 index 0000000000..7c5587dd71 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/testing/echo.go @@ -0,0 +1,161 @@ +package testing + +import ( + "context" + "fmt" + "time" + + idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils" + "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/flyteorg/flyte/flytestdlib/storage" +) + +const ( + echoTaskType = "echo" +) + +type EchoPlugin struct { + enqueueOwner core.EnqueueOwner + taskStartTimes map[string]time.Time +} + +func (e *EchoPlugin) GetID() string { + return echoTaskType +} + +func (e *EchoPlugin) GetProperties() core.PluginProperties { + return core.PluginProperties{} +} + +func (e *EchoPlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { + echoConfig := ConfigSection.GetConfig().(*Config) + + var startTime time.Time + var exists bool + taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + if startTime, exists = e.taskStartTimes[taskExecutionID]; !exists { + startTime = time.Now() + e.taskStartTimes[taskExecutionID] = startTime + + // start timer to enqueue owner once task sleep duration has elapsed + go func() { + time.Sleep(echoConfig.SleepDuration.Duration) + if err := e.enqueueOwner(tCtx.TaskExecutionMetadata().GetOwnerID()); err != nil { + logger.Warnf(ctx, "failed to enqueue owner [%s]: %v", tCtx.TaskExecutionMetadata().GetOwnerID(), err) + } + }() + } + + if time.Since(startTime) >= echoConfig.SleepDuration.Duration { + // copy inputs to outputs + inputToOutputVariableMappings, err := compileInputToOutputVariableMappings(ctx, tCtx) + if err != nil { + return core.UnknownTransition, err + } + + if len(inputToOutputVariableMappings) > 0 { + inputLiterals, err := tCtx.InputReader().Get(ctx) + if err != nil { + return core.UnknownTransition, err + } + + outputLiterals := make(map[string]*idlcore.Literal, len(inputToOutputVariableMappings)) + for inputVariableName, outputVariableName := range inputToOutputVariableMappings { + outputLiterals[outputVariableName] = inputLiterals.Literals[inputVariableName] + } + + outputLiteralMap := &idlcore.LiteralMap{ + Literals: outputLiterals, + } + + outputFile := tCtx.OutputWriter().GetOutputPath() + if err := tCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap); err != nil { + return core.UnknownTransition, err + } + + or := ioutils.NewRemoteFileOutputReader(ctx, tCtx.DataStore(), tCtx.OutputWriter(), tCtx.MaxDatasetSizeBytes()) + if err = tCtx.OutputWriter().Put(ctx, or); err != nil { + return core.UnknownTransition, err + } + } + + return core.DoTransition(core.PhaseInfoSuccess(nil)), nil + } + + return core.DoTransition(core.PhaseInfoRunning(core.DefaultPhaseVersion, nil)), nil +} + +func (e *EchoPlugin) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error { + return nil +} + +func (e *EchoPlugin) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { + taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + delete(e.taskStartTimes, taskExecutionID) + return nil +} + +func compileInputToOutputVariableMappings(ctx context.Context, tCtx core.TaskExecutionContext) (map[string]string, error) { + // validate outputs are castable from inputs otherwise error as this plugin is not applicable + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return nil, fmt.Errorf("failed to read TaskTemplate: [%w]", err) + } + + var inputs, outputs map[string]*idlcore.Variable + if taskTemplate.Interface != nil { + if taskTemplate.Interface.Inputs != nil { + inputs = taskTemplate.Interface.Inputs.Variables + } + if taskTemplate.Interface.Outputs != nil { + outputs = taskTemplate.Interface.Outputs.Variables + } + } + + if len(inputs) != len(outputs) { + return nil, fmt.Errorf("the number of input [%d] and output [%d] variables does not match", len(inputs), len(outputs)) + } else if len(inputs) > 1 { + return nil, fmt.Errorf("this plugin does not currently support more than one input variable") + } + + inputToOutputVariableMappings := make(map[string]string) + outputVariableNameUsed := make(map[string]struct{}) + for inputVariableName := range inputs { + firstCastableOutputName := "" + for outputVariableName := range outputs { + // TODO - need to check if types are castable to support multiple values + if _, ok := outputVariableNameUsed[outputVariableName]; !ok { + firstCastableOutputName = outputVariableName + break + } + } + + if len(firstCastableOutputName) == 0 { + return nil, fmt.Errorf("no castable output variable found for input variable [%s]", inputVariableName) + } + + outputVariableNameUsed[firstCastableOutputName] = struct{}{} + inputToOutputVariableMappings[inputVariableName] = firstCastableOutputName + } + + return inputToOutputVariableMappings, nil +} + +func init() { + pluginmachinery.PluginRegistry().RegisterCorePlugin( + core.PluginEntry{ + ID: echoTaskType, + RegisteredTaskTypes: []core.TaskType{echoTaskType}, + LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { + return &EchoPlugin{ + enqueueOwner: iCtx.EnqueueOwner(), + taskStartTimes: make(map[string]time.Time), + }, nil + }, + IsDefault: true, + }, + ) +} diff --git a/flytepropeller/plugins/loader.go b/flytepropeller/plugins/loader.go index 8e1deffa67..a2983abb15 100644 --- a/flytepropeller/plugins/loader.go +++ b/flytepropeller/plugins/loader.go @@ -13,6 +13,7 @@ import ( _ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/pod" _ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray" _ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/spark" + _ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/testing" _ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/athena" _ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/bigquery" _ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/databricks"