diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go index ff436c3b05c..70dccd7d88f 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/config.go @@ -1,21 +1,21 @@ package batchscheduler -type BatchSchedulerConfig struct { +type Config struct { Scheduler string `json:"scheduler,omitempty"` Parameters string `json:"parameters,omitempty"` } -func NewDefaultBatchSchedulerConfig() BatchSchedulerConfig { - return BatchSchedulerConfig{ +func NewConfig() Config { + return Config{ Scheduler: "", Parameters: "", } } -func (b *BatchSchedulerConfig) GetScheduler() string { +func (b *Config) GetScheduler() string { return b.Scheduler } -func (b *BatchSchedulerConfig) GetParameters() string { +func (b *Config) GetParameters() string { return b.Parameters } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go index 02f0d4658af..dc17d1155f0 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/default.go @@ -19,7 +19,7 @@ func NewDefaultPlugin() *DefaultPlugin { func (d *DefaultPlugin) GetSchedulerName() string { return DefaultScheduler } func (d *DefaultPlugin) ParseJob( - config *BatchSchedulerConfig, + config *Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go index edf51b0eb73..5aa2ff2ff9c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins.go @@ -8,13 +8,13 @@ import ( type SchedulerPlugin interface { GetSchedulerName() string - ParseJob(config *BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error + ParseJob(config *Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error ProcessHead(metadata *metav1.ObjectMeta, head *v1.PodSpec) ProcessWorker(metadata *metav1.ObjectMeta, worker *v1.PodSpec, index int) AfterProcess(metadata *metav1.ObjectMeta) } -func NewSchedulerPlugin(config *BatchSchedulerConfig) SchedulerPlugin { +func NewSchedulerPlugin(config *Config) SchedulerPlugin { switch config.GetScheduler() { case Yunikorn: return NewYunikornPlugin() diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go index 28142e5334b..ac3ee9db9cd 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/plugins_test.go @@ -8,12 +8,12 @@ import ( func TestCreateSchedulerPlugin(t *testing.T) { var tests = []struct{ - input *BatchSchedulerConfig + input *Config expect string }{ - {input: &BatchSchedulerConfig{Scheduler: DefaultScheduler}, expect: DefaultScheduler}, - {input: &BatchSchedulerConfig{Scheduler: Yunikorn}, expect: Yunikorn}, - {input: &BatchSchedulerConfig{Scheduler:"Unknown"}, expect: DefaultScheduler}, + {input: &Config{Scheduler: DefaultScheduler}, expect: DefaultScheduler}, + {input: &Config{Scheduler: Yunikorn}, expect: Yunikorn}, + {input: &Config{Scheduler:"Unknown"}, expect: DefaultScheduler}, } for _, tt := range tests { t.Run("New scheduler plugin", func(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go index 35717429fe1..c9f93b10be3 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/batchscheduler/yunikorn.go @@ -31,7 +31,7 @@ func NewYunikornPlugin() *YunikornGangSchedulingConfig { func (s *YunikornGangSchedulingConfig) GetSchedulerName() string { return Yunikorn } -func (s *YunikornGangSchedulingConfig) ParseJob(config *BatchSchedulerConfig, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { +func (s *YunikornGangSchedulingConfig) ParseJob(config *Config, metadata *metav1.ObjectMeta, workerGroupsSpec []*plugins.WorkerGroupSpec, pod *v1.PodSpec, primaryContainerIdx int) error { s.Annotations = nil s.Parameters = config.GetParameters() return s.BuildGangInfo(metadata, workerGroupsSpec, pod, primaryContainerIdx) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 3d96e66ec7a..ce547a680f9 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -24,7 +24,7 @@ var ( DashboardHost: "0.0.0.0", EnableUsageStats: false, ServiceAccount: "default", - BatchScheduler: batchscheduler.NewDefaultBatchSchedulerConfig(), + BatchScheduler: batchscheduler.NewConfig(), Defaults: DefaultConfig{ HeadNode: NodeConfig{ StartParameters: map[string]string{ @@ -78,7 +78,7 @@ type Config struct { // or 0.0.0.0 (available from all interfaces). By default, this is localhost. DashboardHost string `json:"dashboardHost,omitempty"` - BatchScheduler batchscheduler.BatchSchedulerConfig `json:"BatchScheduler,omitempty"` + BatchScheduler batchscheduler.Config `json:"BatchScheduler,omitempty"` // DeprecatedNodeIPAddress the IP address of the head node. By default, this is pod ip address. DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"`