Skip to content

Commit

Permalink
change getFinalAgent implementation and interface
Browse files Browse the repository at this point in the history
Signed-off-by: Future Outlier <[email protected]>
  • Loading branch information
Future Outlier committed Dec 19, 2023
1 parent 3a4ae5c commit 3b07407
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 22 deletions.
24 changes: 8 additions & 16 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
}
outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

agent, err := getFinalAgent(taskTemplate.Type, p.cfg)
if err != nil {
return nil, nil, fmt.Errorf("failed to find agent agent with error: %v", err)
}
agent := getFinalAgent(taskTemplate.Type, p.cfg, p.agentRegistry)

client, err := p.getClient(ctx, agent, p.connectionCache)
if err != nil {
Expand Down Expand Up @@ -129,7 +126,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)

agent, err := getFinalAgent(metadata.TaskType, p.cfg)
agent := getFinalAgent(metadata.TaskType, p.cfg, p.agentRegistry)
if err != nil {
return nil, fmt.Errorf("failed to find agent with error: %v", err)
}
Expand Down Expand Up @@ -161,10 +158,8 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
}
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)

agent, err := getFinalAgent(metadata.TaskType, p.cfg)
if err != nil {
return fmt.Errorf("failed to find agent agent with error: %v", err)
}
agent := getFinalAgent(metadata.TaskType, p.cfg, p.agentRegistry)

Check warning on line 162 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L161-L162

Added lines #L161 - L162 were not covered by tests
client, err := p.getClient(ctx, agent, p.connectionCache)
if err != nil {
return fmt.Errorf("failed to connect to agent with error: %v", err)
Expand Down Expand Up @@ -223,15 +218,12 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, resource Res
return taskCtx.OutputWriter().Put(ctx, opReader)
}

func getFinalAgent(taskType string, cfg *Config) (*Agent, error) {
if id, exists := cfg.AgentForTaskTypes[taskType]; exists {
if agent, exists := cfg.Agents[id]; exists {
return agent, nil
}
return nil, fmt.Errorf("no agent definition found for ID %s that matches task type %s", id, taskType)
func getFinalAgent(taskType string, cfg *Config, agentRegistry map[string]*Agent) *Agent {
if agent, exists := agentRegistry[taskType]; exists {
return agent
}

return &cfg.DefaultAgent, nil
return &cfg.DefaultAgent
}

func getGrpcConnection(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (*grpc.ClientConn, error) {
Expand Down
13 changes: 7 additions & 6 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ func TestPlugin(t *testing.T) {
})

t.Run("test getFinalAgent", func(t *testing.T) {
agent, _ := getFinalAgent("spark", &cfg)
assert.Equal(t, cfg.Agents["spark_agent"].Endpoint, agent.Endpoint)
agent, _ = getFinalAgent("foo", &cfg)
assert.Equal(t, cfg.DefaultAgent.Endpoint, agent.Endpoint)
_, err := getFinalAgent("bar", &cfg)
assert.NotNil(t, err)
agentRegistry := map[string]*Agent{"spark": {Endpoint: "localhost:80"}}
agent := getFinalAgent("spark", &cfg, agentRegistry)
assert.Equal(t, agent.Endpoint, "localhost:80")
agent = getFinalAgent("foo", &cfg, agentRegistry)
assert.Equal(t, agent.Endpoint, cfg.DefaultAgent.Endpoint)
agent = getFinalAgent("bar", &cfg, agentRegistry)
assert.Equal(t, agent.Endpoint, cfg.DefaultAgent.Endpoint)
})

t.Run("test getAgentMetadataClientFunc", func(t *testing.T) {
Expand Down

0 comments on commit 3b07407

Please sign in to comment.