Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flytepropeller] Watch agent metadata service dynamically #5460

Merged
merged 41 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
0b1955c
wip
pingsutw Mar 6, 2024
0cffe52
Watch agent service
pingsutw Mar 6, 2024
ac2a89f
lint
pingsutw Mar 6, 2024
b88ecb9
nit
pingsutw Mar 6, 2024
d29998e
Fix test
pingsutw Mar 7, 2024
36e6f87
lint
pingsutw Mar 7, 2024
063846f
nit
pingsutw Mar 7, 2024
7187ade
updateAgentClientSets instead
pingsutw Mar 7, 2024
3dca1a7
lock
pingsutw Mar 8, 2024
60377d9
nit
pingsutw Mar 8, 2024
32a907e
defer
pingsutw Mar 8, 2024
150fc33
resolve conflict
pingsutw Jun 8, 2024
24feb98
lint
pingsutw Jun 8, 2024
d83e51f
Add getter and setter
pingsutw Jun 8, 2024
86c0ccc
refactor(webapi): Improve agent client handling and logging
pingsutw Jun 8, 2024
4e9e7fb
errorf
pingsutw Jun 8, 2024
8e3e805
update
pingsutw Jun 8, 2024
a1d2a2c
remove logger
pingsutw Jun 8, 2024
ac06415
nit
pingsutw Jun 8, 2024
6fb31ab
nit
pingsutw Jun 8, 2024
26731cd
lint
pingsutw Jun 8, 2024
e91de3b
lint
pingsutw Jun 8, 2024
17e6946
Dynamically update DefaultPlugins by agent watcher
Future-Outlier Jun 8, 2024
d854ac1
merge master, solve conflict and add annotations
Future-Outlier Jun 8, 2024
bbe862b
change DefaultPlugins to defaultPlugins, since we only need to keep d…
Future-Outlier Jun 8, 2024
a219bd2
fix lint error, need to fix test later
Future-Outlier Jun 8, 2024
b0e4591
nit
Future-Outlier Jun 8, 2024
fc4288e
use read lock in agent registry getter
Future-Outlier Jun 9, 2024
f6e8916
an update version with race condition, move default plugins to task/c…
Future-Outlier Jun 9, 2024
b97b703
a work versionpwd
Future-Outlier Jun 9, 2024
00fb2b2
kevin's update
pingsutw Jun 10, 2024
30108cc
nit
pingsutw Jun 10, 2024
9710584
lint
pingsutw Jun 10, 2024
6a10b45
lint
pingsutw Jun 10, 2024
bcd1085
revert go.mod and go.sum changes
Future-Outlier Jun 10, 2024
3f92c5b
Update agent/plugin.go mutex usage
Future-Outlier Jun 11, 2024
1256557
solve race condition
Future-Outlier Jun 11, 2024
a475a27
Merge branch 'master' into watch-agent-dynamically
Future-Outlier Jun 22, 2024
8fc782f
Add test for agent watcher
Future-Outlier Jun 22, 2024
0803569
Add Tests for AgentService Interface
Future-Outlier Jun 23, 2024
297d575
Merge branch 'master' into watch-agent-dynamically
Future-Outlier Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion flyteplugins/go/tasks/pluginmachinery/core/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package core
import (
"context"
"fmt"
"sync"

"k8s.io/utils/strings/slices"
)

//go:generate mockery -all -case=underscore
Expand Down Expand Up @@ -55,7 +58,27 @@ type Plugin interface {
Finalize(ctx context.Context, tCtx TaskExecutionContext) error
}

// Loads and validates a plugin.
type AgentService struct {
mu sync.RWMutex
supportedTaskTypes []TaskType
CorePlugin Plugin
}

// ContainTaskType check if agent supports this task type.
func (p *AgentService) ContainTaskType(taskType TaskType) bool {
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
p.mu.RLock()
defer p.mu.RUnlock()
return slices.Contains(p.supportedTaskTypes, taskType)
}

// SetSupportedTaskType set supportTaskType in the agent service.
func (p *AgentService) SetSupportedTaskType(taskTypes []TaskType) {
p.mu.Lock()
defer p.mu.Unlock()
p.supportedTaskTypes = taskTypes
}

// LoadPlugin Loads and validates a plugin.
func LoadPlugin(ctx context.Context, iCtx SetupContext, entry PluginEntry) (Plugin, error) {
plugin, err := entry.LoadPlugin(ctx, iCtx)
if err != nil {
Expand Down
14 changes: 14 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/core/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,17 @@ func TestLoadPlugin(t *testing.T) {
})

}

func TestAgentService(t *testing.T) {
agentService := core.AgentService{}
taskTypes := []core.TaskType{"sensor", "chatgpt"}

for _, taskType := range taskTypes {
assert.Equal(t, false, agentService.ContainTaskType(taskType))
}

agentService.SetSupportedTaskType(taskTypes)
for _, taskType := range taskTypes {
assert.Equal(t, true, agentService.ContainTaskType(taskType))
}
}
41 changes: 22 additions & 19 deletions flyteplugins/go/tasks/plugins/webapi/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,11 @@
return context.WithTimeout(ctx, timeout)
}

func updateAgentRegistry(ctx context.Context, cs *ClientSet) {
agentRegistry := make(Registry)
func getAgentRegistry(ctx context.Context, cs *ClientSet) Registry {
newAgentRegistry := make(Registry)
cfg := GetConfig()
var agentDeployments []*Deployment

// Ensure that the old configuration is backward compatible
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
agent := Agent{AgentDeployment: cfg.AgentDeployments[agentDeploymentID], IsSync: false}
agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: &agent}
}

if len(cfg.DefaultAgent.Endpoint) != 0 {
agentDeployments = append(agentDeployments, &cfg.DefaultAgent)
}
Expand Down Expand Up @@ -137,27 +131,36 @@
deprecatedSupportedTaskTypes := agent.SupportedTaskTypes
for _, supportedTaskType := range deprecatedSupportedTaskTypes {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
agentRegistry[supportedTaskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
newAgentRegistry[supportedTaskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}

Check warning on line 134 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L134

Added line #L134 was not covered by tests
}

supportedTaskCategories := agent.SupportedTaskCategories
for _, supportedCategory := range supportedTaskCategories {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
agentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent}
newAgentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent}
}
}
// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok {
if _, ok := agentRegistry[taskType]; !ok {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}
}

// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok {
if _, ok := newAgentRegistry[taskType]; !ok {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}

Check warning on line 150 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L148-L150

Added lines #L148 - L150 were not covered by tests
}
}
}
logger.Debugf(ctx, "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry))
setAgentRegistry(agentRegistry)

// Ensure that the old configuration is backward compatible
for _, taskType := range cfg.SupportedTaskTypes {
if _, ok := newAgentRegistry[taskType]; !ok {
agent := &Agent{AgentDeployment: &cfg.DefaultAgent, IsSync: false}
newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}
}

return newAgentRegistry
}

func getAgentClientSets(ctx context.Context) *ClientSet {
Expand Down
19 changes: 12 additions & 7 deletions flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ import (
)

func TestEndToEnd(t *testing.T) {
agentRegistry = Registry{
"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}},
"spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}},
}
iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error {
return nil
}
Expand Down Expand Up @@ -117,7 +113,7 @@ func TestEndToEnd(t *testing.T) {
t.Run("failed to create a job", func(t *testing.T) {
agentPlugin := newMockAsyncAgentPlugin()
agentPlugin.PluginLoader = func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return Plugin{
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
cs: &ClientSet{
Expand Down Expand Up @@ -259,6 +255,9 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext {

func newMockAsyncAgentPlugin() webapi.PluginEntry {
asyncAgentClient := new(agentMocks.AsyncAgentServiceClient)
agentRegistry := Registry{
"spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}},
}

mockCreateRequestMatcher := mock.MatchedBy(func(request *admin.CreateTaskRequest) bool {
expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"}
Expand All @@ -283,20 +282,25 @@ func newMockAsyncAgentPlugin() webapi.PluginEntry {
ID: "agent-service",
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return Plugin{
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: &cfg,
cs: &ClientSet{
asyncAgentClients: map[string]service.AsyncAgentServiceClient{
defaultAgentEndpoint: asyncAgentClient,
},
},
registry: agentRegistry,
}, nil
},
}
}

func newMockSyncAgentPlugin() webapi.PluginEntry {
agentRegistry := Registry{
"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}},
}

syncAgentClient := new(agentMocks.SyncAgentServiceClient)
output, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1})
resource := &admin.Resource{Phase: flyteIdlCore.TaskExecution_SUCCEEDED, Outputs: output}
Expand All @@ -323,14 +327,15 @@ func newMockSyncAgentPlugin() webapi.PluginEntry {
ID: "agent-service",
SupportedTaskTypes: []core.TaskType{"openai"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return Plugin{
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: &cfg,
cs: &ClientSet{
syncAgentClients: map[string]service.SyncAgentServiceClient{
defaultAgentEndpoint: syncAgentClient,
},
},
registry: agentRegistry,
}, nil
},
}
Expand Down
Loading
Loading