From ee4e97cd82746a42c90ac2360cfaa831d336c6f5 Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Fri, 20 Dec 2024 20:00:59 +0000 Subject: [PATCH] [kubectl-plugin] add --worker-gpu flag for cluster creation Signed-off-by: Andrew Sy Kim --- kubectl-plugin/pkg/cmd/create/create_cluster.go | 3 +++ kubectl-plugin/pkg/util/generation/generation.go | 13 +++++++++++++ .../pkg/util/generation/generation_test.go | 5 +++++ 3 files changed, 21 insertions(+) diff --git a/kubectl-plugin/pkg/cmd/create/create_cluster.go b/kubectl-plugin/pkg/cmd/create/create_cluster.go index 3434e93642e..328920bd877 100644 --- a/kubectl-plugin/pkg/cmd/create/create_cluster.go +++ b/kubectl-plugin/pkg/cmd/create/create_cluster.go @@ -25,6 +25,7 @@ type CreateClusterOptions struct { workerGrpName string workerCPU string workerMemory string + workerGPU string workerReplicas int32 dryRun bool } @@ -76,6 +77,7 @@ func NewCreateClusterCommand(streams genericclioptions.IOStreams) *cobra.Command cmd.Flags().Int32Var(&options.workerReplicas, "worker-replicas", 1, "Number of the worker group replicas. Default of 1") cmd.Flags().StringVar(&options.workerCPU, "worker-cpu", "2", "Number of CPU for the ray worker. Default to 2") cmd.Flags().StringVar(&options.workerMemory, "worker-memory", "4Gi", "Amount of memory to use for the ray worker. Default to 4Gi") + cmd.Flags().StringVar(&options.workerGPU, "worker-gpu", "0", "Number of GPU for the ray worker. Default to 0") cmd.Flags().BoolVar(&options.dryRun, "dry-run", false, "Will not apply the generated cluster and will print out the generated yaml") options.configFlags.AddFlags(cmd.Flags()) @@ -130,6 +132,7 @@ func (options *CreateClusterOptions) Run(ctx context.Context, factory cmdutil.Fa WorkerReplicas: options.workerReplicas, WorkerCPU: options.workerCPU, WorkerMemory: options.workerMemory, + WorkerGPU: options.workerGPU, }, } diff --git a/kubectl-plugin/pkg/util/generation/generation.go b/kubectl-plugin/pkg/util/generation/generation.go index e74bd7aaf35..aa06dca2ed8 100644 --- a/kubectl-plugin/pkg/util/generation/generation.go +++ b/kubectl-plugin/pkg/util/generation/generation.go @@ -19,6 +19,7 @@ type RayClusterSpecObject struct { HeadMemory string WorkerGrpName string WorkerCPU string + WorkerGPU string WorkerMemory string HeadLifecyclePrestopExecCommand []string WorkerLifecyclePrestopExecComand []string @@ -98,6 +99,18 @@ func (rayClusterSpecObject *RayClusterSpecObject) generateRayClusterSpec() *rayv corev1.ResourceMemory: resource.MustParse(rayClusterSpecObject.WorkerMemory), })))))) + gpuResource := resource.MustParse(rayClusterSpecObject.WorkerGPU) + if !gpuResource.IsZero() { + var requests, limits corev1.ResourceList + requests = *rayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests + limits = *rayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Limits + requests[corev1.ResourceName("nvidia.com/gpu")] = gpuResource + limits[corev1.ResourceName("nvidia.com/gpu")] = gpuResource + + rayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests = &requests + rayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Limits = &limits + } + // Lifecycle cannot be empty, an empty lifecycle will stop pod startup so this will add lifecycle if its not empty if len(rayClusterSpecObject.WorkerLifecyclePrestopExecComand) > 0 { rayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Lifecycle = corev1ac.Lifecycle(). diff --git a/kubectl-plugin/pkg/util/generation/generation_test.go b/kubectl-plugin/pkg/util/generation/generation_test.go index 9ac7709fb44..2acf0d053b2 100644 --- a/kubectl-plugin/pkg/util/generation/generation_test.go +++ b/kubectl-plugin/pkg/util/generation/generation_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" @@ -23,6 +24,7 @@ func TestGenerateRayCluterApplyConfig(t *testing.T) { WorkerReplicas: 3, WorkerCPU: "2", WorkerMemory: "10Gi", + WorkerGPU: "1", }, } @@ -37,6 +39,7 @@ func TestGenerateRayCluterApplyConfig(t *testing.T) { assert.Equal(t, testRayClusterYamlObject.WorkerGrpName, *result.Spec.WorkerGroupSpecs[0].GroupName) assert.Equal(t, testRayClusterYamlObject.WorkerReplicas, *result.Spec.WorkerGroupSpecs[0].Replicas) assert.Equal(t, resource.MustParse(testRayClusterYamlObject.WorkerCPU), *result.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests.Cpu()) + assert.Equal(t, resource.MustParse(testRayClusterYamlObject.WorkerGPU), *result.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests.Name(corev1.ResourceName("nvidia.com/gpu"), resource.DecimalSI)) assert.Equal(t, resource.MustParse(testRayClusterYamlObject.WorkerMemory), *result.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests.Memory()) } @@ -54,6 +57,7 @@ func TestGenerateRayJobApplyConfig(t *testing.T) { WorkerReplicas: 3, WorkerCPU: "2", WorkerMemory: "10Gi", + WorkerGPU: "0", }, } @@ -85,6 +89,7 @@ func TestConvertRayClusterApplyConfigToYaml(t *testing.T) { WorkerReplicas: 3, WorkerCPU: "2", WorkerMemory: "10Gi", + WorkerGPU: "0", }, }