Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support extended resources for Ray pods #2436

Merged
merged 6 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions apiserver/pkg/model/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ func FromKubeToAPIComputeTemplate(configMap *corev1.ConfigMap) *api.ComputeTempl
cpu, _ := strconv.ParseUint(configMap.Data["cpu"], 10, 32)
memory, _ := strconv.ParseUint(configMap.Data["memory"], 10, 32)
gpu, _ := strconv.ParseUint(configMap.Data["gpu"], 10, 32)
efa, _ := strconv.ParseUint(configMap.Data["efa"], 10, 32)

runtime := &api.ComputeTemplate{}
runtime.Name = configMap.Name
Expand All @@ -385,6 +386,7 @@ func FromKubeToAPIComputeTemplate(configMap *corev1.ConfigMap) *api.ComputeTempl
runtime.Memory = uint32(memory)
runtime.Gpu = uint32(gpu)
runtime.GpuAccelerator = configMap.Data["gpu_accelerator"]
runtime.Efa = uint32(efa)
val, ok := configMap.Data["tolerations"]
if ok {
err := json.Unmarshal([]byte(val), &runtime.Tolerations)
Expand Down
7 changes: 7 additions & 0 deletions apiserver/pkg/model/converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ var configMapWithoutTolerations = corev1.ConfigMap{
"gpu": "0",
"gpu_accelerator": "",
"memory": "8",
"efa": "0",
"name": "head-node-template",
"namespace": "max",
},
Expand All @@ -141,6 +142,7 @@ var configMapWithTolerations = corev1.ConfigMap{
"gpu": "0",
"gpu_accelerator": "",
"memory": "8",
"efa": "0",
"name": "head-node-template",
"namespace": "max",
"tolerations": "[{\"key\":\"blah1\",\"operator\":\"Exists\",\"effect\":\"NoExecute\"}]",
Expand Down Expand Up @@ -578,6 +580,11 @@ func TestPopulateTemplate(t *testing.T) {
t.Errorf("failed to convert config map, got %v, expected %v", tolerationToString(template.Tolerations[0]),
tolerationToString(&expectedTolerations))
}

assert.Equal(t, uint32(4), template.Cpu, "CPU mismatch")
assert.Equal(t, uint32(8), template.Memory, "Memory mismatch")
assert.Equal(t, uint32(0), template.Gpu, "GPU mismatch")
assert.Equal(t, uint32(0), template.Efa, "EFA mismatch")
}

func tolerationToString(toleration *api.PodToleration) string {
Expand Down
31 changes: 22 additions & 9 deletions apiserver/pkg/util/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ func buildNodeGroupAnnotations(computeTemplate *api.ComputeTemplate, image strin
return annotations
}

// Add resource to container
func addResourceToContainer(container *corev1.Container, resourceName string, quantity uint32) {
if quantity > 0 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

if quantity == 0 {
  return
}

quantityStr := fmt.Sprint(quantity)\
container.Resources.Requests[corev1.ResourceName(resourceName)] = resource.MustParse(quantityStr)
container.Resources.Limits[corev1.ResourceName(resourceName)] = resource.MustParse(quantityStr)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

quantityStr := fmt.Sprint(quantity)
container.Resources.Requests[corev1.ResourceName(resourceName)] = resource.MustParse(quantityStr)
container.Resources.Limits[corev1.ResourceName(resourceName)] = resource.MustParse(quantityStr)
}
}

// Build head node template
func buildHeadPodTemplate(imageVersion string, envs *api.EnvironmentVariables, spec *api.HeadGroupSpec, computeRuntime *api.ComputeTemplate, enableServeService bool) (*corev1.PodTemplateSpec, error) {
image := constructRayImage(RayClusterDefaultImageRepository, imageVersion)
Expand Down Expand Up @@ -232,15 +241,18 @@ func buildHeadPodTemplate(imageVersion string, envs *api.EnvironmentVariables, s
// We are filtering container by name `ray-head`. If container with this name does not exist
// (should never happen) we are not adding container specific parameters
if container, index, ok := GetContainerByName(podTemplateSpec.Spec.Containers, "ray-head"); ok {
if computeRuntime.GetGpu() != 0 {
gpu := computeRuntime.GetGpu()
if gpu := computeRuntime.GetGpu(); gpu != 0 {
accelerator := "nvidia.com/gpu"
if len(computeRuntime.GetGpuAccelerator()) != 0 {
accelerator = computeRuntime.GetGpuAccelerator()
}
container.Resources.Requests[corev1.ResourceName(accelerator)] = resource.MustParse(fmt.Sprint(gpu))
container.Resources.Limits[corev1.ResourceName(accelerator)] = resource.MustParse(fmt.Sprint(gpu))
addResourceToContainer(&container, accelerator, gpu)
}

if efa := computeRuntime.GetEfa(); efa != 0 {
addResourceToContainer(&container, "vpc.amazonaws.com/efa", efa)
}

globalEnv := convertEnvironmentVariables(envs)
if len(globalEnv) > 0 {
container.Env = append(container.Env, globalEnv...)
Expand Down Expand Up @@ -528,16 +540,16 @@ func buildWorkerPodTemplate(imageVersion string, envs *api.EnvironmentVariables,
// We are filtering container by name `ray-worker`. If container with this name does not exist
// (should never happen) we are not adding container specific parameters
if container, index, ok := GetContainerByName(podTemplateSpec.Spec.Containers, "ray-worker"); ok {
if computeRuntime.GetGpu() != 0 {
gpu := computeRuntime.GetGpu()
if gpu := computeRuntime.GetGpu(); gpu != 0 {
accelerator := "nvidia.com/gpu"
if len(computeRuntime.GetGpuAccelerator()) != 0 {
accelerator = computeRuntime.GetGpuAccelerator()
}
addResourceToContainer(&container, accelerator, gpu)
}

// need smarter algorithm to filter main container. for example filter by name `ray-worker`
container.Resources.Requests[corev1.ResourceName(accelerator)] = resource.MustParse(fmt.Sprint(gpu))
container.Resources.Limits[corev1.ResourceName(accelerator)] = resource.MustParse(fmt.Sprint(gpu))
if efa := computeRuntime.GetEfa(); efa != 0 {
addResourceToContainer(&container, "vpc.amazonaws.com/efa", efa)
}

globalEnv := convertEnvironmentVariables(envs)
Expand Down Expand Up @@ -808,6 +820,7 @@ func NewComputeTemplate(runtime *api.ComputeTemplate) (*corev1.ConfigMap, error)
"memory": strconv.FormatUint(uint64(runtime.Memory), 10),
"gpu": strconv.FormatUint(uint64(runtime.Gpu), 10),
"gpu_accelerator": runtime.GpuAccelerator,
"efa": strconv.FormatUint(uint64(runtime.Efa), 10),
}
// Add tolerations in defined
if runtime.Tolerations != nil && len(runtime.Tolerations) > 0 {
Expand Down
67 changes: 41 additions & 26 deletions apiserver/pkg/util/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,22 @@ var template = api.ComputeTemplate{
},
}

var templateWorker = api.ComputeTemplate{
Name: "",
Namespace: "",
Cpu: 2,
Memory: 8,
Gpu: 4,
Efa: 4,
Tolerations: []*api.PodToleration{
{
Key: "blah1",
Operator: "Exists",
Effect: "NoExecute",
},
},
}

var expectedToleration = corev1.Toleration{
Key: "blah1",
Operator: "Exists",
Expand Down Expand Up @@ -591,34 +607,33 @@ func TestBuildRayCluster(t *testing.T) {
}

func TestBuilWorkerPodTemplate(t *testing.T) {
podSpec, err := buildWorkerPodTemplate("2.4", &api.EnvironmentVariables{}, &workerGroup, &template)
podSpec, err := buildWorkerPodTemplate("2.4", &api.EnvironmentVariables{}, &workerGroup, &templateWorker)
assert.Nil(t, err)

if podSpec.Spec.ServiceAccountName != "account" {
t.Errorf("failed to propagate service account")
}
if podSpec.Spec.ImagePullSecrets[0].Name != "foo" {
t.Errorf("failed to propagate image pull secret")
}
if (string)(podSpec.Spec.Containers[0].ImagePullPolicy) != "Always" {
t.Errorf("failed to propagate image pull policy")
}
if !containsEnv(podSpec.Spec.Containers[0].Env, "foo", "bar") {
t.Errorf("failed to propagate environment")
}
if len(podSpec.Spec.Tolerations) != 1 {
t.Errorf("failed to propagate tolerations, expected 1, got %d", len(podSpec.Spec.Tolerations))
}
if !reflect.DeepEqual(podSpec.Spec.Tolerations[0], expectedToleration) {
t.Errorf("failed to propagate annotations, got %v, expected %v", tolerationToString(&podSpec.Spec.Tolerations[0]),
tolerationToString(&expectedToleration))
}
if val, exists := podSpec.Annotations["foo"]; !exists || val != "bar" {
t.Errorf("failed to convert annotations")
}
if !reflect.DeepEqual(podSpec.Labels, expectedLabels) {
t.Errorf("failed to convert labels, got %v, expected %v", podSpec.Labels, expectedLabels)
}
assert.Equal(t, "account", podSpec.Spec.ServiceAccountName, "failed to propagate service account")
assert.Equal(t, "foo", podSpec.Spec.ImagePullSecrets[0].Name, "failed to propagate image pull secret")
assert.Equal(t, corev1.PullAlways, podSpec.Spec.Containers[0].ImagePullPolicy, "failed to propagate image pull policy")
assert.True(t, containsEnv(podSpec.Spec.Containers[0].Env, "foo", "bar"), "failed to propagate environment")
assert.Len(t, podSpec.Spec.Tolerations, 1, "failed to propagate tolerations")
assert.Equal(t, expectedToleration, podSpec.Spec.Tolerations[0], "failed to propagate tolerations")
assert.Equal(t, "bar", podSpec.Annotations["foo"], "failed to convert annotations")
assert.Equal(t, expectedLabels, podSpec.Labels, "failed to convert labels")

// Check Resources
container := podSpec.Spec.Containers[0]
resources := container.Resources

assert.Equal(t, resource.MustParse("2"), resources.Limits[corev1.ResourceCPU], "CPU limit doesn't match")
assert.Equal(t, resource.MustParse("2"), resources.Requests[corev1.ResourceCPU], "CPU request doesn't match")

assert.Equal(t, resource.MustParse("8Gi"), resources.Limits[corev1.ResourceMemory], "Memory limit doesn't match")
assert.Equal(t, resource.MustParse("8Gi"), resources.Requests[corev1.ResourceMemory], "Memory request doesn't match")

assert.Equal(t, resource.MustParse("4"), resources.Limits["nvidia.com/gpu"], "GPU limit doesn't match")
assert.Equal(t, resource.MustParse("4"), resources.Requests["nvidia.com/gpu"], "GPU request doesn't match")

assert.Equal(t, resource.MustParse("4"), resources.Limits["vpc.amazonaws.com/efa"], "EFA limit doesn't match")
assert.Equal(t, resource.MustParse("4"), resources.Requests["vpc.amazonaws.com/efa"], "EFA request doesn't match")
}

func containsEnv(envs []corev1.EnvVar, key string, val string) bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Template:
memory - required, template memory (GB)
gpus - optional, number of GPUs, default 0
gpu_accelerator - optional, if not defined nvidia.com/gpu is assumed
efa - optional, if not defined vpc.amazonaws.com/efa is assumed
tolerations - optional, tolerations for pod placing, default none
- to_string() -> str: convert toleration to string for printing
- to_dict() -> dict[str, Any] convert to dict
Expand All @@ -106,6 +107,7 @@ def __init__(
memory: int,
gpu: int = 0,
gpu_accelerator: str = None,
efa: int = None,
tolerations: list[Toleration] = None,
):
"""
Expand All @@ -124,6 +126,7 @@ def __init__(
self.memory = memory
self.gpu = gpu
self.gpu_accelerator = gpu_accelerator
self.efa = efa
self.tolerations = tolerations

def to_string(self) -> str:
Expand All @@ -136,6 +139,8 @@ def to_string(self) -> str:
val = val + f", gpu {self.gpu}"
if self.gpu_accelerator is not None:
val = val + f", gpu accelerator {self.gpu_accelerator}"
if self.efa is not None:
val = val + f", efa {self.efa}"
if self.tolerations is None:
return val
val = val + ", tolerations ["
Expand All @@ -158,6 +163,8 @@ def to_dict(self) -> dict[str, Any]:
dct["gpu"] = self.gpu
if self.gpu_accelerator is not None:
dct["gpu accelerator"] = self.gpu_accelerator
if self.efa is not None:
dct["efa"] = self.efa
if self.tolerations is not None:
dct["tolerations"] = [tl.to_dict() for tl in self.tolerations]
return dct
Expand Down Expand Up @@ -199,6 +206,7 @@ def template_decoder(dct: dict[str, Any]) -> Template:
memory=int(dct.get("memory", "0")),
gpu=int(dct.get("gpu", "0")),
gpu_accelerator=dct.get("gpu_accelerator"),
efa=dct.get("efa"),
tolerations=tolerations,
)

Expand Down
6 changes: 6 additions & 0 deletions clients/python-apiserver-client/test/api_params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,14 @@ def test_templates():
tm2_json = json.dumps(temp2.to_dict())
print(f"template 2 JSON: {tm2_json}")

temp3 = Template(name="template3", namespace="namespace", cpu=2, memory=8, gpu=1, efa=4)
print(f"template 3: {temp3.to_string()}")
tm3_json = json.dumps(temp3.to_dict())
print(f"template 3 JSON: {tm3_json}")

assert temp1.to_string() == template_decoder(json.loads(tm1_json)).to_string()
assert temp2.to_string() == template_decoder(json.loads(tm2_json)).to_string()
assert temp3.to_string() == template_decoder(json.loads(tm3_json)).to_string()


def test_volumes():
Expand Down
2 changes: 1 addition & 1 deletion clients/python-apiserver-client/test/kuberay_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_templates():
_, _ = apis.delete_compute_template(ns="default", name="default-template")
# create
toleration = Toleration(key="blah1", operator=TolerationOperation.Exists, effect=TolerationEffect.NoExecute)
template = Template(name="default-template", namespace="default", cpu=2, memory=8, tolerations=[toleration])
template = Template(name="default-template", namespace="default", cpu=2, memory=8, gpu=1, efa=4, tolerations=[toleration])
status, error = apis.create_compute_template(template)
assert status == 200
assert error is None
Expand Down
2 changes: 2 additions & 0 deletions proto/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ message ComputeTemplate {
string gpu_accelerator = 6;
// Optional pod tolerations
repeated PodToleration tolerations = 7;
// Optional. Number of efas
uint32 efa = 8;
}

// This service is not implemented.
Expand Down
Loading
Loading