Skip to content

Commit

Permalink
[kubectl-plugin] add --worker-gpu flag for cluster creation
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Sy Kim <[email protected]>
  • Loading branch information
andrewsykim committed Dec 20, 2024
1 parent 2c47bbc commit ee4e97c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
3 changes: 3 additions & 0 deletions kubectl-plugin/pkg/cmd/create/create_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type CreateClusterOptions struct {
workerGrpName string
workerCPU string
workerMemory string
workerGPU string
workerReplicas int32
dryRun bool
}
Expand Down Expand Up @@ -76,6 +77,7 @@ func NewCreateClusterCommand(streams genericclioptions.IOStreams) *cobra.Command
cmd.Flags().Int32Var(&options.workerReplicas, "worker-replicas", 1, "Number of the worker group replicas. Default of 1")
cmd.Flags().StringVar(&options.workerCPU, "worker-cpu", "2", "Number of CPU for the ray worker. Default to 2")
cmd.Flags().StringVar(&options.workerMemory, "worker-memory", "4Gi", "Amount of memory to use for the ray worker. Default to 4Gi")
cmd.Flags().StringVar(&options.workerGPU, "worker-gpu", "0", "Number of GPU for the ray worker. Default to 0")
cmd.Flags().BoolVar(&options.dryRun, "dry-run", false, "Will not apply the generated cluster and will print out the generated yaml")

options.configFlags.AddFlags(cmd.Flags())
Expand Down Expand Up @@ -130,6 +132,7 @@ func (options *CreateClusterOptions) Run(ctx context.Context, factory cmdutil.Fa
WorkerReplicas: options.workerReplicas,
WorkerCPU: options.workerCPU,
WorkerMemory: options.workerMemory,
WorkerGPU: options.workerGPU,
},
}

Expand Down
13 changes: 13 additions & 0 deletions kubectl-plugin/pkg/util/generation/generation.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type RayClusterSpecObject struct {
HeadMemory string
WorkerGrpName string
WorkerCPU string
WorkerGPU string
WorkerMemory string
HeadLifecyclePrestopExecCommand []string
WorkerLifecyclePrestopExecComand []string
Expand Down Expand Up @@ -98,6 +99,18 @@ func (rayClusterSpecObject *RayClusterSpecObject) generateRayClusterSpec() *rayv
corev1.ResourceMemory: resource.MustParse(rayClusterSpecObject.WorkerMemory),
}))))))

gpuResource := resource.MustParse(rayClusterSpecObject.WorkerGPU)
if !gpuResource.IsZero() {
var requests, limits corev1.ResourceList
requests = *rayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests
limits = *rayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Limits
requests[corev1.ResourceName("nvidia.com/gpu")] = gpuResource
limits[corev1.ResourceName("nvidia.com/gpu")] = gpuResource

rayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests = &requests
rayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Limits = &limits
}

// Lifecycle cannot be empty, an empty lifecycle will stop pod startup so this will add lifecycle if its not empty
if len(rayClusterSpecObject.WorkerLifecyclePrestopExecComand) > 0 {
rayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Lifecycle = corev1ac.Lifecycle().
Expand Down
5 changes: 5 additions & 0 deletions kubectl-plugin/pkg/util/generation/generation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
Expand All @@ -23,6 +24,7 @@ func TestGenerateRayCluterApplyConfig(t *testing.T) {
WorkerReplicas: 3,
WorkerCPU: "2",
WorkerMemory: "10Gi",
WorkerGPU: "1",
},
}

Expand All @@ -37,6 +39,7 @@ func TestGenerateRayCluterApplyConfig(t *testing.T) {
assert.Equal(t, testRayClusterYamlObject.WorkerGrpName, *result.Spec.WorkerGroupSpecs[0].GroupName)
assert.Equal(t, testRayClusterYamlObject.WorkerReplicas, *result.Spec.WorkerGroupSpecs[0].Replicas)
assert.Equal(t, resource.MustParse(testRayClusterYamlObject.WorkerCPU), *result.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests.Cpu())
assert.Equal(t, resource.MustParse(testRayClusterYamlObject.WorkerGPU), *result.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests.Name(corev1.ResourceName("nvidia.com/gpu"), resource.DecimalSI))
assert.Equal(t, resource.MustParse(testRayClusterYamlObject.WorkerMemory), *result.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests.Memory())
}

Expand All @@ -54,6 +57,7 @@ func TestGenerateRayJobApplyConfig(t *testing.T) {
WorkerReplicas: 3,
WorkerCPU: "2",
WorkerMemory: "10Gi",
WorkerGPU: "0",
},
}

Expand Down Expand Up @@ -85,6 +89,7 @@ func TestConvertRayClusterApplyConfigToYaml(t *testing.T) {
WorkerReplicas: 3,
WorkerCPU: "2",
WorkerMemory: "10Gi",
WorkerGPU: "0",
},
}

Expand Down

0 comments on commit ee4e97c

Please sign in to comment.