Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add mockery AsyncAgentClient
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
Future-Outlier committed Jan 24, 2024
1 parent d3b9220 commit 48d78d9
Showing 6 changed files with 194 additions and 24 deletions.
2 changes: 2 additions & 0 deletions flyteplugins/go/tasks/plugins/webapi/agent/client.go
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ import (
"context"
"crypto/x509"
"fmt"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
@@ -88,6 +89,7 @@ func initializeAgentRegistry(cs *ClientSet) (map[string]*Agent, error) {
agentRegistry := make(map[string]*Agent)
cfg := GetConfig()
var agentDeployments []*Agent
fmt.Printf("@@@ cfg.AgentForTaskTypes: [%v]\n", cfg.AgentForTaskTypes)

// Ensure that the old configuration is backward compatible
for taskType, agentID := range cfg.AgentForTaskTypes {
16 changes: 11 additions & 5 deletions flyteplugins/go/tasks/plugins/webapi/agent/client_test.go
Original file line number Diff line number Diff line change
@@ -2,11 +2,12 @@ package agent

import (
"context"
"testing"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
)

func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient {
@@ -25,8 +26,9 @@ func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient {
return mockMetadataServiceClient
}

func getMockServiceClient() *agentMocks.AgentMetadataServiceClient {
mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient)
// TODO, USE CREATE, GET DELETE FUNCTION TO MOCK THE OUTPUT
func getMockServiceClient() *agentMocks.AsyncAgentServiceClient {
mockServiceClient := new(agentMocks.AsyncAgentServiceClient)
mockRequest := &admin.ListAgentsRequest{}
mockResponse := &admin.ListAgentsResponse{
Agents: []*admin.Agent{
@@ -37,8 +39,12 @@ func getMockServiceClient() *agentMocks.AgentMetadataServiceClient {
},
}

mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil)
return mockMetadataServiceClient
mockServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil)
return mockServiceClient
}

func mockGetBadAsyncClientFunc() *agentMocks.AsyncAgentServiceClient {
return nil
}

func TestInitializeClientFunc(t *testing.T) {
22 changes: 12 additions & 10 deletions flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
@@ -108,9 +108,9 @@ func mockSyncTaskClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.Clie
return &MockSyncTask{}, nil
}

func mockGetBadAsyncClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
return nil, fmt.Errorf("error")
}
// func mockGetBadAsyncClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
// return nil, fmt.Errorf("error")
// }

func TestEndToEnd(t *testing.T) {
iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error {
@@ -172,7 +172,9 @@ func TestEndToEnd(t *testing.T) {
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
cs: &ClientSet{
agentClients: mockGetBadAsyncClientFunc,
agentClients: map[string]service.AsyncAgentServiceClient{
"localhost:80": mockGetBadAsyncClientFunc(),
},
},
},
}, nil
@@ -313,9 +315,9 @@ func newMockAgentPlugin() webapi.PluginEntry {
Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
cs: &ClientFuncSet{
getAgentClient: mockAsyncTaskClientFunc,
},
// cs: &ClientSet{
// getAgentClient: mockAsyncTaskClientFunc,
// },
},
}, nil
},
@@ -331,9 +333,9 @@ func newMockSyncAgentPlugin() webapi.PluginEntry {
Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
cs: &ClientSet{
agentClients: mockSyncTaskClientFunc,
},
// cs: &ClientSet{
// agentClients: mockSyncTaskClientFunc,
// },
},
}, nil
},

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
@@ -225,12 +225,12 @@ func newAgentPlugin() webapi.PluginEntry {
cs, err := initializeClients(context.Background())
if err != nil {
// We should wait for all agents to be up and running before starting the server
panic(fmt.Sprintf("failed to initalize clients with error: %v", err))
panic(fmt.Sprintf("failed to initialize clients with error: %v", err))
}

agentRegistry, err := initializeAgentRegistry(cs)
if err != nil {
panic(fmt.Sprintf("failed to initalize agent registry with error: %v", err))
panic(fmt.Sprintf("failed to initialize agent registry with error: %v", err))
}

cfg := GetConfig()
12 changes: 5 additions & 7 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
@@ -2,11 +2,12 @@ package agent

import (
"context"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"sort"
"testing"
"time"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"golang.org/x/exp/maps"

"github.com/flyteorg/flyte/flyteidl/clients/go/coreutils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
@@ -21,7 +22,6 @@ import (
"github.com/flyteorg/flyte/flytestdlib/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/exp/maps"
)

func TestSyncTask(t *testing.T) {
@@ -188,6 +188,7 @@ func TestPlugin(t *testing.T) {
func TestInitializeAgentRegistry(t *testing.T) {
agentClients := make(map[string]service.AsyncAgentServiceClient)
agentMetadataClients := make(map[string]service.AgentMetadataServiceClient)
agentClients["localhost:80"] = getMockServiceClient()
agentMetadataClients["localhost:80"] = getMockMetadataServiceClient()

cs := &ClientSet{
@@ -201,9 +202,6 @@ func TestInitializeAgentRegistry(t *testing.T) {
agentRegistry, err := initializeAgentRegistry(cs)
assert.NoError(t, err)

// In golang, the order of keys in a map is random. So, we sort the keys before asserting.
agentRegistryKeys := maps.Keys(agentRegistry)
sort.Strings(agentRegistryKeys)

assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"})
assert.Equal(t, agentRegistryKeys, []string{})
}

0 comments on commit 48d78d9

Please sign in to comment.