diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 3efed684c7..17141983cb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -78,7 +78,7 @@ type Config struct { // or 0.0.0.0 (available from all interfaces). By default, this is localhost. DashboardHost string `json:"dashboardHost,omitempty"` - BatchScheduler schedulerConfig.Config `json:"BatchScheduler,omitempty"` + BatchScheduler schedulerConfig.Config `json:"batchScheduler,omitempty"` // DeprecatedNodeIPAddress the IP address of the head node. By default, this is pod ip address. DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"` diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go index 5048869eab..a725e93c5a 100755 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags.go @@ -55,6 +55,8 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "serviceType"), defaultConfig.ServiceType, "") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "includeDashboard"), defaultConfig.IncludeDashboard, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dashboardHost"), defaultConfig.DashboardHost, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.scheduler"), defaultConfig.BatchScheduler.Scheduler, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "batchScheduler.parameters"), defaultConfig.BatchScheduler.Parameters, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.name"), defaultConfig.RemoteClusterConfig.Name, "Friendly name of the remote cluster") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.endpoint"), defaultConfig.RemoteClusterConfig.Endpoint, " Remote K8s cluster endpoint") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.enabled"), defaultConfig.RemoteClusterConfig.Enabled, " Boolean flag to enable or disable") diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go index 05871adc51..960752c0a5 100755 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config_flags_test.go @@ -169,6 +169,34 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_batchScheduler.scheduler", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.scheduler", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.scheduler"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Scheduler) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_batchScheduler.parameters", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("batchScheduler.parameters", testValue) + if vString, err := cmdFlags.GetString("batchScheduler.parameters"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BatchScheduler.Parameters) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_remoteClusterConfig.name", func(t *testing.T) { t.Run("Override", func(t *testing.T) {