Skip to content

Commit

Permalink
Implemented simple echo plugin for testing (#4489)
Browse files Browse the repository at this point in the history
* implemented simple echo plugin for testing

Signed-off-by: Daniel Rammer <[email protected]>

* using external timestamp tracking to work with ArrayNode

Signed-off-by: Daniel Rammer <[email protected]>

* Implemented sleep configuration

Signed-off-by: Daniel Rammer <[email protected]>

* fixed lint

Signed-off-by: Daniel Rammer <[email protected]>

* moved echo plugin to testing package

Signed-off-by: Daniel Rammer <[email protected]>

* updated config to set pflags

Signed-off-by: Daniel Rammer <[email protected]>

---------

Signed-off-by: Daniel Rammer <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
hamersaw and eapolinario authored Nov 30, 2023
1 parent 7d712de commit 7878b8e
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 0 deletions.
23 changes: 23 additions & 0 deletions flyteplugins/go/tasks/plugins/testing/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package testing

import (
"time"

"github.com/flyteorg/flyte/flyteplugins/go/tasks/config"
flytestdconfig "github.com/flyteorg/flyte/flytestdlib/config"
)

//go:generate pflags Config --default-var=defaultConfig

var (
defaultConfig = Config{
SleepDuration: flytestdconfig.Duration{Duration: 0 * time.Second},
}

ConfigSection = config.MustRegisterSubSection(echoTaskType, &defaultConfig)
)

type Config struct {
// SleepDuration indicates the amount of time before transitioning to success
SleepDuration flytestdconfig.Duration `json:"sleep-duration" pflag:",Indicates the amount of time before transitioning to success"`
}
55 changes: 55 additions & 0 deletions flyteplugins/go/tasks/plugins/testing/config_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

116 changes: 116 additions & 0 deletions flyteplugins/go/tasks/plugins/testing/config_flags_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

161 changes: 161 additions & 0 deletions flyteplugins/go/tasks/plugins/testing/echo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package testing

import (
"context"
"fmt"
"time"

idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyte/flytestdlib/logger"
"github.com/flyteorg/flyte/flytestdlib/storage"
)

const (
echoTaskType = "echo"
)

type EchoPlugin struct {
enqueueOwner core.EnqueueOwner
taskStartTimes map[string]time.Time
}

func (e *EchoPlugin) GetID() string {
return echoTaskType
}

func (e *EchoPlugin) GetProperties() core.PluginProperties {
return core.PluginProperties{}
}

func (e *EchoPlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
echoConfig := ConfigSection.GetConfig().(*Config)

var startTime time.Time
var exists bool
taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
if startTime, exists = e.taskStartTimes[taskExecutionID]; !exists {
startTime = time.Now()
e.taskStartTimes[taskExecutionID] = startTime

// start timer to enqueue owner once task sleep duration has elapsed
go func() {
time.Sleep(echoConfig.SleepDuration.Duration)
if err := e.enqueueOwner(tCtx.TaskExecutionMetadata().GetOwnerID()); err != nil {
logger.Warnf(ctx, "failed to enqueue owner [%s]: %v", tCtx.TaskExecutionMetadata().GetOwnerID(), err)
}
}()
}

if time.Since(startTime) >= echoConfig.SleepDuration.Duration {
// copy inputs to outputs
inputToOutputVariableMappings, err := compileInputToOutputVariableMappings(ctx, tCtx)
if err != nil {
return core.UnknownTransition, err
}

if len(inputToOutputVariableMappings) > 0 {
inputLiterals, err := tCtx.InputReader().Get(ctx)
if err != nil {
return core.UnknownTransition, err
}

outputLiterals := make(map[string]*idlcore.Literal, len(inputToOutputVariableMappings))
for inputVariableName, outputVariableName := range inputToOutputVariableMappings {
outputLiterals[outputVariableName] = inputLiterals.Literals[inputVariableName]
}

outputLiteralMap := &idlcore.LiteralMap{
Literals: outputLiterals,
}

outputFile := tCtx.OutputWriter().GetOutputPath()
if err := tCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap); err != nil {
return core.UnknownTransition, err
}

or := ioutils.NewRemoteFileOutputReader(ctx, tCtx.DataStore(), tCtx.OutputWriter(), tCtx.MaxDatasetSizeBytes())
if err = tCtx.OutputWriter().Put(ctx, or); err != nil {
return core.UnknownTransition, err
}
}

return core.DoTransition(core.PhaseInfoSuccess(nil)), nil
}

return core.DoTransition(core.PhaseInfoRunning(core.DefaultPhaseVersion, nil)), nil
}

func (e *EchoPlugin) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error {
return nil
}

func (e *EchoPlugin) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error {
taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
delete(e.taskStartTimes, taskExecutionID)
return nil
}

func compileInputToOutputVariableMappings(ctx context.Context, tCtx core.TaskExecutionContext) (map[string]string, error) {
// validate outputs are castable from inputs otherwise error as this plugin is not applicable
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return nil, fmt.Errorf("failed to read TaskTemplate: [%w]", err)
}

var inputs, outputs map[string]*idlcore.Variable
if taskTemplate.Interface != nil {
if taskTemplate.Interface.Inputs != nil {
inputs = taskTemplate.Interface.Inputs.Variables
}
if taskTemplate.Interface.Outputs != nil {
outputs = taskTemplate.Interface.Outputs.Variables
}
}

if len(inputs) != len(outputs) {
return nil, fmt.Errorf("the number of input [%d] and output [%d] variables does not match", len(inputs), len(outputs))
} else if len(inputs) > 1 {
return nil, fmt.Errorf("this plugin does not currently support more than one input variable")
}

inputToOutputVariableMappings := make(map[string]string)
outputVariableNameUsed := make(map[string]struct{})
for inputVariableName := range inputs {
firstCastableOutputName := ""
for outputVariableName := range outputs {
// TODO - need to check if types are castable to support multiple values
if _, ok := outputVariableNameUsed[outputVariableName]; !ok {
firstCastableOutputName = outputVariableName
break
}
}

if len(firstCastableOutputName) == 0 {
return nil, fmt.Errorf("no castable output variable found for input variable [%s]", inputVariableName)
}

outputVariableNameUsed[firstCastableOutputName] = struct{}{}
inputToOutputVariableMappings[inputVariableName] = firstCastableOutputName
}

return inputToOutputVariableMappings, nil
}

func init() {
pluginmachinery.PluginRegistry().RegisterCorePlugin(
core.PluginEntry{
ID: echoTaskType,
RegisteredTaskTypes: []core.TaskType{echoTaskType},
LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) {
return &EchoPlugin{
enqueueOwner: iCtx.EnqueueOwner(),
taskStartTimes: make(map[string]time.Time),
}, nil
},
IsDefault: true,
},
)
}
1 change: 1 addition & 0 deletions flytepropeller/plugins/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
_ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/pod"
_ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/ray"
_ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/spark"
_ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/testing"
_ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/athena"
_ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/bigquery"
_ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/databricks"
Expand Down

0 comments on commit 7878b8e

Please sign in to comment.