diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/client.go b/flyteplugins/go/tasks/pluginmachinery/k8s/client.go index f14ae2c8a0..0ab46081e9 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/client.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/client.go @@ -69,7 +69,7 @@ func NewKubeClient(config *rest.Config, options Options) (core.KubeClient, error if options.ClientOptions == nil { options.ClientOptions = &client.Options{ HTTPClient: httpClient, - Mapper: mapper, + Mapper: mapper, } } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go new file mode 100644 index 0000000000..139df552b7 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -0,0 +1,43 @@ +package agent + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" +) + +type GetAgentClientFunc func(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) +type GetAgentMetadataClientFunc func(ctx context.Context, agent *Agent, connCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) + +// Clientset contains the clients exposed to communicate with various agent services. +type ClientFuncSet struct { + getAgentClient GetAgentClientFunc + getAgentMetadataClient GetAgentMetadataClientFunc +} + +func getAgentClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { + conn, err := getGrpcConnection(ctx, agent, connectionCache) + if err != nil { + return nil, err + } + + return service.NewAsyncAgentServiceClient(conn), nil +} + +func getAgentMetadataClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) { + conn, err := getGrpcConnection(ctx, agent, connectionCache) + if err != nil { + return nil, err + } + + return service.NewAgentMetadataServiceClient(conn), nil +} + +func initializeClientFunc() *ClientFuncSet { + return &ClientFuncSet{ + getAgentClient: getAgentClientFunc, + getAgentMetadataClient: getAgentMetadataClientFunc, + } +} diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go new file mode 100644 index 0000000000..5dfbe2f521 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -0,0 +1,14 @@ +package agent + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInitializeClientFunc(t *testing.T) { + cs := initializeClientFunc() + assert.NotNil(t, cs) + assert.NotNil(t, cs.getAgentClient) + assert.NotNil(t, cs.getAgentMetadataClient) +} diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 827af0d907..573145d29e 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -171,7 +171,9 @@ func TestEndToEnd(t *testing.T) { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - getClient: mockGetBadAsyncClientFunc, + cs: &ClientFuncSet{ + getAgentClient: mockGetBadAsyncClientFunc, + }, }, }, nil } @@ -311,7 +313,9 @@ func newMockAgentPlugin() webapi.PluginEntry { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - getClient: mockAsyncTaskClientFunc, + cs: &ClientFuncSet{ + getAgentClient: mockAsyncTaskClientFunc, + }, }, }, nil }, @@ -327,7 +331,9 @@ func newMockSyncAgentPlugin() webapi.PluginEntry { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - getClient: mockSyncTaskClientFunc, + cs: &ClientFuncSet{ + getAgentClient: mockSyncTaskClientFunc, + }, }, }, nil }, diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index b20bd62d7a..9badd074eb 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -17,7 +17,6 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" @@ -30,13 +29,10 @@ import ( "github.com/flyteorg/flyte/flytestdlib/promutils" ) -type GetClientFunc func(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) -type GetAgentMetadataClientFunc func(ctx context.Context, agent *Agent, connCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) - type Plugin struct { metricScope promutils.Scope cfg *Config - getClient GetClientFunc + cs *ClientFuncSet connectionCache map[*Agent]*grpc.ClientConn agentRegistry map[string]*Agent // map[taskType] => Agent } @@ -96,7 +92,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR agent := getFinalAgent(taskTemplate.Type, p.cfg, p.agentRegistry) - client, err := p.getClient(ctx, agent, p.connectionCache) + client, err := p.cs.getAgentClient(ctx, agent, p.connectionCache) if err != nil { return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err) } @@ -145,7 +141,7 @@ func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest weba return nil, fmt.Errorf("failed to find agent with error: %v", err) } - client, err := p.getClient(ctx, agent, p.connectionCache) + client, err := p.cs.getAgentClient(ctx, agent, p.connectionCache) if err != nil { return nil, fmt.Errorf("failed to connect to agent with error: %v", err) } @@ -174,7 +170,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error agent := getFinalAgent(metadata.TaskType, p.cfg, p.agentRegistry) - client, err := p.getClient(ctx, agent, p.connectionCache) + client, err := p.cs.getAgentClient(ctx, agent, p.connectionCache) if err != nil { return fmt.Errorf("failed to connect to agent with error: %v", err) } @@ -287,24 +283,6 @@ func getGrpcConnection(ctx context.Context, agent *Agent, connectionCache map[*A return conn, nil } -func getClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - conn, err := getGrpcConnection(ctx, agent, connectionCache) - if err != nil { - return nil, err - } - - return service.NewAsyncAgentServiceClient(conn), nil -} - -func getAgentMetadataClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) { - conn, err := getGrpcConnection(ctx, agent, connectionCache) - if err != nil { - return nil, err - } - - return service.NewAgentMetadataServiceClient(conn), nil -} - func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata) admin.TaskExecutionMetadata { taskExecutionID := taskExecutionMetadata.GetTaskExecutionID().GetID() return admin.TaskExecutionMetadata{ @@ -334,7 +312,7 @@ func getFinalContext(ctx context.Context, operation string, agent *Agent) (conte return context.WithTimeout(ctx, timeout) } -func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.ClientConn, getAgentMetadataClientFunc GetAgentMetadataClientFunc) (map[string]*Agent, error) { +func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.ClientConn, cs *ClientFuncSet) (map[string]*Agent, error) { agentRegistry := make(map[string]*Agent) var agentDeployments []*Agent @@ -348,7 +326,7 @@ func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.Clien } agentDeployments = append(agentDeployments, maps.Values(cfg.Agents)...) for _, agentDeployment := range agentDeployments { - client, err := getAgentMetadataClientFunc(context.Background(), agentDeployment, connectionCache) + client, err := cs.getAgentMetadataClient(context.Background(), agentDeployment, connectionCache) if err != nil { return nil, fmt.Errorf("failed to connect to agent [%v] with error: [%v]", agentDeployment, err) } @@ -385,9 +363,10 @@ func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.Clien } func newAgentPlugin() webapi.PluginEntry { + cs := initializeClientFunc() cfg := GetConfig() connectionCache := make(map[*Agent]*grpc.ClientConn) - agentRegistry, err := initializeAgentRegistry(cfg, connectionCache, getAgentMetadataClientFunc) + agentRegistry, err := initializeAgentRegistry(cfg, connectionCache, cs) if err != nil { // We should wait for all agents to be up and running before starting the server panic(err) @@ -403,7 +382,7 @@ func newAgentPlugin() webapi.PluginEntry { return &Plugin{ metricScope: iCtx.MetricsScope(), cfg: cfg, - getClient: getClientFunc, + cs: cs, connectionCache: connectionCache, agentRegistry: agentRegistry, }, nil diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index b9fd7e1b35..a90f265ae8 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -108,13 +108,15 @@ func TestPlugin(t *testing.T) { }) t.Run("test getClientFunc", func(t *testing.T) { - client, err := getClientFunc(context.Background(), &Agent{Endpoint: "localhost:80"}, map[*Agent]*grpc.ClientConn{}) + cs := initializeClientFunc() + client, err := cs.getAgentClient(context.Background(), &Agent{Endpoint: "localhost:80"}, map[*Agent]*grpc.ClientConn{}) assert.NoError(t, err) assert.NotNil(t, client) }) t.Run("test getClientFunc more config", func(t *testing.T) { - client, err := getClientFunc(context.Background(), &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[*Agent]*grpc.ClientConn{}) + cs := initializeClientFunc() + client, err := cs.getAgentClient(context.Background(), &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[*Agent]*grpc.ClientConn{}) assert.NoError(t, err) assert.NotNil(t, client) }) @@ -123,12 +125,13 @@ func TestPlugin(t *testing.T) { connectionCache := make(map[*Agent]*grpc.ClientConn) agent := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"} - client, err := getClientFunc(context.Background(), agent, connectionCache) + cs := initializeClientFunc() + client, err := cs.getAgentClient(context.Background(), agent, connectionCache) assert.NoError(t, err) assert.NotNil(t, client) assert.NotNil(t, client, connectionCache[agent]) - cachedClient, err := getClientFunc(context.Background(), agent, connectionCache) + cachedClient, err := cs.getAgentClient(context.Background(), agent, connectionCache) assert.NoError(t, err) assert.NotNil(t, cachedClient) assert.Equal(t, client, cachedClient) @@ -238,11 +241,14 @@ func TestInitializeAgentRegistry(t *testing.T) { return mockClient, nil } + cs := initializeClientFunc() + cs.getAgentMetadataClient = getAgentMetadataClientFunc + cfg := defaultConfig cfg.Agents = map[string]*Agent{"custom_agent": {Endpoint: "localhost:80"}} cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} connectionCache := make(map[*Agent]*grpc.ClientConn) - agentRegistry, err := initializeAgentRegistry(&cfg, connectionCache, getAgentMetadataClientFunc) + agentRegistry, err := initializeAgentRegistry(&cfg, connectionCache, cs) assert.NoError(t, err) // In golang, the order of keys in a map is random. So, we sort the keys before asserting.