diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/config.go b/flyteplugins/go/tasks/plugins/webapi/agent/config.go index a0014e9627..1045a7cf80 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/config.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/config.go @@ -44,6 +44,7 @@ var ( Insecure: true, DefaultTimeout: config.Duration{Duration: 10 * time.Second}, }, + SupportedTaskTypes: []string{"task_type_1", "task_type_2"}, } configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig) @@ -65,6 +66,9 @@ type Config struct { // Maps task types to their agents. {TaskType: AgentId} AgentForTaskTypes map[string]string `json:"agentForTaskTypes" pflag:"-,"` + + // SupportedTaskTypes is a list of task types that are supported by this plugin. + SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."` } type Agent struct { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index ce1c6b6841..0710bb3c28 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -161,7 +161,7 @@ func TestEndToEnd(t *testing.T) { tr.OnRead(context.Background()).Return(nil, fmt.Errorf("read fail")) tCtx.OnTaskReader().Return(tr) - agentPlugin := newAgentPlugin(SupportedTaskTypes{}) + agentPlugin := newAgentPlugin() pluginEntry := pluginmachinery.CreateRemotePlugin(agentPlugin) plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test3")) assert.NoError(t, err) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 9713ba90f3..483e6ed97c 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -29,8 +29,6 @@ import ( type GetClientFunc func(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) -type TaskType = string -type SupportedTaskTypes []TaskType type Plugin struct { metricScope promutils.Scope cfg *Config @@ -298,10 +296,8 @@ func getFinalContext(ctx context.Context, operation string, agent *Agent) (conte return context.WithTimeout(ctx, timeout) } -func newAgentPlugin(supportedTaskTypes SupportedTaskTypes) webapi.PluginEntry { - if len(supportedTaskTypes) == 0 { - supportedTaskTypes = SupportedTaskTypes{"default_supported_task_type"} - } +func newAgentPlugin() webapi.PluginEntry { + supportedTaskTypes := GetConfig().SupportedTaskTypes return webapi.PluginEntry{ ID: "agent-service", @@ -317,9 +313,9 @@ func newAgentPlugin(supportedTaskTypes SupportedTaskTypes) webapi.PluginEntry { } } -func RegisterAgentPlugin(supportedTaskTypes SupportedTaskTypes) { +func RegisterAgentPlugin() { gob.Register(ResourceMetaWrapper{}) gob.Register(ResourceWrapper{}) - pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin(supportedTaskTypes)) + pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin()) } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 24e93bf1c8..7a4ea350b6 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -43,7 +43,7 @@ func TestPlugin(t *testing.T) { }) t.Run("test newAgentPlugin", func(t *testing.T) { - p := newAgentPlugin(SupportedTaskTypes{}) + p := newAgentPlugin() assert.NotNil(t, p) assert.Equal(t, "agent-service", p.ID) assert.NotNil(t, p.PluginLoader) diff --git a/flytepropeller/pkg/controller/nodes/task/plugin_config.go b/flytepropeller/pkg/controller/nodes/task/plugin_config.go index a3fc5a4cfa..11b4bc6790 100644 --- a/flytepropeller/pkg/controller/nodes/task/plugin_config.go +++ b/flytepropeller/pkg/controller/nodes/task/plugin_config.go @@ -16,8 +16,6 @@ import ( "github.com/flyteorg/flyte/flytestdlib/logger" ) -const AgentServiceKey = "agent-service" - var once sync.Once func WranglePluginsAndGenerateFinalList(ctx context.Context, cfg *config.TaskPluginConfig, pr PluginRegistryIface, @@ -27,8 +25,8 @@ func WranglePluginsAndGenerateFinalList(ctx context.Context, cfg *config.TaskPlu } // Register the GRPC plugin after the config is loaded + once.Do(func() { agent.RegisterAgentPlugin() }) pluginsConfigMeta, err := cfg.GetEnabledPlugins() - once.Do(func() { agent.RegisterAgentPlugin(pluginsConfigMeta.AllDefaultForTaskTypes[AgentServiceKey]) }) if err != nil { return nil, nil, err diff --git a/rsts/deployment/agents/index.rst b/rsts/deployment/agents/index.rst index b9f5e2450b..937e1d9f65 100644 --- a/rsts/deployment/agents/index.rst +++ b/rsts/deployment/agents/index.rst @@ -45,3 +45,4 @@ Discover the process of setting up Agents for Flyte. bigquery mmcloud databricks +=======