From 48d78d98d2ae128323c219be5a93ffc8cc6ff972 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 24 Jan 2024 23:05:56 +0800 Subject: [PATCH] add mockery AsyncAgentClient Signed-off-by: Future-Outlier --- .../go/tasks/plugins/webapi/agent/client.go | 2 + .../tasks/plugins/webapi/agent/client_test.go | 16 +- .../plugins/webapi/agent/integration_test.go | 22 +-- .../agent/mocks/AsyncAgentServiceClient.go | 162 ++++++++++++++++++ .../go/tasks/plugins/webapi/agent/plugin.go | 4 +- .../tasks/plugins/webapi/agent/plugin_test.go | 12 +- 6 files changed, 194 insertions(+), 24 deletions(-) create mode 100644 flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 7277250e60..9cfcb2f4db 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -4,6 +4,7 @@ import ( "context" "crypto/x509" "fmt" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/logger" @@ -88,6 +89,7 @@ func initializeAgentRegistry(cs *ClientSet) (map[string]*Agent, error) { agentRegistry := make(map[string]*Agent) cfg := GetConfig() var agentDeployments []*Agent + fmt.Printf("@@@ cfg.AgentForTaskTypes: [%v]\n", cfg.AgentForTaskTypes) // Ensure that the old configuration is backward compatible for taskType, agentID := range cfg.AgentForTaskTypes { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index 11c235414c..3d992c6a85 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -2,11 +2,12 @@ package agent import ( "context" + "testing" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "testing" ) func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { @@ -25,8 +26,9 @@ func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { return mockMetadataServiceClient } -func getMockServiceClient() *agentMocks.AgentMetadataServiceClient { - mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient) +// TODO, USE CREATE, GET DELETE FUNCTION TO MOCK THE OUTPUT +func getMockServiceClient() *agentMocks.AsyncAgentServiceClient { + mockServiceClient := new(agentMocks.AsyncAgentServiceClient) mockRequest := &admin.ListAgentsRequest{} mockResponse := &admin.ListAgentsResponse{ Agents: []*admin.Agent{ @@ -37,8 +39,12 @@ func getMockServiceClient() *agentMocks.AgentMetadataServiceClient { }, } - mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) - return mockMetadataServiceClient + mockServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) + return mockServiceClient +} + +func mockGetBadAsyncClientFunc() *agentMocks.AsyncAgentServiceClient { + return nil } func TestInitializeClientFunc(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 68acd50c9a..29c2db78d2 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -108,9 +108,9 @@ func mockSyncTaskClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.Clie return &MockSyncTask{}, nil } -func mockGetBadAsyncClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - return nil, fmt.Errorf("error") -} +// func mockGetBadAsyncClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { +// return nil, fmt.Errorf("error") +// } func TestEndToEnd(t *testing.T) { iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { @@ -172,7 +172,9 @@ func TestEndToEnd(t *testing.T) { metricScope: iCtx.MetricsScope(), cfg: GetConfig(), cs: &ClientSet{ - agentClients: mockGetBadAsyncClientFunc, + agentClients: map[string]service.AsyncAgentServiceClient{ + "localhost:80": mockGetBadAsyncClientFunc(), + }, }, }, }, nil @@ -313,9 +315,9 @@ func newMockAgentPlugin() webapi.PluginEntry { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - cs: &ClientFuncSet{ - getAgentClient: mockAsyncTaskClientFunc, - }, + // cs: &ClientSet{ + // getAgentClient: mockAsyncTaskClientFunc, + // }, }, }, nil }, @@ -331,9 +333,9 @@ func newMockSyncAgentPlugin() webapi.PluginEntry { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - cs: &ClientSet{ - agentClients: mockSyncTaskClientFunc, - }, + // cs: &ClientSet{ + // agentClients: mockSyncTaskClientFunc, + // }, }, }, nil }, diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go b/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go new file mode 100644 index 0000000000..4a2b2c25f3 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go @@ -0,0 +1,162 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + admin "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + + grpc "google.golang.org/grpc" + + mock "github.com/stretchr/testify/mock" +) + +// AsyncAgentServiceClient is an autogenerated mock type for the AsyncAgentServiceClient type +type AsyncAgentServiceClient struct { + mock.Mock +} + +type AsyncAgentServiceClient_CreateTask struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_CreateTask) Return(_a0 *admin.CreateTaskResponse, _a1 error) *AsyncAgentServiceClient_CreateTask { + return &AsyncAgentServiceClient_CreateTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnCreateTask(ctx context.Context, in *admin.CreateTaskRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_CreateTask { + c_call := _m.On("CreateTask", ctx, in, opts) + return &AsyncAgentServiceClient_CreateTask{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnCreateTaskMatch(matchers ...interface{}) *AsyncAgentServiceClient_CreateTask { + c_call := _m.On("CreateTask", matchers...) + return &AsyncAgentServiceClient_CreateTask{Call: c_call} +} + +// CreateTask provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) CreateTask(ctx context.Context, in *admin.CreateTaskRequest, opts ...grpc.CallOption) (*admin.CreateTaskResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.CreateTaskResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.CreateTaskRequest, ...grpc.CallOption) *admin.CreateTaskResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.CreateTaskResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.CreateTaskRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type AsyncAgentServiceClient_DeleteTask struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_DeleteTask) Return(_a0 *admin.DeleteTaskResponse, _a1 error) *AsyncAgentServiceClient_DeleteTask { + return &AsyncAgentServiceClient_DeleteTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnDeleteTask(ctx context.Context, in *admin.DeleteTaskRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_DeleteTask { + c_call := _m.On("DeleteTask", ctx, in, opts) + return &AsyncAgentServiceClient_DeleteTask{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnDeleteTaskMatch(matchers ...interface{}) *AsyncAgentServiceClient_DeleteTask { + c_call := _m.On("DeleteTask", matchers...) + return &AsyncAgentServiceClient_DeleteTask{Call: c_call} +} + +// DeleteTask provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) DeleteTask(ctx context.Context, in *admin.DeleteTaskRequest, opts ...grpc.CallOption) (*admin.DeleteTaskResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.DeleteTaskResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.DeleteTaskRequest, ...grpc.CallOption) *admin.DeleteTaskResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.DeleteTaskResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.DeleteTaskRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type AsyncAgentServiceClient_GetTask struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_GetTask) Return(_a0 *admin.GetTaskResponse, _a1 error) *AsyncAgentServiceClient_GetTask { + return &AsyncAgentServiceClient_GetTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnGetTask(ctx context.Context, in *admin.GetTaskRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_GetTask { + c_call := _m.On("GetTask", ctx, in, opts) + return &AsyncAgentServiceClient_GetTask{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnGetTaskMatch(matchers ...interface{}) *AsyncAgentServiceClient_GetTask { + c_call := _m.On("GetTask", matchers...) + return &AsyncAgentServiceClient_GetTask{Call: c_call} +} + +// GetTask provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) GetTask(ctx context.Context, in *admin.GetTaskRequest, opts ...grpc.CallOption) (*admin.GetTaskResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.GetTaskResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.GetTaskRequest, ...grpc.CallOption) *admin.GetTaskResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.GetTaskResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.GetTaskRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index bc2033a70a..b99ee357af 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -225,12 +225,12 @@ func newAgentPlugin() webapi.PluginEntry { cs, err := initializeClients(context.Background()) if err != nil { // We should wait for all agents to be up and running before starting the server - panic(fmt.Sprintf("failed to initalize clients with error: %v", err)) + panic(fmt.Sprintf("failed to initialize clients with error: %v", err)) } agentRegistry, err := initializeAgentRegistry(cs) if err != nil { - panic(fmt.Sprintf("failed to initalize agent registry with error: %v", err)) + panic(fmt.Sprintf("failed to initialize agent registry with error: %v", err)) } cfg := GetConfig() diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 588b20d024..14b0c10c89 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -2,11 +2,12 @@ package agent import ( "context" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" - "sort" "testing" "time" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "golang.org/x/exp/maps" + "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -21,7 +22,6 @@ import ( "github.com/flyteorg/flyte/flytestdlib/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "golang.org/x/exp/maps" ) func TestSyncTask(t *testing.T) { @@ -188,6 +188,7 @@ func TestPlugin(t *testing.T) { func TestInitializeAgentRegistry(t *testing.T) { agentClients := make(map[string]service.AsyncAgentServiceClient) agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) + agentClients["localhost:80"] = getMockServiceClient() agentMetadataClients["localhost:80"] = getMockMetadataServiceClient() cs := &ClientSet{ @@ -201,9 +202,6 @@ func TestInitializeAgentRegistry(t *testing.T) { agentRegistry, err := initializeAgentRegistry(cs) assert.NoError(t, err) - // In golang, the order of keys in a map is random. So, we sort the keys before asserting. agentRegistryKeys := maps.Keys(agentRegistry) - sort.Strings(agentRegistryKeys) - - assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"}) + assert.Equal(t, agentRegistryKeys, []string{}) }