From 76edcb443b93626df0d13cdaa8b892e4ca109dbd Mon Sep 17 00:00:00 2001 From: ljstrnadiii Date: Tue, 28 Mar 2023 15:25:12 +0000 Subject: [PATCH 1/2] add cuda support Signed-off-by: ljstrnadiii --- go/tasks/plugins/k8s/dask/dask.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/go/tasks/plugins/k8s/dask/dask.go b/go/tasks/plugins/k8s/dask/dask.go index aa820bce3..d2726b56c 100755 --- a/go/tasks/plugins/k8s/dask/dask.go +++ b/go/tasks/plugins/k8s/dask/dask.go @@ -187,6 +187,12 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, defaults defaults) (*dask memory := limits.Memory().String() workerArgs = append(workerArgs, "--memory-limit", memory) } + // If limits includes gpu, assume dask cuda worker cli startup + // https://docs.rapids.ai/api/dask-cuda/nightly/quickstart.html#dask-cuda-worker + // TODO: is this how gpu resources are called? + if limits.Gpu() != nil { + workerArgs[0] = "dask cuda worker" + } } wokerSpec := v1.PodSpec{ From 6a82451b81502ebc02c85bc9c096c988cc8979c4 Mon Sep 17 00:00:00 2001 From: ljstrnadiii Date: Mon, 3 Apr 2023 03:27:29 +0000 Subject: [PATCH 2/2] add simple test for workerspec arg + expected resources Signed-off-by: ljstrnadiii --- go/tasks/plugins/k8s/dask/dask.go | 5 ++--- go/tasks/plugins/k8s/dask/dask_test.go | 28 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/go/tasks/plugins/k8s/dask/dask.go b/go/tasks/plugins/k8s/dask/dask.go index d2726b56c..15487c92f 100755 --- a/go/tasks/plugins/k8s/dask/dask.go +++ b/go/tasks/plugins/k8s/dask/dask.go @@ -189,9 +189,8 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, defaults defaults) (*dask } // If limits includes gpu, assume dask cuda worker cli startup // https://docs.rapids.ai/api/dask-cuda/nightly/quickstart.html#dask-cuda-worker - // TODO: is this how gpu resources are called? - if limits.Gpu() != nil { - workerArgs[0] = "dask cuda worker" + if !limits.Name(flytek8s.ResourceNvidiaGPU, "0").IsZero() { + workerArgs[0] = "dask-cuda-worker" } } diff --git a/go/tasks/plugins/k8s/dask/dask_test.go b/go/tasks/plugins/k8s/dask/dask_test.go index 2eb36ad3b..25004e487 100644 --- a/go/tasks/plugins/k8s/dask/dask_test.go +++ b/go/tasks/plugins/k8s/dask/dask_test.go @@ -340,6 +340,34 @@ func TestBuildResourceDaskDefaultResoureRequirements(t *testing.T) { assert.Contains(t, workerSpec.Containers[0].Args, "2G") } +func TestBuildResourceGPUCudaWorkerArgs(t *testing.T) { + protobufResources := core.Resources{ + Limits: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_GPU, + Value: "1", + }, + }, + } + expectedResources, _ := flytek8s.ToK8sResourceRequirements(&protobufResources) + + flyteWorkflowResources := v1.ResourceRequirements{} + + daskResourceHandler := daskResourceHandler{} + taskTemplate := dummyDaskTaskTemplate("", &protobufResources) + taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, false) + resource, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, resource) + daskJob, ok := resource.(*daskAPI.DaskJob) + assert.True(t, ok) + + // Default Workers + workerSpec := daskJob.Spec.Cluster.Spec.Worker.Spec + assert.Equal(t, *expectedResources, workerSpec.Containers[0].Resources) + assert.Contains(t, workerSpec.Containers[0].Args, "dask-cuda-worker") +} + func TestBuildResourcesDaskCustomResoureRequirements(t *testing.T) { protobufResources := core.Resources{ Requests: []*core.Resources_ResourceEntry{