Skip to content

Commit

Permalink
Transfer commits
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario committed Oct 3, 2023
2 parents d9586b0 + f6f3342 commit 09112b8
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 21 deletions.
6 changes: 6 additions & 0 deletions flyteplugins/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ require (
sigs.k8s.io/yaml v1.3.0 // indirect
)

<<<<<<< HEAD
replace (
github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d
github.com/flyteorg/flyte/datacatalog => ../datacatalog
Expand All @@ -144,3 +145,8 @@ replace (
github.com/flyteorg/flyte/flytestdlib => ../flytestdlib
github.com/flyteorg/flyteidl => ../flyteidl
)
=======
replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d

replace github.com/flyteorg/flyteidl => /mnt/c/code/dev/flyteidl
>>>>>>> flyteplugins/dev-sync-plugin
20 changes: 19 additions & 1 deletion flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const (
type CorePlugin struct {
id string
p webapi.AsyncPlugin
sp webapi.SyncPlugin
cache cache.AutoRefresh
tokenAllocator tokenAllocator
metrics Metrics
Expand Down Expand Up @@ -68,12 +69,28 @@ func (c CorePlugin) GetProperties() core.PluginProperties {
return core.PluginProperties{}
}

// syncHandle
// TODO: ADD Sync Handle
func (c CorePlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
incomingState, err := c.unmarshalState(ctx, tCtx.PluginStateReader())
if err != nil {
return core.UnknownTransition, err
}

taskTemplate, err := tCtx.TaskReader().Read(ctx)

if taskTemplate.Type == "dispatcher" {
res, err := c.sp.Do(ctx, tCtx)
if err != nil {
return core.UnknownTransition, err
}
logger.Infof(ctx, "@@@ SyncPlugin [%v] returned result: %v", c.GetID(), res)
// if err := tCtx.PluginStateWriter().Put(pluginStateVersion, nextState); err != nil {
// return core.UnknownTransition, err
// }
return core.DoTransition(core.PhaseInfoSuccess(nil)), nil
}

var nextState *State
var phaseInfo core.PhaseInfo
switch incomingState.Phase {
Expand Down Expand Up @@ -165,7 +182,7 @@ func createRemotePlugin(pluginEntry webapi.PluginEntry, c clock.Clock) core.Plug
RegisteredTaskTypes: pluginEntry.SupportedTaskTypes,
LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (
core.Plugin, error) {
p, err := pluginEntry.PluginLoader(ctx, iCtx)
p, sp, err := pluginEntry.PluginLoader(ctx, iCtx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -205,6 +222,7 @@ func createRemotePlugin(pluginEntry webapi.PluginEntry, c clock.Clock) core.Plug
return CorePlugin{
id: pluginEntry.ID,
p: p,
sp: sp,
cache: resourceCache,
metrics: newMetrics(iCtx.MetricsScope()),
tokenAllocator: newTokenAllocator(c),
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/pluginmachinery/webapi/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (

// A Lazy loading function, that will load the plugin. Plugins should be initialized in this method. It is guaranteed
// that the plugin loader will be called before any Handle/Abort/Finalize functions are invoked
type PluginLoader func(ctx context.Context, iCtx PluginSetupContext) (AsyncPlugin, error)
type PluginLoader func(ctx context.Context, iCtx PluginSetupContext) (AsyncPlugin, SyncPlugin, error)

// PluginEntry is a structure that is used to indicate to the system a K8s plugin
type PluginEntry struct {
Expand Down Expand Up @@ -150,5 +150,5 @@ type SyncPlugin interface {
GetConfig() PluginConfig

// Do performs the action associated with this plugin.
Do(ctx context.Context, tCtx TaskExecutionContext) (phase pluginsCore.PhaseInfo, err error)
Do(ctx context.Context, tCtx TaskExecutionContextReader) (latest Resource, err error)
}
85 changes: 79 additions & 6 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,73 @@ func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionCo
return "default", p.cfg.ResourceConstraints, nil
}

func (p Plugin) Do(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (latest webapi.Resource, err error) {
// write the resource here
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
return nil, err
}

inputs, err := taskCtx.InputReader().Get(ctx)
if err != nil {
return nil, err
}

if taskTemplate.GetContainer() != nil {
templateParameters := template.Parameters{
TaskExecMetadata: taskCtx.TaskExecutionMetadata(),
Inputs: taskCtx.InputReader(),
OutputPath: taskCtx.OutputWriter(),
Task: taskCtx.TaskReader(),
}
modifiedArgs, err := template.Render(ctx, taskTemplate.GetContainer().Args, templateParameters)
if err != nil {
return nil, err
}
taskTemplate.GetContainer().Args = modifiedArgs
}

agent, err := getFinalAgent(taskTemplate.Type, p.cfg)
if err != nil {
return nil, fmt.Errorf("failed to find agent agent with error: %v", err)
}

client, err := p.getClient(ctx, agent, p.connectionCache)
if err != nil {
return nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

finalCtx, cancel := getFinalContext(ctx, "DoTask", agent)

defer cancel()

// taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
// write it in agent?

logger.Infof(ctx, "@@@ inputs: [%v]", inputs)
logger.Infof(ctx, "@@@ taskTemplate: [%v]", taskTemplate)

res, err := client.DoTask(finalCtx, &admin.DoTaskRequest{Inputs: inputs, Template: taskTemplate})
if err != nil {
return nil, err
}

logger.Infof(ctx, "@@@ res.Resource.State: [%v]", res.Resource.State)
logger.Infof(ctx, "@@@ res.Resource.Outputs: [%v]", res.Resource.Outputs)

return &ResourceWrapper{
State: res.Resource.State,
Outputs: res.Resource.Outputs,
}, nil
}

// todo: write the output

// we can get the task type in core.go
/*
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
taskTemplate.type = spark, dispatcher ...
*/
func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta,
webapi.Resource, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
Expand Down Expand Up @@ -298,6 +365,7 @@ func getFinalContext(ctx context.Context, operation string, agent *Agent) (conte
return context.WithTimeout(ctx, timeout)
}

// TODO: Add sync agent plugin
func newAgentPlugin(supportedTaskTypes SupportedTaskTypes) webapi.PluginEntry {
if len(supportedTaskTypes) == 0 {
supportedTaskTypes = SupportedTaskTypes{"default_supported_task_type"}
Expand All @@ -306,13 +374,18 @@ func newAgentPlugin(supportedTaskTypes SupportedTaskTypes) webapi.PluginEntry {
return webapi.PluginEntry{
ID: "agent-service",
SupportedTaskTypes: supportedTaskTypes,
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
getClient: getClientFunc,
connectionCache: make(map[*Agent]*grpc.ClientConn),
}, nil
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
getClient: getClientFunc,
connectionCache: make(map[*Agent]*grpc.ClientConn),
}, &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
getClient: getClientFunc,
connectionCache: make(map[*Agent]*grpc.ClientConn),
}, nil
},
}
}
Expand Down
8 changes: 4 additions & 4 deletions flyteplugins/go/tasks/plugins/webapi/athena/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,25 +200,25 @@ func createTaskInfo(queryID string, cfg awsSdk.Config) *core.TaskInfo {
}
}

func NewPlugin(_ context.Context, cfg *Config, awsConfig *aws.Config, metricScope promutils.Scope) (Plugin, error) {
func NewPlugin(_ context.Context, cfg *Config, awsConfig *aws.Config, metricScope promutils.Scope) (Plugin, webapi.SyncPlugin, error) {
sdkCfg, err := awsConfig.GetSdkConfig()
if err != nil {
return Plugin{}, err
return Plugin{}, nil, err
}

return Plugin{
metricScope: metricScope,
client: athena.NewFromConfig(sdkCfg),
cfg: cfg,
awsConfig: sdkCfg,
}, nil
}, nil, nil
}

func init() {
pluginmachinery.PluginRegistry().RegisterRemotePlugin(webapi.PluginEntry{
ID: "athena",
SupportedTaskTypes: []core.TaskType{"hive", "presto"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return NewPlugin(ctx, GetConfig(), aws.GetConfig(), iCtx.MetricsScope())
},
})
Expand Down
8 changes: 4 additions & 4 deletions flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,25 +547,25 @@ func (p Plugin) newBigQueryClient(ctx context.Context, identity google.Identity)
return bigquery.NewService(ctx, options...)
}

func NewPlugin(cfg *Config, metricScope promutils.Scope) (*Plugin, error) {
func NewPlugin(cfg *Config, metricScope promutils.Scope) (*Plugin, webapi.SyncPlugin, error) {
googleTokenSource, err := google.NewTokenSourceFactory(cfg.GoogleTokenSource)

if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.PluginInitializationFailed, err, "failed to get google token source")
return nil, nil, pluginErrors.Wrapf(pluginErrors.PluginInitializationFailed, err, "failed to get google token source")
}

return &Plugin{
metricScope: metricScope,
cfg: cfg,
googleTokenSource: googleTokenSource,
}, nil
}, nil, nil
}

func newBigQueryJobTaskPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "bigquery",
SupportedTaskTypes: []core.TaskType{bigqueryQueryJobTask},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
cfg := GetConfig()

return NewPlugin(cfg, iCtx.MetricsScope())
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,12 @@ func newDatabricksJobTaskPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "databricks",
SupportedTaskTypes: []core.TaskType{"spark"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
client: &http.Client{},
}, nil
}, nil, nil
},
}
}
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/webapi/snowflake/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,12 @@ func newSnowflakeJobTaskPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "snowflake",
SupportedTaskTypes: []core.TaskType{"snowflake"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
client: &http.Client{},
}, nil
}, nil, nil
},
}
}
Expand Down

0 comments on commit 09112b8

Please sign in to comment.