From 3b07407a1a50e158f933d33dd41a6543609fea04 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 20 Dec 2023 06:36:35 +0800 Subject: [PATCH] change getFinalAgent implementation and interface Signed-off-by: Future Outlier --- .../go/tasks/plugins/webapi/agent/plugin.go | 24 +++++++------------ .../tasks/plugins/webapi/agent/plugin_test.go | 13 +++++----- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index a67239bf89..d8e5b3651c 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -94,10 +94,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR } outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() - agent, err := getFinalAgent(taskTemplate.Type, p.cfg) - if err != nil { - return nil, nil, fmt.Errorf("failed to find agent agent with error: %v", err) - } + agent := getFinalAgent(taskTemplate.Type, p.cfg, p.agentRegistry) client, err := p.getClient(ctx, agent, p.connectionCache) if err != nil { @@ -129,7 +126,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent, err := getFinalAgent(metadata.TaskType, p.cfg) + agent := getFinalAgent(metadata.TaskType, p.cfg, p.agentRegistry) if err != nil { return nil, fmt.Errorf("failed to find agent with error: %v", err) } @@ -161,10 +158,8 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error } metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent, err := getFinalAgent(metadata.TaskType, p.cfg) - if err != nil { - return fmt.Errorf("failed to find agent agent with error: %v", err) - } + agent := getFinalAgent(metadata.TaskType, p.cfg, p.agentRegistry) + client, err := p.getClient(ctx, agent, p.connectionCache) if err != nil { return fmt.Errorf("failed to connect to agent with error: %v", err) @@ -223,15 +218,12 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, resource Res return taskCtx.OutputWriter().Put(ctx, opReader) } -func getFinalAgent(taskType string, cfg *Config) (*Agent, error) { - if id, exists := cfg.AgentForTaskTypes[taskType]; exists { - if agent, exists := cfg.Agents[id]; exists { - return agent, nil - } - return nil, fmt.Errorf("no agent definition found for ID %s that matches task type %s", id, taskType) +func getFinalAgent(taskType string, cfg *Config, agentRegistry map[string]*Agent) *Agent { + if agent, exists := agentRegistry[taskType]; exists { + return agent } - return &cfg.DefaultAgent, nil + return &cfg.DefaultAgent } func getGrpcConnection(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (*grpc.ClientConn, error) { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index bd5ab12a83..c6da80a99c 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -57,12 +57,13 @@ func TestPlugin(t *testing.T) { }) t.Run("test getFinalAgent", func(t *testing.T) { - agent, _ := getFinalAgent("spark", &cfg) - assert.Equal(t, cfg.Agents["spark_agent"].Endpoint, agent.Endpoint) - agent, _ = getFinalAgent("foo", &cfg) - assert.Equal(t, cfg.DefaultAgent.Endpoint, agent.Endpoint) - _, err := getFinalAgent("bar", &cfg) - assert.NotNil(t, err) + agentRegistry := map[string]*Agent{"spark": {Endpoint: "localhost:80"}} + agent := getFinalAgent("spark", &cfg, agentRegistry) + assert.Equal(t, agent.Endpoint, "localhost:80") + agent = getFinalAgent("foo", &cfg, agentRegistry) + assert.Equal(t, agent.Endpoint, cfg.DefaultAgent.Endpoint) + agent = getFinalAgent("bar", &cfg, agentRegistry) + assert.Equal(t, agent.Endpoint, cfg.DefaultAgent.Endpoint) }) t.Run("test getAgentMetadataClientFunc", func(t *testing.T) {