Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flytepropeller] Watch agent metadata service dynamically #5460

Merged
merged 41 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
0b1955c
wip
pingsutw Mar 6, 2024
0cffe52
Watch agent service
pingsutw Mar 6, 2024
ac2a89f
lint
pingsutw Mar 6, 2024
b88ecb9
nit
pingsutw Mar 6, 2024
d29998e
Fix test
pingsutw Mar 7, 2024
36e6f87
lint
pingsutw Mar 7, 2024
063846f
nit
pingsutw Mar 7, 2024
7187ade
updateAgentClientSets instead
pingsutw Mar 7, 2024
3dca1a7
lock
pingsutw Mar 8, 2024
60377d9
nit
pingsutw Mar 8, 2024
32a907e
defer
pingsutw Mar 8, 2024
150fc33
resolve conflict
pingsutw Jun 8, 2024
24feb98
lint
pingsutw Jun 8, 2024
d83e51f
Add getter and setter
pingsutw Jun 8, 2024
86c0ccc
refactor(webapi): Improve agent client handling and logging
pingsutw Jun 8, 2024
4e9e7fb
errorf
pingsutw Jun 8, 2024
8e3e805
update
pingsutw Jun 8, 2024
a1d2a2c
remove logger
pingsutw Jun 8, 2024
ac06415
nit
pingsutw Jun 8, 2024
6fb31ab
nit
pingsutw Jun 8, 2024
26731cd
lint
pingsutw Jun 8, 2024
e91de3b
lint
pingsutw Jun 8, 2024
17e6946
Dynamically update DefaultPlugins by agent watcher
Future-Outlier Jun 8, 2024
d854ac1
merge master, solve conflict and add annotations
Future-Outlier Jun 8, 2024
bbe862b
change DefaultPlugins to defaultPlugins, since we only need to keep d…
Future-Outlier Jun 8, 2024
a219bd2
fix lint error, need to fix test later
Future-Outlier Jun 8, 2024
b0e4591
nit
Future-Outlier Jun 8, 2024
fc4288e
use read lock in agent registry getter
Future-Outlier Jun 9, 2024
f6e8916
an update version with race condition, move default plugins to task/c…
Future-Outlier Jun 9, 2024
b97b703
a work versionpwd
Future-Outlier Jun 9, 2024
00fb2b2
kevin's update
pingsutw Jun 10, 2024
30108cc
nit
pingsutw Jun 10, 2024
9710584
lint
pingsutw Jun 10, 2024
6a10b45
lint
pingsutw Jun 10, 2024
bcd1085
revert go.mod and go.sum changes
Future-Outlier Jun 10, 2024
3f92c5b
Update agent/plugin.go mutex usage
Future-Outlier Jun 11, 2024
1256557
solve race condition
Future-Outlier Jun 11, 2024
a475a27
Merge branch 'master' into watch-agent-dynamically
Future-Outlier Jun 22, 2024
8fc782f
Add test for agent watcher
Future-Outlier Jun 22, 2024
0803569
Add Tests for AgentService Interface
Future-Outlier Jun 23, 2024
297d575
Merge branch 'master' into watch-agent-dynamically
Future-Outlier Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion flyteplugins/go/tasks/pluginmachinery/core/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package core
import (
"context"
"fmt"
"k8s.io/utils/strings/slices"
"sync"
)

//go:generate mockery -all -case=underscore
Expand Down Expand Up @@ -55,7 +57,27 @@ type Plugin interface {
Finalize(ctx context.Context, tCtx TaskExecutionContext) error
}

// Loads and validates a plugin.
type AgentService struct {
mu sync.RWMutex
supportedTaskTypes []TaskType
CorePlugin Plugin
}

// ContainTaskType check if agent supports this task type.
func (p *AgentService) ContainTaskType(taskType TaskType) bool {
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
p.mu.RLock()
defer p.mu.RUnlock()
return slices.Contains(p.supportedTaskTypes, taskType)
}

// SetSupportedTaskType set supportTaskType in the agent service.
func (p *AgentService) SetSupportedTaskType(taskTypes []TaskType) {
p.mu.Lock()
defer p.mu.Unlock()
p.supportedTaskTypes = taskTypes
}

// LoadPlugin Loads and validates a plugin.
func LoadPlugin(ctx context.Context, iCtx SetupContext, entry PluginEntry) (Plugin, error) {
plugin, err := entry.LoadPlugin(ctx, iCtx)
if err != nil {
Expand Down
41 changes: 22 additions & 19 deletions flyteplugins/go/tasks/plugins/webapi/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,11 @@ func getFinalContext(ctx context.Context, operation string, agent *Deployment) (
return context.WithTimeout(ctx, timeout)
}

func updateAgentRegistry(ctx context.Context, cs *ClientSet) {
agentRegistry := make(Registry)
func getUpdatedAgentRegistry(ctx context.Context, cs *ClientSet) Registry {
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
newAgentRegistry := make(Registry)
cfg := GetConfig()
var agentDeployments []*Deployment

// Ensure that the old configuration is backward compatible
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
agent := Agent{AgentDeployment: cfg.AgentDeployments[agentDeploymentID], IsSync: false}
agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: &agent}
}

if len(cfg.DefaultAgent.Endpoint) != 0 {
agentDeployments = append(agentDeployments, &cfg.DefaultAgent)
}
Expand Down Expand Up @@ -144,27 +138,36 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) {
deprecatedSupportedTaskTypes := agent.SupportedTaskTypes
for _, supportedTaskType := range deprecatedSupportedTaskTypes {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
agentRegistry[supportedTaskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
newAgentRegistry[supportedTaskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}

supportedTaskCategories := agent.SupportedTaskCategories
for _, supportedCategory := range supportedTaskCategories {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
agentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent}
newAgentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent}
}
}
// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok {
if _, ok := agentRegistry[taskType]; !ok {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}
}

// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok {
if _, ok := newAgentRegistry[taskType]; !ok {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}
}
}
logger.Debugf(ctx, "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry))
setAgentRegistry(agentRegistry)

// Ensure that the old configuration is backward compatible
for _, taskType := range cfg.SupportedTaskTypes {
if _, ok := newAgentRegistry[taskType]; !ok {
agent := &Agent{AgentDeployment: &cfg.DefaultAgent, IsSync: false}
newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}
}

return newAgentRegistry
}

func getAgentClientSets(ctx context.Context) *ClientSet {
Expand Down
13 changes: 9 additions & 4 deletions flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ import (
)

func TestEndToEnd(t *testing.T) {
agentRegistry = Registry{
"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}},
"spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}},
}
iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error {
return nil
}
Expand Down Expand Up @@ -259,6 +255,9 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext {

func newMockAsyncAgentPlugin() webapi.PluginEntry {
asyncAgentClient := new(agentMocks.AsyncAgentServiceClient)
agentRegistry := Registry{
"spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}},
}

mockCreateRequestMatcher := mock.MatchedBy(func(request *admin.CreateTaskRequest) bool {
expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"}
Expand Down Expand Up @@ -291,12 +290,17 @@ func newMockAsyncAgentPlugin() webapi.PluginEntry {
defaultAgentEndpoint: asyncAgentClient,
},
},
registry: agentRegistry,
}, nil
},
}
}

func newMockSyncAgentPlugin() webapi.PluginEntry {
agentRegistry := Registry{
"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}},
}

syncAgentClient := new(agentMocks.SyncAgentServiceClient)
output, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1})
resource := &admin.Resource{Phase: flyteIdlCore.TaskExecution_SUCCEEDED, Outputs: output}
Expand Down Expand Up @@ -331,6 +335,7 @@ func newMockSyncAgentPlugin() webapi.PluginEntry {
defaultAgentEndpoint: syncAgentClient,
},
},
registry: agentRegistry,
}, nil
},
}
Expand Down
95 changes: 49 additions & 46 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,16 @@ import (
"github.com/flyteorg/flyte/flytestdlib/promutils"
)

type Registry map[string]map[int32]*Agent // map[taskTypeName][taskTypeVersion] => Agent

var (
agentRegistry Registry
mu sync.RWMutex
)
const ID = "agent-service"

func getAgentRegistry() Registry {
mu.Lock()
defer mu.Unlock()
return agentRegistry
}

func setAgentRegistry(r Registry) {
mu.Lock()
defer mu.Unlock()
agentRegistry = r
}
type Registry map[string]map[int32]*Agent // map[taskTypeName][taskTypeVersion] => Agent

type Plugin struct {
metricScope promutils.Scope
cfg *Config
cs *ClientSet
registry Registry
mu sync.RWMutex
}

type ResourceWrapper struct {
Expand All @@ -69,18 +56,30 @@ type ResourceMetaWrapper struct {
TaskCategory admin.TaskCategory
}

func (p Plugin) GetConfig() webapi.PluginConfig {
func (p *Plugin) getRegistryTaskTypes() []core.TaskType {
p.mu.RLock()
defer p.mu.RUnlock()
return maps.Keys(p.registry)
}

func (p *Plugin) setRegistry(r Registry) {
p.mu.Lock()
defer p.mu.Unlock()
p.registry = r
}

func (p *Plugin) GetConfig() webapi.PluginConfig {
return GetConfig().WebAPI
}

func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) (
func (p *Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) (
namespace core.ResourceNamespace, constraints core.ResourceConstraintsSpec, err error) {

// Resource requirements are assumed to be the same.
return "default", p.cfg.ResourceConstraints, nil
}

func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta,
func (p *Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta,
webapi.Resource, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
Expand Down Expand Up @@ -113,7 +112,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

taskCategory := admin.TaskCategory{Name: taskTemplate.Type, Version: taskTemplate.TaskTypeVersion}
agent, isSync := getFinalAgent(&taskCategory, p.cfg)
agent, isSync := p.getFinalAgent(&taskCategory, p.cfg)

taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

Expand Down Expand Up @@ -149,7 +148,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
}, nil, nil
}

func (p Plugin) ExecuteTaskSync(
func (p *Plugin) ExecuteTaskSync(
ctx context.Context,
client service.SyncAgentServiceClient,
header *admin.CreateRequestHeader,
Expand Down Expand Up @@ -209,9 +208,9 @@ func (p Plugin) ExecuteTaskSync(
}, err
}

func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
func (p *Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg)
agent, _ := p.getFinalAgent(&metadata.TaskCategory, p.cfg)

client, err := p.getAsyncAgentClient(ctx, agent)
if err != nil {
Expand Down Expand Up @@ -239,12 +238,12 @@ func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest weba
}, nil
}

func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error {
func (p *Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error {
if taskCtx.ResourceMeta() == nil {
return nil
}
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg)
agent, _ := p.getFinalAgent(&metadata.TaskCategory, p.cfg)

client, err := p.getAsyncAgentClient(ctx, agent)
if err != nil {
Expand All @@ -262,7 +261,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
return err
}

func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) {
func (p *Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) {
resource := taskCtx.Resource().(ResourceWrapper)
taskInfo := &core.TaskInfo{Logs: resource.LogLinks}

Expand Down Expand Up @@ -314,7 +313,7 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase
return core.PhaseInfoUndefined, pluginErrors.Errorf(core.SystemErrorCode, "unknown execution state [%v].", resource.State)
}

func (p Plugin) getSyncAgentClient(ctx context.Context, agent *Deployment) (service.SyncAgentServiceClient, error) {
func (p *Plugin) getSyncAgentClient(ctx context.Context, agent *Deployment) (service.SyncAgentServiceClient, error) {
client, ok := p.cs.syncAgentClients[agent.Endpoint]
if !ok {
conn, err := getGrpcConnection(ctx, agent)
Expand All @@ -327,7 +326,7 @@ func (p Plugin) getSyncAgentClient(ctx context.Context, agent *Deployment) (serv
return client, nil
}

func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (service.AsyncAgentServiceClient, error) {
func (p *Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (service.AsyncAgentServiceClient, error) {
client, ok := p.cs.asyncAgentClients[agent.Endpoint]
if !ok {
conn, err := getGrpcConnection(ctx, agent)
Expand All @@ -340,13 +339,25 @@ func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (ser
return client, nil
}

func (p Plugin) watchAgents(ctx context.Context) {
func (p *Plugin) watchAgents(ctx context.Context, agentService *core.AgentService) {
go wait.Until(func() {
clientSet := getAgentClientSets(ctx)
updateAgentRegistry(ctx, clientSet)
agentRegistry := getUpdatedAgentRegistry(ctx, clientSet)
p.setRegistry(agentRegistry)
agentService.SetSupportedTaskType(maps.Keys(agentRegistry))
}, p.cfg.PollInterval.Duration, ctx.Done())
}

func (p *Plugin) getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) {
p.mu.RLock()
defer p.mu.RUnlock()

if agent, exists := p.registry[taskCategory.Name][taskCategory.Version]; exists {
return agent.AgentDeployment, agent.IsSync
}
return &cfg.DefaultAgent, false
}

func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *flyteIdl.LiteralMap) error {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
Expand All @@ -369,14 +380,6 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *fly
return taskCtx.OutputWriter().Put(ctx, opReader)
}

func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) {
r := getAgentRegistry()
if agent, exists := r[taskCategory.Name][taskCategory.Version]; exists {
return agent.AgentDeployment, agent.IsSync
}
return &cfg.DefaultAgent, false
}

func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata) admin.TaskExecutionMetadata {
taskExecutionID := taskExecutionMetadata.GetTaskExecutionID().GetID()

Expand All @@ -391,13 +394,12 @@ func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata
}
}

func newAgentPlugin() webapi.PluginEntry {
func newAgentPlugin(agentService *core.AgentService) webapi.PluginEntry {
ctx := context.Background()
cfg := GetConfig()

clientSet := getAgentClientSets(ctx)
updateAgentRegistry(ctx, clientSet)
supportedTaskTypes := append(maps.Keys(getAgentRegistry()), cfg.SupportedTaskTypes...)
agentRegistry := getUpdatedAgentRegistry(ctx, clientSet)
supportedTaskTypes := maps.Keys(agentRegistry)

return webapi.PluginEntry{
ID: "agent-service",
Expand All @@ -407,15 +409,16 @@ func newAgentPlugin() webapi.PluginEntry {
metricScope: iCtx.MetricsScope(),
cfg: cfg,
cs: clientSet,
registry: agentRegistry,
}
plugin.watchAgents(ctx)
plugin.watchAgents(ctx, agentService)
return plugin, nil
},
}
}

func RegisterAgentPlugin() {
func RegisterAgentPlugin(agentService *core.AgentService) {
gob.Register(ResourceMetaWrapper{})
gob.Register(ResourceWrapper{})
pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin())
pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin(agentService))
}
Loading
Loading