diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index e8de072b90..6fa5790b21 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -1,654 +1,654 @@ package ray import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "regexp" - "strconv" - "strings" - "time" - - rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" - "gopkg.in/yaml.v2" - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/kubernetes/scheme" - "sigs.k8s.io/controller-runtime/pkg/client" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - flyteerr "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" - pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "regexp" + "strconv" + "strings" + "time" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "gopkg.in/yaml.v2" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + flyteerr "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" ) const ( - rayStateMountPath = "/tmp/ray" - defaultRayStateVolName = "system-ray-state" - rayTaskType = "ray" - KindRayJob = "RayJob" - IncludeDashboard = "include-dashboard" - NodeIPAddress = "node-ip-address" - DashboardHost = "dashboard-host" - DisableUsageStatsStartParameter = "disable-usage-stats" - DisableUsageStatsStartParameterVal = "true" + rayStateMountPath = "/tmp/ray" + defaultRayStateVolName = "system-ray-state" + rayTaskType = "ray" + KindRayJob = "RayJob" + IncludeDashboard = "include-dashboard" + NodeIPAddress = "node-ip-address" + DashboardHost = "dashboard-host" + DisableUsageStatsStartParameter = "disable-usage-stats" + DisableUsageStatsStartParameterVal = "true" ) var logTemplateRegexes = struct { - RayClusterName *regexp.Regexp - RayJobID *regexp.Regexp + RayClusterName *regexp.Regexp + RayJobID *regexp.Regexp }{ - tasklog.MustCreateRegex("rayClusterName"), - tasklog.MustCreateRegex("rayJobID"), + tasklog.MustCreateRegex("rayClusterName"), + tasklog.MustCreateRegex("rayJobID"), } type rayJobResourceHandler struct{} func (rayJobResourceHandler) GetProperties() k8s.PluginProperties { - return k8s.PluginProperties{} + return k8s.PluginProperties{} } // BuildResource Creates a new ray job resource func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { - taskTemplate, err := taskCtx.TaskReader().Read(ctx) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error()) - } else if taskTemplate == nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") - } - - rayJob := plugins.RayJob{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &rayJob) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) - } - - podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) - } - - var primaryContainer *v1.Container - var primaryContainerIdx int - for idx, c := range podSpec.Containers { - if c.Name == primaryContainerName { - c := c - primaryContainer = &c - primaryContainerIdx = idx - break - } - } - - if primaryContainer == nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to get primary container from the pod: [%v]", err.Error()) - } - - cfg := GetConfig() - - headNodeRayStartParams := make(map[string]string) - if rayJob.RayCluster.HeadGroupSpec != nil && rayJob.RayCluster.HeadGroupSpec.RayStartParams != nil { - headNodeRayStartParams = rayJob.RayCluster.HeadGroupSpec.RayStartParams - } else if headNode := cfg.Defaults.HeadNode; len(headNode.StartParameters) > 0 { - headNodeRayStartParams = headNode.StartParameters - } - - if _, exist := headNodeRayStartParams[IncludeDashboard]; !exist { - headNodeRayStartParams[IncludeDashboard] = strconv.FormatBool(GetConfig().IncludeDashboard) - } - - if _, exist := headNodeRayStartParams[NodeIPAddress]; !exist { - headNodeRayStartParams[NodeIPAddress] = cfg.Defaults.HeadNode.IPAddress - } - - if _, exist := headNodeRayStartParams[DashboardHost]; !exist { - headNodeRayStartParams[DashboardHost] = cfg.DashboardHost - } - - if _, exists := headNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { - headNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal - } - - podSpec.ServiceAccountName = cfg.ServiceAccount - - rayjob, err := constructRayJob(taskCtx, &rayJob, objectMeta, *podSpec, headNodeRayStartParams, primaryContainerIdx, *primaryContainer) - - return rayjob, err + taskTemplate, err := taskCtx.TaskReader().Read(ctx) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error()) + } else if taskTemplate == nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") + } + + rayJob := plugins.RayJob{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &rayJob) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) + } + + var primaryContainer *v1.Container + var primaryContainerIdx int + for idx, c := range podSpec.Containers { + if c.Name == primaryContainerName { + c := c + primaryContainer = &c + primaryContainerIdx = idx + break + } + } + + if primaryContainer == nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to get primary container from the pod: [%v]", err.Error()) + } + + cfg := GetConfig() + + headNodeRayStartParams := make(map[string]string) + if rayJob.RayCluster.HeadGroupSpec != nil && rayJob.RayCluster.HeadGroupSpec.RayStartParams != nil { + headNodeRayStartParams = rayJob.RayCluster.HeadGroupSpec.RayStartParams + } else if headNode := cfg.Defaults.HeadNode; len(headNode.StartParameters) > 0 { + headNodeRayStartParams = headNode.StartParameters + } + + if _, exist := headNodeRayStartParams[IncludeDashboard]; !exist { + headNodeRayStartParams[IncludeDashboard] = strconv.FormatBool(GetConfig().IncludeDashboard) + } + + if _, exist := headNodeRayStartParams[NodeIPAddress]; !exist { + headNodeRayStartParams[NodeIPAddress] = cfg.Defaults.HeadNode.IPAddress + } + + if _, exist := headNodeRayStartParams[DashboardHost]; !exist { + headNodeRayStartParams[DashboardHost] = cfg.DashboardHost + } + + if _, exists := headNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { + headNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal + } + + podSpec.ServiceAccountName = cfg.ServiceAccount + + rayjob, err := constructRayJob(taskCtx, &rayJob, objectMeta, *podSpec, headNodeRayStartParams, primaryContainerIdx, *primaryContainer) + + return rayjob, err } func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob *plugins.RayJob, objectMeta *metav1.ObjectMeta, taskPodSpec v1.PodSpec, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) (*rayv1.RayJob, error) { - enableIngress := true - cfg := GetConfig() - - headPodSpec := taskPodSpec.DeepCopy() - headPodTemplate, err := buildHeadPodTemplate( - &headPodSpec.Containers[primaryContainerIdx], - headPodSpec, - objectMeta, - taskCtx, - rayJob.RayCluster.HeadGroupSpec, - ) - if err != nil { - return nil, err - } - - rayClusterSpec := rayv1.RayClusterSpec{ - HeadGroupSpec: rayv1.HeadGroupSpec{ - Template: headPodTemplate, - ServiceType: v1.ServiceType(cfg.ServiceType), - EnableIngress: &enableIngress, - RayStartParams: headNodeRayStartParams, - }, - WorkerGroupSpecs: []rayv1.WorkerGroupSpec{}, - EnableInTreeAutoscaling: &rayJob.RayCluster.EnableAutoscaling, - } - - for _, spec := range rayJob.RayCluster.WorkerGroupSpec { - workerPodSpec := taskPodSpec.DeepCopy() - workerPodTemplate, err := buildWorkerPodTemplate( - &workerPodSpec.Containers[primaryContainerIdx], - workerPodSpec, - objectMeta, - taskCtx, - spec, - ) - if err != nil { - return nil, err - } - - workerNodeRayStartParams := make(map[string]string) - if spec.RayStartParams != nil { - workerNodeRayStartParams = spec.RayStartParams - } else if workerNode := cfg.Defaults.WorkerNode; len(workerNode.StartParameters) > 0 { - workerNodeRayStartParams = workerNode.StartParameters - } - - if _, exist := workerNodeRayStartParams[NodeIPAddress]; !exist { - workerNodeRayStartParams[NodeIPAddress] = cfg.Defaults.WorkerNode.IPAddress - } - - if _, exists := workerNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { - workerNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal - } - - minReplicas := spec.MinReplicas - if minReplicas > spec.Replicas { - minReplicas = spec.Replicas - } - maxReplicas := spec.MaxReplicas - if maxReplicas < spec.Replicas { - maxReplicas = spec.Replicas - } - - workerNodeSpec := rayv1.WorkerGroupSpec{ - GroupName: spec.GroupName, - MinReplicas: &minReplicas, - MaxReplicas: &maxReplicas, - Replicas: &spec.Replicas, - RayStartParams: workerNodeRayStartParams, - Template: workerPodTemplate, - } - - rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec) - } - - serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) - if len(serviceAccountName) == 0 { - serviceAccountName = cfg.ServiceAccount - } - - rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName - for index := range rayClusterSpec.WorkerGroupSpecs { - rayClusterSpec.WorkerGroupSpecs[index].Template.Spec.ServiceAccountName = serviceAccountName - } - - shutdownAfterJobFinishes := cfg.ShutdownAfterJobFinishes - ttlSecondsAfterFinished := &cfg.TTLSecondsAfterFinished - if rayJob.ShutdownAfterJobFinishes { - shutdownAfterJobFinishes = true - ttlSecondsAfterFinished = &rayJob.TtlSecondsAfterFinished - } - - submitterPodSpec := taskPodSpec.DeepCopy() - submitterPodTemplate := buildSubmitterPodTemplate(submitterPodSpec, objectMeta, taskCtx) - - // TODO: This is for backward compatibility. Remove this block once runtime_env is removed from ray proto. - var runtimeEnvYaml string - runtimeEnvYaml = rayJob.RuntimeEnvYaml - // If runtime_env exists but runtime_env_yaml does not, convert runtime_env to runtime_env_yaml - if rayJob.RuntimeEnv != "" && rayJob.RuntimeEnvYaml == "" { - runtimeEnvYaml, err = convertBase64RuntimeEnvToYaml(rayJob.RuntimeEnv) - if err != nil { - return nil, err - } - } - - jobSpec := rayv1.RayJobSpec{ - RayClusterSpec: &rayClusterSpec, - Entrypoint: strings.Join(primaryContainer.Args, " "), - ShutdownAfterJobFinishes: shutdownAfterJobFinishes, - TTLSecondsAfterFinished: *ttlSecondsAfterFinished, - RuntimeEnvYAML: runtimeEnvYaml, - SubmitterPodTemplate: &submitterPodTemplate, - } - - return &rayv1.RayJob{ - TypeMeta: metav1.TypeMeta{ - Kind: KindRayJob, - APIVersion: rayv1.SchemeGroupVersion.String(), - }, - Spec: jobSpec, - ObjectMeta: *objectMeta, - }, nil + enableIngress := true + cfg := GetConfig() + + headPodSpec := taskPodSpec.DeepCopy() + headPodTemplate, err := buildHeadPodTemplate( + &headPodSpec.Containers[primaryContainerIdx], + headPodSpec, + objectMeta, + taskCtx, + rayJob.RayCluster.HeadGroupSpec, + ) + if err != nil { + return nil, err + } + + rayClusterSpec := rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: headPodTemplate, + ServiceType: v1.ServiceType(cfg.ServiceType), + EnableIngress: &enableIngress, + RayStartParams: headNodeRayStartParams, + }, + WorkerGroupSpecs: []rayv1.WorkerGroupSpec{}, + EnableInTreeAutoscaling: &rayJob.RayCluster.EnableAutoscaling, + } + + for _, spec := range rayJob.RayCluster.WorkerGroupSpec { + workerPodSpec := taskPodSpec.DeepCopy() + workerPodTemplate, err := buildWorkerPodTemplate( + &workerPodSpec.Containers[primaryContainerIdx], + workerPodSpec, + objectMeta, + taskCtx, + spec, + ) + if err != nil { + return nil, err + } + + workerNodeRayStartParams := make(map[string]string) + if spec.RayStartParams != nil { + workerNodeRayStartParams = spec.RayStartParams + } else if workerNode := cfg.Defaults.WorkerNode; len(workerNode.StartParameters) > 0 { + workerNodeRayStartParams = workerNode.StartParameters + } + + if _, exist := workerNodeRayStartParams[NodeIPAddress]; !exist { + workerNodeRayStartParams[NodeIPAddress] = cfg.Defaults.WorkerNode.IPAddress + } + + if _, exists := workerNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { + workerNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal + } + + minReplicas := spec.MinReplicas + if minReplicas > spec.Replicas { + minReplicas = spec.Replicas + } + maxReplicas := spec.MaxReplicas + if maxReplicas < spec.Replicas { + maxReplicas = spec.Replicas + } + + workerNodeSpec := rayv1.WorkerGroupSpec{ + GroupName: spec.GroupName, + MinReplicas: &minReplicas, + MaxReplicas: &maxReplicas, + Replicas: &spec.Replicas, + RayStartParams: workerNodeRayStartParams, + Template: workerPodTemplate, + } + + rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec) + } + + serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + if len(serviceAccountName) == 0 { + serviceAccountName = cfg.ServiceAccount + } + + rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName + for index := range rayClusterSpec.WorkerGroupSpecs { + rayClusterSpec.WorkerGroupSpecs[index].Template.Spec.ServiceAccountName = serviceAccountName + } + + shutdownAfterJobFinishes := cfg.ShutdownAfterJobFinishes + ttlSecondsAfterFinished := &cfg.TTLSecondsAfterFinished + if rayJob.ShutdownAfterJobFinishes { + shutdownAfterJobFinishes = true + ttlSecondsAfterFinished = &rayJob.TtlSecondsAfterFinished + } + + submitterPodSpec := taskPodSpec.DeepCopy() + submitterPodTemplate := buildSubmitterPodTemplate(submitterPodSpec, objectMeta, taskCtx) + + // TODO: This is for backward compatibility. Remove this block once runtime_env is removed from ray proto. + var runtimeEnvYaml string + runtimeEnvYaml = rayJob.RuntimeEnvYaml + // If runtime_env exists but runtime_env_yaml does not, convert runtime_env to runtime_env_yaml + if rayJob.RuntimeEnv != "" && rayJob.RuntimeEnvYaml == "" { + runtimeEnvYaml, err = convertBase64RuntimeEnvToYaml(rayJob.RuntimeEnv) + if err != nil { + return nil, err + } + } + + jobSpec := rayv1.RayJobSpec{ + RayClusterSpec: &rayClusterSpec, + Entrypoint: strings.Join(primaryContainer.Args, " "), + ShutdownAfterJobFinishes: shutdownAfterJobFinishes, + TTLSecondsAfterFinished: *ttlSecondsAfterFinished, + RuntimeEnvYAML: runtimeEnvYaml, + SubmitterPodTemplate: &submitterPodTemplate, + } + + return &rayv1.RayJob{ + TypeMeta: metav1.TypeMeta{ + Kind: KindRayJob, + APIVersion: rayv1.SchemeGroupVersion.String(), + }, + Spec: jobSpec, + ObjectMeta: *objectMeta, + }, nil } func convertBase64RuntimeEnvToYaml(s string) (string, error) { - // Decode from base64 - data, err := base64.StdEncoding.DecodeString(s) - if err != nil { - return "", err - } - - // Unmarshal JSON - var obj map[string]interface{} - err = json.Unmarshal(data, &obj) - if err != nil { - return "", err - } - - // Convert to YAML - y, err := yaml.Marshal(&obj) - if err != nil { - return "", err - } - - return string(y), nil + // Decode from base64 + data, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return "", err + } + + // Unmarshal JSON + var obj map[string]interface{} + err = json.Unmarshal(data, &obj) + if err != nil { + return "", err + } + + // Convert to YAML + y, err := yaml.Marshal(&obj) + if err != nil { + return "", err + } + + return string(y), nil } func injectLogsSidecar(primaryContainer *v1.Container, podSpec *v1.PodSpec) { - cfg := GetConfig() - if cfg.LogsSidecar == nil { - return - } - sidecar := cfg.LogsSidecar.DeepCopy() - - // Ray logs integration - var rayStateVolMount *v1.VolumeMount - // Look for an existing volume mount on the primary container, mounted at /tmp/ray - for _, vm := range primaryContainer.VolumeMounts { - if vm.MountPath == rayStateMountPath { - vm := vm - rayStateVolMount = &vm - break - } - } - // No existing volume mount exists at /tmp/ray. We create a new volume and volume - // mount and add it to the pod and container specs respectively - if rayStateVolMount == nil { - vol := v1.Volume{ - Name: defaultRayStateVolName, - VolumeSource: v1.VolumeSource{ - EmptyDir: &v1.EmptyDirVolumeSource{}, - }, - } - podSpec.Volumes = append(podSpec.Volumes, vol) - volMount := v1.VolumeMount{ - Name: defaultRayStateVolName, - MountPath: rayStateMountPath, - } - primaryContainer.VolumeMounts = append(primaryContainer.VolumeMounts, volMount) - rayStateVolMount = &volMount - } - // We need to mirror the ray state volume mount into the sidecar as readonly, - // so that we can read the logs written by the head node. - readOnlyRayStateVolMount := *rayStateVolMount.DeepCopy() - readOnlyRayStateVolMount.ReadOnly = true - - // Update volume mounts on sidecar - // If one already exists with the desired mount path, simply replace it. Otherwise, - // add it to sidecar's volume mounts. - foundExistingSidecarVolMount := false - for idx, vm := range sidecar.VolumeMounts { - if vm.MountPath == rayStateMountPath { - foundExistingSidecarVolMount = true - sidecar.VolumeMounts[idx] = readOnlyRayStateVolMount - } - } - if !foundExistingSidecarVolMount { - sidecar.VolumeMounts = append(sidecar.VolumeMounts, readOnlyRayStateVolMount) - } - - // Add sidecar to containers - podSpec.Containers = append(podSpec.Containers, *sidecar) + cfg := GetConfig() + if cfg.LogsSidecar == nil { + return + } + sidecar := cfg.LogsSidecar.DeepCopy() + + // Ray logs integration + var rayStateVolMount *v1.VolumeMount + // Look for an existing volume mount on the primary container, mounted at /tmp/ray + for _, vm := range primaryContainer.VolumeMounts { + if vm.MountPath == rayStateMountPath { + vm := vm + rayStateVolMount = &vm + break + } + } + // No existing volume mount exists at /tmp/ray. We create a new volume and volume + // mount and add it to the pod and container specs respectively + if rayStateVolMount == nil { + vol := v1.Volume{ + Name: defaultRayStateVolName, + VolumeSource: v1.VolumeSource{ + EmptyDir: &v1.EmptyDirVolumeSource{}, + }, + } + podSpec.Volumes = append(podSpec.Volumes, vol) + volMount := v1.VolumeMount{ + Name: defaultRayStateVolName, + MountPath: rayStateMountPath, + } + primaryContainer.VolumeMounts = append(primaryContainer.VolumeMounts, volMount) + rayStateVolMount = &volMount + } + // We need to mirror the ray state volume mount into the sidecar as readonly, + // so that we can read the logs written by the head node. + readOnlyRayStateVolMount := *rayStateVolMount.DeepCopy() + readOnlyRayStateVolMount.ReadOnly = true + + // Update volume mounts on sidecar + // If one already exists with the desired mount path, simply replace it. Otherwise, + // add it to sidecar's volume mounts. + foundExistingSidecarVolMount := false + for idx, vm := range sidecar.VolumeMounts { + if vm.MountPath == rayStateMountPath { + foundExistingSidecarVolMount = true + sidecar.VolumeMounts[idx] = readOnlyRayStateVolMount + } + } + if !foundExistingSidecarVolMount { + sidecar.VolumeMounts = append(sidecar.VolumeMounts, readOnlyRayStateVolMount) + } + + // Add sidecar to containers + podSpec.Containers = append(podSpec.Containers, *sidecar) } func buildHeadPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, spec *plugins.HeadGroupSpec) (v1.PodTemplateSpec, error) { - // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97 - // They should always be the same, so we could hard code here. - primaryContainer.Name = "ray-head" - - envs := []v1.EnvVar{ - { - Name: "MY_POD_IP", - ValueFrom: &v1.EnvVarSource{ - FieldRef: &v1.ObjectFieldSelector{ - FieldPath: "status.podIP", - }, - }, - }, - } - - primaryContainer.Args = []string{} - - primaryContainer.Env = append(primaryContainer.Env, envs...) - - ports := []v1.ContainerPort{ - { - Name: "redis", - ContainerPort: 6379, - }, - { - Name: "head", - ContainerPort: 10001, - }, - { - Name: "dashboard", - ContainerPort: 8265, - }, - } - - primaryContainer.Ports = append(primaryContainer.Ports, ports...) - - // Inject a sidecar for capturing and exposing Ray job logs - injectLogsSidecar(primaryContainer, podSpec) - - // Overwrite head pod taskResources if specified - if spec.Resources != nil { - res, err := flytek8s.ToK8sResourceRequirements(spec.Resources) - if err != nil { - return v1.PodTemplateSpec{}, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification HeadGroupSpec Resources[%v], Err: [%v]", spec.Resources, err.Error()) - } - - primaryContainer.Resources = *res - } - - podTemplateSpec := v1.PodTemplateSpec{ - Spec: *podSpec, - ObjectMeta: *objectMeta, - } - cfg := config.GetK8sPluginConfig() - podTemplateSpec.SetLabels(utils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) - podTemplateSpec.SetAnnotations(utils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) - return podTemplateSpec, nil + // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97 + // They should always be the same, so we could hard code here. + primaryContainer.Name = "ray-head" + + envs := []v1.EnvVar{ + { + Name: "MY_POD_IP", + ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + FieldPath: "status.podIP", + }, + }, + }, + } + + primaryContainer.Args = []string{} + + primaryContainer.Env = append(primaryContainer.Env, envs...) + + ports := []v1.ContainerPort{ + { + Name: "redis", + ContainerPort: 6379, + }, + { + Name: "head", + ContainerPort: 10001, + }, + { + Name: "dashboard", + ContainerPort: 8265, + }, + } + + primaryContainer.Ports = append(primaryContainer.Ports, ports...) + + // Inject a sidecar for capturing and exposing Ray job logs + injectLogsSidecar(primaryContainer, podSpec) + + // Overwrite head pod taskResources if specified + if spec.Resources != nil { + res, err := flytek8s.ToK8sResourceRequirements(spec.Resources) + if err != nil { + return v1.PodTemplateSpec{}, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification HeadGroupSpec Resources[%v], Err: [%v]", spec.Resources, err.Error()) + } + + primaryContainer.Resources = *res + } + + podTemplateSpec := v1.PodTemplateSpec{ + Spec: *podSpec, + ObjectMeta: *objectMeta, + } + cfg := config.GetK8sPluginConfig() + podTemplateSpec.SetLabels(utils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) + podTemplateSpec.SetAnnotations(utils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) + return podTemplateSpec, nil } func buildSubmitterPodTemplate(podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { - submitterPodSpec := podSpec.DeepCopy() + submitterPodSpec := podSpec.DeepCopy() - podTemplateSpec := v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *submitterPodSpec, - } + podTemplateSpec := v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *submitterPodSpec, + } - cfg := config.GetK8sPluginConfig() - podTemplateSpec.SetLabels(utils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) - podTemplateSpec.SetAnnotations(utils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) - return podTemplateSpec + cfg := config.GetK8sPluginConfig() + podTemplateSpec.SetLabels(utils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) + podTemplateSpec.SetAnnotations(utils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) + return podTemplateSpec } func buildWorkerPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, spec *plugins.WorkerGroupSpec) (v1.PodTemplateSpec, error) { - // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185 - // They should always be the same, so we could hard code here. - - primaryContainer.Name = "ray-worker" - - primaryContainer.Args = []string{} - - envs := []v1.EnvVar{ - { - Name: "RAY_DISABLE_DOCKER_CPU_WARNING", - Value: "1", - }, - { - Name: "TYPE", - Value: "worker", - }, - { - Name: "CPU_REQUEST", - ValueFrom: &v1.EnvVarSource{ - ResourceFieldRef: &v1.ResourceFieldSelector{ - ContainerName: "ray-worker", - Resource: "requests.cpu", - }, - }, - }, - { - Name: "CPU_LIMITS", - ValueFrom: &v1.EnvVarSource{ - ResourceFieldRef: &v1.ResourceFieldSelector{ - ContainerName: "ray-worker", - Resource: "limits.cpu", - }, - }, - }, - { - Name: "MEMORY_REQUESTS", - ValueFrom: &v1.EnvVarSource{ - ResourceFieldRef: &v1.ResourceFieldSelector{ - ContainerName: "ray-worker", - Resource: "requests.cpu", - }, - }, - }, - { - Name: "MEMORY_LIMITS", - ValueFrom: &v1.EnvVarSource{ - ResourceFieldRef: &v1.ResourceFieldSelector{ - ContainerName: "ray-worker", - Resource: "limits.cpu", - }, - }, - }, - { - Name: "MY_POD_NAME", - ValueFrom: &v1.EnvVarSource{ - FieldRef: &v1.ObjectFieldSelector{ - FieldPath: "metadata.name", - }, - }, - }, - { - Name: "MY_POD_IP", - ValueFrom: &v1.EnvVarSource{ - FieldRef: &v1.ObjectFieldSelector{ - FieldPath: "status.podIP", - }, - }, - }, - } - - primaryContainer.Env = append(primaryContainer.Env, envs...) - - primaryContainer.Lifecycle = &v1.Lifecycle{ - PreStop: &v1.LifecycleHandler{ - Exec: &v1.ExecAction{ - Command: []string{ - "/bin/sh", "-c", "ray stop", - }, - }, - }, - } - - ports := []v1.ContainerPort{ - { - Name: "redis", - ContainerPort: 6379, - }, - { - Name: "head", - ContainerPort: 10001, - }, - { - Name: "dashboard", - ContainerPort: 8265, - }, - } - primaryContainer.Ports = append(primaryContainer.Ports, ports...) - - // Overwrite worker pod taskResources if specified - if spec.Resources != nil { - res, err := flytek8s.ToK8sResourceRequirements(spec.Resources) - if err != nil { - return v1.PodTemplateSpec{}, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on WorkerGroupSpec Resources[%v], Err: [%v]", spec.Resources, err.Error()) - } - - primaryContainer.Resources = *res - } - - podTemplateSpec := v1.PodTemplateSpec{ - Spec: *podSpec, - ObjectMeta: *objectMetadata, - } - podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) - podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) - return podTemplateSpec, nil + // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185 + // They should always be the same, so we could hard code here. + + primaryContainer.Name = "ray-worker" + + primaryContainer.Args = []string{} + + envs := []v1.EnvVar{ + { + Name: "RAY_DISABLE_DOCKER_CPU_WARNING", + Value: "1", + }, + { + Name: "TYPE", + Value: "worker", + }, + { + Name: "CPU_REQUEST", + ValueFrom: &v1.EnvVarSource{ + ResourceFieldRef: &v1.ResourceFieldSelector{ + ContainerName: "ray-worker", + Resource: "requests.cpu", + }, + }, + }, + { + Name: "CPU_LIMITS", + ValueFrom: &v1.EnvVarSource{ + ResourceFieldRef: &v1.ResourceFieldSelector{ + ContainerName: "ray-worker", + Resource: "limits.cpu", + }, + }, + }, + { + Name: "MEMORY_REQUESTS", + ValueFrom: &v1.EnvVarSource{ + ResourceFieldRef: &v1.ResourceFieldSelector{ + ContainerName: "ray-worker", + Resource: "requests.cpu", + }, + }, + }, + { + Name: "MEMORY_LIMITS", + ValueFrom: &v1.EnvVarSource{ + ResourceFieldRef: &v1.ResourceFieldSelector{ + ContainerName: "ray-worker", + Resource: "limits.cpu", + }, + }, + }, + { + Name: "MY_POD_NAME", + ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + FieldPath: "metadata.name", + }, + }, + }, + { + Name: "MY_POD_IP", + ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + FieldPath: "status.podIP", + }, + }, + }, + } + + primaryContainer.Env = append(primaryContainer.Env, envs...) + + primaryContainer.Lifecycle = &v1.Lifecycle{ + PreStop: &v1.LifecycleHandler{ + Exec: &v1.ExecAction{ + Command: []string{ + "/bin/sh", "-c", "ray stop", + }, + }, + }, + } + + ports := []v1.ContainerPort{ + { + Name: "redis", + ContainerPort: 6379, + }, + { + Name: "head", + ContainerPort: 10001, + }, + { + Name: "dashboard", + ContainerPort: 8265, + }, + } + primaryContainer.Ports = append(primaryContainer.Ports, ports...) + + // Overwrite worker pod taskResources if specified + if spec.Resources != nil { + res, err := flytek8s.ToK8sResourceRequirements(spec.Resources) + if err != nil { + return v1.PodTemplateSpec{}, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on WorkerGroupSpec Resources[%v], Err: [%v]", spec.Resources, err.Error()) + } + + primaryContainer.Resources = *res + } + + podTemplateSpec := v1.PodTemplateSpec{ + Spec: *podSpec, + ObjectMeta: *objectMetadata, + } + podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) + podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) + return podTemplateSpec, nil } func (rayJobResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionMetadata) (client.Object, error) { - return &rayv1.RayJob{ - TypeMeta: metav1.TypeMeta{ - Kind: KindRayJob, - APIVersion: rayv1.SchemeGroupVersion.String(), - }, - }, nil + return &rayv1.RayJob{ + TypeMeta: metav1.TypeMeta{ + Kind: KindRayJob, + APIVersion: rayv1.SchemeGroupVersion.String(), + }, + }, nil } func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginContext, rayJob *rayv1.RayJob) (*pluginsCore.TaskInfo, error) { - logPlugin, err := logs.InitializeLogPlugins(&logConfig) - if err != nil { - return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err) - } - - var taskLogs []*core.TaskLog - - taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() - input := tasklog.Input{ - Namespace: rayJob.Namespace, - TaskExecutionID: taskExecID, - ExtraTemplateVars: []tasklog.TemplateVar{}, - } - if rayJob.Status.JobId != "" { - input.ExtraTemplateVars = append( - input.ExtraTemplateVars, - tasklog.TemplateVar{ - Regex: logTemplateRegexes.RayJobID, - Value: rayJob.Status.JobId, - }, - ) - } - if rayJob.Status.RayClusterName != "" { - input.ExtraTemplateVars = append( - input.ExtraTemplateVars, - tasklog.TemplateVar{ - Regex: logTemplateRegexes.RayClusterName, - Value: rayJob.Status.RayClusterName, - }, - ) - } - - // TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs - // RayJob CRD does not include the name of the worker or head pod for now - logOutput, err := logPlugin.GetTaskLogs(input) - if err != nil { - return nil, fmt.Errorf("failed to generate task logs. Error: %w", err) - } - taskLogs = append(taskLogs, logOutput.TaskLogs...) - - // Handling for Ray Dashboard - dashboardURLTemplate := GetConfig().DashboardURLTemplate - if dashboardURLTemplate != nil && - rayJob.Status.DashboardURL != "" && - rayJob.Status.JobStatus == rayv1.JobStatusRunning { - dashboardURLOutput, err := dashboardURLTemplate.GetTaskLogs(input) - if err != nil { - return nil, fmt.Errorf("failed to generate Ray dashboard link. Error: %w", err) - } - taskLogs = append(taskLogs, dashboardURLOutput.TaskLogs...) - } - - return &pluginsCore.TaskInfo{Logs: taskLogs}, nil + logPlugin, err := logs.InitializeLogPlugins(&logConfig) + if err != nil { + return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err) + } + + var taskLogs []*core.TaskLog + + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() + input := tasklog.Input{ + Namespace: rayJob.Namespace, + TaskExecutionID: taskExecID, + ExtraTemplateVars: []tasklog.TemplateVar{}, + } + if rayJob.Status.JobId != "" { + input.ExtraTemplateVars = append( + input.ExtraTemplateVars, + tasklog.TemplateVar{ + Regex: logTemplateRegexes.RayJobID, + Value: rayJob.Status.JobId, + }, + ) + } + if rayJob.Status.RayClusterName != "" { + input.ExtraTemplateVars = append( + input.ExtraTemplateVars, + tasklog.TemplateVar{ + Regex: logTemplateRegexes.RayClusterName, + Value: rayJob.Status.RayClusterName, + }, + ) + } + + // TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs + // RayJob CRD does not include the name of the worker or head pod for now + logOutput, err := logPlugin.GetTaskLogs(input) + if err != nil { + return nil, fmt.Errorf("failed to generate task logs. Error: %w", err) + } + taskLogs = append(taskLogs, logOutput.TaskLogs...) + + // Handling for Ray Dashboard + dashboardURLTemplate := GetConfig().DashboardURLTemplate + if dashboardURLTemplate != nil && + rayJob.Status.DashboardURL != "" && + rayJob.Status.JobStatus == rayv1.JobStatusRunning { + dashboardURLOutput, err := dashboardURLTemplate.GetTaskLogs(input) + if err != nil { + return nil, fmt.Errorf("failed to generate Ray dashboard link. Error: %w", err) + } + taskLogs = append(taskLogs, dashboardURLOutput.TaskLogs...) + } + + return &pluginsCore.TaskInfo{Logs: taskLogs}, nil } func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { - rayJob := resource.(*rayv1.RayJob) - info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob) - if err != nil { - return pluginsCore.PhaseInfoUndefined, err - } - - if len(rayJob.Status.JobDeploymentStatus) == 0 { - return pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling", info), nil - } - - var phaseInfo pluginsCore.PhaseInfo - - // KubeRay creates a Ray cluster first, and then submits a Ray job to the cluster - switch rayJob.Status.JobDeploymentStatus { - case rayv1.JobDeploymentStatusInitializing: - phaseInfo, err = pluginsCore.PhaseInfoInitializing(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "cluster is creating", info), nil - case rayv1.JobDeploymentStatusRunning: - phaseInfo, err = pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil - case rayv1.JobDeploymentStatusComplete: - phaseInfo, err = pluginsCore.PhaseInfoSuccess(info), nil - case rayv1.JobDeploymentStatusFailed: - failInfo := fmt.Sprintf("Failed to run Ray job %s with error: [%s] %s", rayJob.Name, rayJob.Status.Reason, rayJob.Status.Message) - phaseInfo, err = pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, failInfo, info), nil - default: - // We already handle all known deployment status, so this should never happen unless a future version of ray - // introduced a new job status. - phaseInfo, err = pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job deployment status: %s", rayJob.Status.JobDeploymentStatus) - } - - phaseVersionUpdateErr := k8s.MaybeUpdatePhaseVersionFromPluginContext(&phaseInfo, &pluginContext) - if phaseVersionUpdateErr != nil { - return phaseInfo, phaseVersionUpdateErr - } - - return phaseInfo, err + rayJob := resource.(*rayv1.RayJob) + info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + if len(rayJob.Status.JobDeploymentStatus) == 0 { + return pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling", info), nil + } + + var phaseInfo pluginsCore.PhaseInfo + + // KubeRay creates a Ray cluster first, and then submits a Ray job to the cluster + switch rayJob.Status.JobDeploymentStatus { + case rayv1.JobDeploymentStatusInitializing: + phaseInfo, err = pluginsCore.PhaseInfoInitializing(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "cluster is creating", info), nil + case rayv1.JobDeploymentStatusRunning: + phaseInfo, err = pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil + case rayv1.JobDeploymentStatusComplete: + phaseInfo, err = pluginsCore.PhaseInfoSuccess(info), nil + case rayv1.JobDeploymentStatusFailed: + failInfo := fmt.Sprintf("Failed to run Ray job %s with error: [%s] %s", rayJob.Name, rayJob.Status.Reason, rayJob.Status.Message) + phaseInfo, err = pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, failInfo, info), nil + default: + // We already handle all known deployment status, so this should never happen unless a future version of ray + // introduced a new job status. + phaseInfo, err = pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job deployment status: %s", rayJob.Status.JobDeploymentStatus) + } + + phaseVersionUpdateErr := k8s.MaybeUpdatePhaseVersionFromPluginContext(&phaseInfo, &pluginContext) + if phaseVersionUpdateErr != nil { + return phaseInfo, phaseVersionUpdateErr + } + + return phaseInfo, err } func init() { - if err := rayv1.AddToScheme(scheme.Scheme); err != nil { - panic(err) - } - - pluginmachinery.PluginRegistry().RegisterK8sPlugin( - k8s.PluginEntry{ - ID: rayTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{rayTaskType}, - ResourceToWatch: &rayv1.RayJob{}, - Plugin: rayJobResourceHandler{}, - IsDefault: false, - CustomKubeClient: func(ctx context.Context) (pluginsCore.KubeClient, error) { - remoteConfig := GetConfig().RemoteClusterConfig - if !remoteConfig.Enabled { - // use controller-runtime KubeClient - return nil, nil - } - - kubeConfig, err := k8s.KubeClientConfig(remoteConfig.Endpoint, remoteConfig.Auth) - if err != nil { - return nil, err - } - - return k8s.NewDefaultKubeClient(kubeConfig) - }, - }) + if err := rayv1.AddToScheme(scheme.Scheme); err != nil { + panic(err) + } + + pluginmachinery.PluginRegistry().RegisterK8sPlugin( + k8s.PluginEntry{ + ID: rayTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{rayTaskType}, + ResourceToWatch: &rayv1.RayJob{}, + Plugin: rayJobResourceHandler{}, + IsDefault: false, + CustomKubeClient: func(ctx context.Context) (pluginsCore.KubeClient, error) { + remoteConfig := GetConfig().RemoteClusterConfig + if !remoteConfig.Enabled { + // use controller-runtime KubeClient + return nil, nil + } + + kubeConfig, err := k8s.KubeClientConfig(remoteConfig.Endpoint, remoteConfig.Auth) + if err != nil { + return nil, err + } + + return k8s.NewDefaultKubeClient(kubeConfig) + }, + }) } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index d2e35fb67d..73fd7a9e92 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -1,1201 +1,1201 @@ package ray import ( - "context" - "reflect" - "testing" - "time" - - structpb "github.com/golang/protobuf/ptypes/struct" - rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" - pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" - pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" - mocks2 "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + "context" + "reflect" + "testing" + "time" + + structpb "github.com/golang/protobuf/ptypes/struct" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + mocks2 "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" ) const ( - testImage = "image://" - serviceAccount = "ray_sa" + testImage = "image://" + serviceAccount = "ray_sa" ) var ( - dummyEnvVars = []*core.KeyValuePair{ - {Key: "Env_Var", Value: "Env_Val"}, - } - - testArgs = []string{ - "test-args", - } - - resourceRequirements = &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1000m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("100m"), - corev1.ResourceMemory: resource.MustParse("512Mi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - } - - workerGroupName = "worker-group" + dummyEnvVars = []*core.KeyValuePair{ + {Key: "Env_Var", Value: "Env_Val"}, + } + + testArgs = []string{ + "test-args", + } + + resourceRequirements = &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1000m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("100m"), + corev1.ResourceMemory: resource.MustParse("512Mi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + } + + workerGroupName = "worker-group" ) func transformRayJobToCustomObj(rayJob *plugins.RayJob) *structpb.Struct { - structObj, err := utils.MarshalObjToStruct(rayJob) - if err != nil { - panic(err) - } - return structObj + structObj, err := utils.MarshalObjToStruct(rayJob) + if err != nil { + panic(err) + } + return structObj } func transformPodSpecToTaskTemplateTarget(podSpec *corev1.PodSpec) *core.TaskTemplate_K8SPod { - structObj, err := utils.MarshalObjToStruct(&podSpec) - if err != nil { - panic(err) - } - return &core.TaskTemplate_K8SPod{ - K8SPod: &core.K8SPod{ - PodSpec: structObj, - }, - } + structObj, err := utils.MarshalObjToStruct(&podSpec) + if err != nil { + panic(err) + } + return &core.TaskTemplate_K8SPod{ + K8SPod: &core.K8SPod{ + PodSpec: structObj, + }, + } } func dummyRayCustomObj() *plugins.RayJob { - return &plugins.RayJob{ - RayCluster: &plugins.RayCluster{ - HeadGroupSpec: &plugins.HeadGroupSpec{RayStartParams: map[string]string{"num-cpus": "1"}}, - WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, - EnableAutoscaling: true, - }, - ShutdownAfterJobFinishes: true, - TtlSecondsAfterFinished: 120, - } + return &plugins.RayJob{ + RayCluster: &plugins.RayCluster{ + HeadGroupSpec: &plugins.HeadGroupSpec{RayStartParams: map[string]string{"num-cpus": "1"}}, + WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, + EnableAutoscaling: true, + }, + ShutdownAfterJobFinishes: true, + TtlSecondsAfterFinished: 120, + } } func dummyRayTaskTemplate(id string, rayJob *plugins.RayJob) *core.TaskTemplate { - return &core.TaskTemplate{ - Id: &core.Identifier{Name: id}, - Type: "container", - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Image: testImage, - Args: testArgs, - Env: dummyEnvVars, - }, - }, - Custom: transformRayJobToCustomObj(rayJob), - } + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "container", + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + Args: testArgs, + Env: dummyEnvVars, + }, + }, + Custom: transformRayJobToCustomObj(rayJob), + } } func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage, serviceAccount string) pluginsCore.TaskExecutionContext { - taskCtx := &mocks.TaskExecutionContext{} - inputReader := &pluginIOMocks.InputReader{} - inputReader.OnGetInputPrefixPath().Return("/input/prefix") - inputReader.OnGetInputPath().Return("/input") - inputReader.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) - taskCtx.OnInputReader().Return(inputReader) - - outputReader := &pluginIOMocks.OutputWriter{} - outputReader.OnGetOutputPath().Return("/data/outputs.pb") - outputReader.OnGetOutputPrefixPath().Return("/data/") - outputReader.OnGetRawOutputPrefix().Return("") - outputReader.OnGetCheckpointPrefix().Return("/checkpoint") - outputReader.OnGetPreviousCheckpointsPrefix().Return("/prev") - taskCtx.OnOutputWriter().Return(outputReader) - - taskReader := &mocks.TaskReader{} - taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) - taskCtx.OnTaskReader().Return(taskReader) - - tID := &mocks.TaskExecutionID{} - tID.OnGetID().Return(core.TaskExecutionIdentifier{ - NodeExecutionId: &core.NodeExecutionIdentifier{ - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my_name", - Project: "my_project", - Domain: "my_domain", - }, - }, - }) - tID.OnGetGeneratedName().Return("some-acceptable-name") - - overrides := &mocks.TaskOverrides{} - overrides.OnGetResources().Return(resources) - overrides.OnGetExtendedResources().Return(extendedResources) - overrides.OnGetContainerImage().Return(containerImage) - - taskExecutionMetadata := &mocks.TaskExecutionMetadata{} - taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) - taskExecutionMetadata.OnGetNamespace().Return("test-namespace") - taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"}) - taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) - taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ - Kind: "node", - Name: "blah", - }) - taskExecutionMetadata.OnIsInterruptible().Return(true) - taskExecutionMetadata.OnGetOverrides().Return(overrides) - taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) - taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) - taskExecutionMetadata.OnGetSecurityContext().Return(core.SecurityContext{ - RunAs: &core.Identity{K8SServiceAccount: serviceAccount}, - }) - taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) - taskExecutionMetadata.OnGetConsoleURL().Return("") - taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) - return taskCtx + taskCtx := &mocks.TaskExecutionContext{} + inputReader := &pluginIOMocks.InputReader{} + inputReader.OnGetInputPrefixPath().Return("/input/prefix") + inputReader.OnGetInputPath().Return("/input") + inputReader.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) + taskCtx.OnInputReader().Return(inputReader) + + outputReader := &pluginIOMocks.OutputWriter{} + outputReader.OnGetOutputPath().Return("/data/outputs.pb") + outputReader.OnGetOutputPrefixPath().Return("/data/") + outputReader.OnGetRawOutputPrefix().Return("") + outputReader.OnGetCheckpointPrefix().Return("/checkpoint") + outputReader.OnGetPreviousCheckpointsPrefix().Return("/prev") + taskCtx.OnOutputWriter().Return(outputReader) + + taskReader := &mocks.TaskReader{} + taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) + taskCtx.OnTaskReader().Return(taskReader) + + tID := &mocks.TaskExecutionID{} + tID.OnGetID().Return(core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + tID.OnGetGeneratedName().Return("some-acceptable-name") + + overrides := &mocks.TaskOverrides{} + overrides.OnGetResources().Return(resources) + overrides.OnGetExtendedResources().Return(extendedResources) + overrides.OnGetContainerImage().Return(containerImage) + + taskExecutionMetadata := &mocks.TaskExecutionMetadata{} + taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) + taskExecutionMetadata.OnGetNamespace().Return("test-namespace") + taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"}) + taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) + taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ + Kind: "node", + Name: "blah", + }) + taskExecutionMetadata.OnIsInterruptible().Return(true) + taskExecutionMetadata.OnGetOverrides().Return(overrides) + taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) + taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) + taskExecutionMetadata.OnGetSecurityContext().Return(core.SecurityContext{ + RunAs: &core.Identity{K8SServiceAccount: serviceAccount}, + }) + taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) + taskExecutionMetadata.OnGetConsoleURL().Return("") + taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + return taskCtx } func TestBuildResourceRay(t *testing.T) { - rayJobResourceHandler := rayJobResourceHandler{} - taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) - toleration := []corev1.Toleration{{ - Key: "storage", - Value: "dedicated", - Operator: corev1.TolerationOpExists, - Effect: corev1.TaintEffectNoSchedule, - }} - err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) - assert.Nil(t, err) - - rayCtx := dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount) - RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), rayCtx) - assert.Nil(t, err) - - assert.NotNil(t, RayResource) - ray, ok := RayResource.(*rayv1.RayJob) - assert.True(t, ok) - - assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, true) - assert.Equal(t, ray.Spec.ShutdownAfterJobFinishes, true) - assert.Equal(t, ray.Spec.TTLSecondsAfterFinished, int32(120)) - - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, - map[string]string{ - "dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true", - "node-ip-address": "$MY_POD_IP", "num-cpus": "1", - }) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) - - workerReplica := int32(3) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, workerReplica) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas, workerReplica) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, workerReplica) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) - - // Make sure the default service account is being used if SA is not provided in the task context - rayCtx = dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", "") - RayResource, err = rayJobResourceHandler.BuildResource(context.TODO(), rayCtx) - assert.Nil(t, err) - assert.NotNil(t, RayResource) - ray, ok = RayResource.(*rayv1.RayJob) - assert.True(t, ok) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, GetConfig().ServiceAccount) + rayJobResourceHandler := rayJobResourceHandler{} + taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) + toleration := []corev1.Toleration{{ + Key: "storage", + Value: "dedicated", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }} + err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) + assert.Nil(t, err) + + rayCtx := dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount) + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), rayCtx) + assert.Nil(t, err) + + assert.NotNil(t, RayResource) + ray, ok := RayResource.(*rayv1.RayJob) + assert.True(t, ok) + + assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, true) + assert.Equal(t, ray.Spec.ShutdownAfterJobFinishes, true) + assert.Equal(t, ray.Spec.TTLSecondsAfterFinished, int32(120)) + + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, + map[string]string{ + "dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true", + "node-ip-address": "$MY_POD_IP", "num-cpus": "1", + }) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) + + workerReplica := int32(3) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, workerReplica) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas, workerReplica) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, workerReplica) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) + + // Make sure the default service account is being used if SA is not provided in the task context + rayCtx = dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", "") + RayResource, err = rayJobResourceHandler.BuildResource(context.TODO(), rayCtx) + assert.Nil(t, err) + assert.NotNil(t, RayResource) + ray, ok = RayResource.(*rayv1.RayJob) + assert.True(t, ok) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, GetConfig().ServiceAccount) } func TestBuildResourceRayContainerImage(t *testing.T) { - assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) - - fixtures := []struct { - name string - resources *corev1.ResourceRequirements - containerImageOverride string - }{ - { - "without overrides", - &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - }, - "", - }, - { - "with overrides", - &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - }, - "container-image-override", - }, - } - - for _, f := range fixtures { - t.Run(f.name, func(t *testing.T) { - taskTemplate := dummyRayTaskTemplate("id", dummyRayCustomObj()) - taskContext := dummyRayTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride, serviceAccount) - rayJobResourceHandler := rayJobResourceHandler{} - r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) - assert.Nil(t, err) - assert.NotNil(t, r) - rayJob, ok := r.(*rayv1.RayJob) - assert.True(t, ok) - - var expectedContainerImage string - if len(f.containerImageOverride) > 0 { - expectedContainerImage = f.containerImageOverride - } else { - expectedContainerImage = testImage - } - - // Head node - headNodeSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec - assert.Equal(t, expectedContainerImage, headNodeSpec.Containers[0].Image) - - // Worker node - workerNodeSpec := rayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec - assert.Equal(t, expectedContainerImage, workerNodeSpec.Containers[0].Image) - }) - } + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) + + fixtures := []struct { + name string + resources *corev1.ResourceRequirements + containerImageOverride string + }{ + { + "without overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "", + }, + { + "with overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "container-image-override", + }, + } + + for _, f := range fixtures { + t.Run(f.name, func(t *testing.T) { + taskTemplate := dummyRayTaskTemplate("id", dummyRayCustomObj()) + taskContext := dummyRayTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride, serviceAccount) + rayJobResourceHandler := rayJobResourceHandler{} + r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + rayJob, ok := r.(*rayv1.RayJob) + assert.True(t, ok) + + var expectedContainerImage string + if len(f.containerImageOverride) > 0 { + expectedContainerImage = f.containerImageOverride + } else { + expectedContainerImage = testImage + } + + // Head node + headNodeSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec + assert.Equal(t, expectedContainerImage, headNodeSpec.Containers[0].Image) + + // Worker node + workerNodeSpec := rayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec + assert.Equal(t, expectedContainerImage, workerNodeSpec.Containers[0].Image) + }) + } } func TestBuildResourceRayExtendedResources(t *testing.T) { - assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ - GpuDeviceNodeLabel: "gpu-node-label", - GpuPartitionSizeNodeLabel: "gpu-partition-size", - GpuResourceName: flytek8s.ResourceNvidiaGPU, - })) - - params := []struct { - name string - resources *corev1.ResourceRequirements - extendedResourcesBase *core.ExtendedResources - extendedResourcesOverride *core.ExtendedResources - expectedNsr []corev1.NodeSelectorTerm - expectedTol []corev1.Toleration - }{ - { - "without overrides", - &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - }, - &core.ExtendedResources{ - GpuAccelerator: &core.GPUAccelerator{ - Device: "nvidia-tesla-t4", - }, - }, - nil, - []corev1.NodeSelectorTerm{ - { - MatchExpressions: []corev1.NodeSelectorRequirement{ - { - Key: "gpu-node-label", - Operator: corev1.NodeSelectorOpIn, - Values: []string{"nvidia-tesla-t4"}, - }, - }, - }, - }, - []corev1.Toleration{ - { - Key: "gpu-node-label", - Value: "nvidia-tesla-t4", - Operator: corev1.TolerationOpEqual, - Effect: corev1.TaintEffectNoSchedule, - }, - }, - }, - { - "with overrides", - &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - }, - &core.ExtendedResources{ - GpuAccelerator: &core.GPUAccelerator{ - Device: "nvidia-tesla-t4", - }, - }, - &core.ExtendedResources{ - GpuAccelerator: &core.GPUAccelerator{ - Device: "nvidia-tesla-a100", - PartitionSizeValue: &core.GPUAccelerator_PartitionSize{ - PartitionSize: "1g.5gb", - }, - }, - }, - []corev1.NodeSelectorTerm{ - { - MatchExpressions: []corev1.NodeSelectorRequirement{ - { - Key: "gpu-node-label", - Operator: corev1.NodeSelectorOpIn, - Values: []string{"nvidia-tesla-a100"}, - }, - { - Key: "gpu-partition-size", - Operator: corev1.NodeSelectorOpIn, - Values: []string{"1g.5gb"}, - }, - }, - }, - }, - []corev1.Toleration{ - { - Key: "gpu-node-label", - Value: "nvidia-tesla-a100", - Operator: corev1.TolerationOpEqual, - Effect: corev1.TaintEffectNoSchedule, - }, - { - Key: "gpu-partition-size", - Value: "1g.5gb", - Operator: corev1.TolerationOpEqual, - Effect: corev1.TaintEffectNoSchedule, - }, - }, - }, - } - - for _, p := range params { - t.Run(p.name, func(t *testing.T) { - taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) - taskTemplate.ExtendedResources = p.extendedResourcesBase - taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride, "", serviceAccount) - rayJobResourceHandler := rayJobResourceHandler{} - r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) - assert.Nil(t, err) - assert.NotNil(t, r) - rayJob, ok := r.(*rayv1.RayJob) - assert.True(t, ok) - - // Head node - headNodeSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec - assert.EqualValues( - t, - p.expectedNsr, - headNodeSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, - ) - assert.EqualValues( - t, - p.expectedTol, - headNodeSpec.Tolerations, - ) - - // Worker node - workerNodeSpec := rayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec - assert.EqualValues( - t, - p.expectedNsr, - workerNodeSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, - ) - assert.EqualValues( - t, - p.expectedTol, - workerNodeSpec.Tolerations, - ) - }) - } + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + GpuDeviceNodeLabel: "gpu-node-label", + GpuPartitionSizeNodeLabel: "gpu-partition-size", + GpuResourceName: flytek8s.ResourceNvidiaGPU, + })) + + params := []struct { + name string + resources *corev1.ResourceRequirements + extendedResourcesBase *core.ExtendedResources + extendedResourcesOverride *core.ExtendedResources + expectedNsr []corev1.NodeSelectorTerm + expectedTol []corev1.Toleration + }{ + { + "without overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + &core.ExtendedResources{ + GpuAccelerator: &core.GPUAccelerator{ + Device: "nvidia-tesla-t4", + }, + }, + nil, + []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "gpu-node-label", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"nvidia-tesla-t4"}, + }, + }, + }, + }, + []corev1.Toleration{ + { + Key: "gpu-node-label", + Value: "nvidia-tesla-t4", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + { + "with overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + &core.ExtendedResources{ + GpuAccelerator: &core.GPUAccelerator{ + Device: "nvidia-tesla-t4", + }, + }, + &core.ExtendedResources{ + GpuAccelerator: &core.GPUAccelerator{ + Device: "nvidia-tesla-a100", + PartitionSizeValue: &core.GPUAccelerator_PartitionSize{ + PartitionSize: "1g.5gb", + }, + }, + }, + []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "gpu-node-label", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"nvidia-tesla-a100"}, + }, + { + Key: "gpu-partition-size", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"1g.5gb"}, + }, + }, + }, + }, + []corev1.Toleration{ + { + Key: "gpu-node-label", + Value: "nvidia-tesla-a100", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + { + Key: "gpu-partition-size", + Value: "1g.5gb", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + } + + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) + taskTemplate.ExtendedResources = p.extendedResourcesBase + taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride, "", serviceAccount) + rayJobResourceHandler := rayJobResourceHandler{} + r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + rayJob, ok := r.(*rayv1.RayJob) + assert.True(t, ok) + + // Head node + headNodeSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec + assert.EqualValues( + t, + p.expectedNsr, + headNodeSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, + ) + assert.EqualValues( + t, + p.expectedTol, + headNodeSpec.Tolerations, + ) + + // Worker node + workerNodeSpec := rayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec + assert.EqualValues( + t, + p.expectedNsr, + workerNodeSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, + ) + assert.EqualValues( + t, + p.expectedTol, + workerNodeSpec.Tolerations, + ) + }) + } } func TestBuildResourceRayCustomResources(t *testing.T) { - assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) - - headResourceEntries := []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "10"}, - {Name: core.Resources_MEMORY, Value: "10Gi"}, - {Name: core.Resources_GPU, Value: "10"}, - } - headResources := &core.Resources{Requests: headResourceEntries, Limits: headResourceEntries} - - expectedHeadResources, err := flytek8s.ToK8sResourceRequirements(headResources) - require.NoError(t, err) - - workerResourceEntries := []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "20"}, - {Name: core.Resources_MEMORY, Value: "20Gi"}, - {Name: core.Resources_GPU, Value: "20"}, - } - workerResources := &core.Resources{Requests: workerResourceEntries, Limits: workerResourceEntries} - - expectedWorkerResources, err := flytek8s.ToK8sResourceRequirements(workerResources) - require.NoError(t, err) - - params := []struct { - name string - taskResources *corev1.ResourceRequirements - headResources *core.Resources - workerResources *core.Resources - expectedSubmitterResources *corev1.ResourceRequirements - expectedHeadResources *corev1.ResourceRequirements - expectedWorkerResources *corev1.ResourceRequirements - }{ - { - name: "task resources", - taskResources: resourceRequirements, - expectedSubmitterResources: resourceRequirements, - expectedHeadResources: resourceRequirements, - expectedWorkerResources: resourceRequirements, - }, - { - name: "custom worker and head resources", - taskResources: resourceRequirements, - headResources: headResources, - workerResources: workerResources, - expectedSubmitterResources: resourceRequirements, - expectedHeadResources: expectedHeadResources, - expectedWorkerResources: expectedWorkerResources, - }, - } - - for _, p := range params { - t.Run(p.name, func(t *testing.T) { - rayJobInput := dummyRayCustomObj() - - if p.headResources != nil { - rayJobInput.RayCluster.HeadGroupSpec.Resources = p.headResources - } - - if p.workerResources != nil { - for _, spec := range rayJobInput.RayCluster.WorkerGroupSpec { - spec.Resources = p.workerResources - } - } - - taskTemplate := dummyRayTaskTemplate("ray-id", rayJobInput) - taskContext := dummyRayTaskContext(taskTemplate, p.taskResources, nil, "", serviceAccount) - rayJobResourceHandler := rayJobResourceHandler{} - r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) - assert.Nil(t, err) - assert.NotNil(t, r) - rayJob, ok := r.(*rayv1.RayJob) - assert.True(t, ok) - - submitterPodResources := rayJob.Spec.SubmitterPodTemplate.Spec.Containers[0].Resources - assert.EqualValues(t, - p.expectedSubmitterResources, - &submitterPodResources, - ) - - headPodResources := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Containers[0].Resources - assert.EqualValues(t, - p.expectedHeadResources, - &headPodResources, - ) - - for _, workerGroupSpec := range rayJob.Spec.RayClusterSpec.WorkerGroupSpecs { - workerPodResources := workerGroupSpec.Template.Spec.Containers[0].Resources - assert.EqualValues(t, - p.expectedWorkerResources, - &workerPodResources, - ) - } - }) - } + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) + + headResourceEntries := []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "10"}, + {Name: core.Resources_MEMORY, Value: "10Gi"}, + {Name: core.Resources_GPU, Value: "10"}, + } + headResources := &core.Resources{Requests: headResourceEntries, Limits: headResourceEntries} + + expectedHeadResources, err := flytek8s.ToK8sResourceRequirements(headResources) + require.NoError(t, err) + + workerResourceEntries := []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "20"}, + {Name: core.Resources_MEMORY, Value: "20Gi"}, + {Name: core.Resources_GPU, Value: "20"}, + } + workerResources := &core.Resources{Requests: workerResourceEntries, Limits: workerResourceEntries} + + expectedWorkerResources, err := flytek8s.ToK8sResourceRequirements(workerResources) + require.NoError(t, err) + + params := []struct { + name string + taskResources *corev1.ResourceRequirements + headResources *core.Resources + workerResources *core.Resources + expectedSubmitterResources *corev1.ResourceRequirements + expectedHeadResources *corev1.ResourceRequirements + expectedWorkerResources *corev1.ResourceRequirements + }{ + { + name: "task resources", + taskResources: resourceRequirements, + expectedSubmitterResources: resourceRequirements, + expectedHeadResources: resourceRequirements, + expectedWorkerResources: resourceRequirements, + }, + { + name: "custom worker and head resources", + taskResources: resourceRequirements, + headResources: headResources, + workerResources: workerResources, + expectedSubmitterResources: resourceRequirements, + expectedHeadResources: expectedHeadResources, + expectedWorkerResources: expectedWorkerResources, + }, + } + + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + rayJobInput := dummyRayCustomObj() + + if p.headResources != nil { + rayJobInput.RayCluster.HeadGroupSpec.Resources = p.headResources + } + + if p.workerResources != nil { + for _, spec := range rayJobInput.RayCluster.WorkerGroupSpec { + spec.Resources = p.workerResources + } + } + + taskTemplate := dummyRayTaskTemplate("ray-id", rayJobInput) + taskContext := dummyRayTaskContext(taskTemplate, p.taskResources, nil, "", serviceAccount) + rayJobResourceHandler := rayJobResourceHandler{} + r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + rayJob, ok := r.(*rayv1.RayJob) + assert.True(t, ok) + + submitterPodResources := rayJob.Spec.SubmitterPodTemplate.Spec.Containers[0].Resources + assert.EqualValues(t, + p.expectedSubmitterResources, + &submitterPodResources, + ) + + headPodResources := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Containers[0].Resources + assert.EqualValues(t, + p.expectedHeadResources, + &headPodResources, + ) + + for _, workerGroupSpec := range rayJob.Spec.RayClusterSpec.WorkerGroupSpecs { + workerPodResources := workerGroupSpec.Template.Spec.Containers[0].Resources + assert.EqualValues(t, + p.expectedWorkerResources, + &workerPodResources, + ) + } + }) + } } func TestDefaultStartParameters(t *testing.T) { - rayJobResourceHandler := rayJobResourceHandler{} - rayJob := &plugins.RayJob{ - RayCluster: &plugins.RayCluster{ - HeadGroupSpec: &plugins.HeadGroupSpec{}, - WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, - EnableAutoscaling: true, - }, - ShutdownAfterJobFinishes: true, - TtlSecondsAfterFinished: 120, - } - - taskTemplate := dummyRayTaskTemplate("ray-id", rayJob) - toleration := []corev1.Toleration{{ - Key: "storage", - Value: "dedicated", - Operator: corev1.TolerationOpExists, - Effect: corev1.TaintEffectNoSchedule, - }} - err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) - assert.Nil(t, err) - - RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount)) - assert.Nil(t, err) - - assert.NotNil(t, RayResource) - ray, ok := RayResource.(*rayv1.RayJob) - assert.True(t, ok) - - assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, true) - assert.Equal(t, ray.Spec.ShutdownAfterJobFinishes, true) - assert.Equal(t, ray.Spec.TTLSecondsAfterFinished, int32(120)) - - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, - map[string]string{ - "dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true", - "node-ip-address": "$MY_POD_IP", - }) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) - - workerReplica := int32(3) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, workerReplica) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas, workerReplica) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, workerReplica) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) + rayJobResourceHandler := rayJobResourceHandler{} + rayJob := &plugins.RayJob{ + RayCluster: &plugins.RayCluster{ + HeadGroupSpec: &plugins.HeadGroupSpec{}, + WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, + EnableAutoscaling: true, + }, + ShutdownAfterJobFinishes: true, + TtlSecondsAfterFinished: 120, + } + + taskTemplate := dummyRayTaskTemplate("ray-id", rayJob) + toleration := []corev1.Toleration{{ + Key: "storage", + Value: "dedicated", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }} + err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) + assert.Nil(t, err) + + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount)) + assert.Nil(t, err) + + assert.NotNil(t, RayResource) + ray, ok := RayResource.(*rayv1.RayJob) + assert.True(t, ok) + + assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, true) + assert.Equal(t, ray.Spec.ShutdownAfterJobFinishes, true) + assert.Equal(t, ray.Spec.TTLSecondsAfterFinished, int32(120)) + + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, + map[string]string{ + "dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true", + "node-ip-address": "$MY_POD_IP", + }) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) + + workerReplica := int32(3) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, workerReplica) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas, workerReplica) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, workerReplica) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) } func TestInjectLogsSidecar(t *testing.T) { - rayJobObj := transformRayJobToCustomObj(dummyRayCustomObj()) - params := []struct { - name string - taskTemplate core.TaskTemplate - // primaryContainerName string - logsSidecarCfg *corev1.Container - expectedVolumes []corev1.Volume - expectedPrimaryContainerVolumeMounts []corev1.VolumeMount - expectedLogsSidecarVolumeMounts []corev1.VolumeMount - }{ - { - "container target", - core.TaskTemplate{ - Id: &core.Identifier{Name: "ray-id"}, - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Image: testImage, - Args: testArgs, - }, - }, - Custom: rayJobObj, - }, - &corev1.Container{ - Name: "logs-sidecar", - Image: "test-image", - }, - []corev1.Volume{ - { - Name: "system-ray-state", - VolumeSource: corev1.VolumeSource{ - EmptyDir: &corev1.EmptyDirVolumeSource{}, - }, - }, - }, - []corev1.VolumeMount{ - { - Name: "system-ray-state", - MountPath: "/tmp/ray", - }, - }, - []corev1.VolumeMount{ - { - Name: "system-ray-state", - MountPath: "/tmp/ray", - ReadOnly: true, - }, - }, - }, - { - "container target with no sidecar", - core.TaskTemplate{ - Id: &core.Identifier{Name: "ray-id"}, - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Image: testImage, - Args: testArgs, - }, - }, - Custom: rayJobObj, - }, - nil, - nil, - nil, - nil, - }, - { - "pod target", - core.TaskTemplate{ - Id: &core.Identifier{Name: "ray-id"}, - Target: transformPodSpecToTaskTemplateTarget(&corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "main", - Image: "primary-image", - }, - }, - }), - Custom: rayJobObj, - Config: map[string]string{ - flytek8s.PrimaryContainerKey: "main", - }, - }, - &corev1.Container{ - Name: "logs-sidecar", - Image: "test-image", - }, - []corev1.Volume{ - { - Name: "system-ray-state", - VolumeSource: corev1.VolumeSource{ - EmptyDir: &corev1.EmptyDirVolumeSource{}, - }, - }, - }, - []corev1.VolumeMount{ - { - Name: "system-ray-state", - MountPath: "/tmp/ray", - }, - }, - []corev1.VolumeMount{ - { - Name: "system-ray-state", - MountPath: "/tmp/ray", - ReadOnly: true, - }, - }, - }, - { - "pod target with existing ray state volume", - core.TaskTemplate{ - Id: &core.Identifier{Name: "ray-id"}, - Target: transformPodSpecToTaskTemplateTarget(&corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "main", - Image: "primary-image", - VolumeMounts: []corev1.VolumeMount{ - { - Name: "test-vol", - MountPath: "/tmp/ray", - }, - }, - }, - }, - Volumes: []corev1.Volume{ - { - Name: "test-vol", - VolumeSource: corev1.VolumeSource{ - EmptyDir: &corev1.EmptyDirVolumeSource{}, - }, - }, - }, - }), - Custom: rayJobObj, - Config: map[string]string{ - flytek8s.PrimaryContainerKey: "main", - }, - }, - &corev1.Container{ - Name: "logs-sidecar", - Image: "test-image", - }, - []corev1.Volume{ - { - Name: "test-vol", - VolumeSource: corev1.VolumeSource{ - EmptyDir: &corev1.EmptyDirVolumeSource{}, - }, - }, - }, - []corev1.VolumeMount{ - { - Name: "test-vol", - MountPath: "/tmp/ray", - }, - }, - []corev1.VolumeMount{ - { - Name: "test-vol", - MountPath: "/tmp/ray", - ReadOnly: true, - }, - }, - }, - } - - for _, p := range params { - t.Run(p.name, func(t *testing.T) { - assert.NoError(t, SetConfig(&Config{ - LogsSidecar: p.logsSidecarCfg, - })) - taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil, "", serviceAccount) - rayJobResourceHandler := rayJobResourceHandler{} - r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) - assert.Nil(t, err) - assert.NotNil(t, r) - rayJob, ok := r.(*rayv1.RayJob) - assert.True(t, ok) - - headPodSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec - - // Check volumes - assert.EqualValues(t, p.expectedVolumes, headPodSpec.Volumes) - - // Check containers and respective volume mounts - foundPrimaryContainer := false - foundLogsSidecar := false - for _, cnt := range headPodSpec.Containers { - if cnt.Name == "ray-head" { - foundPrimaryContainer = true - assert.EqualValues( - t, - p.expectedPrimaryContainerVolumeMounts, - cnt.VolumeMounts, - ) - } - if p.logsSidecarCfg != nil && cnt.Name == p.logsSidecarCfg.Name { - foundLogsSidecar = true - assert.EqualValues( - t, - p.expectedLogsSidecarVolumeMounts, - cnt.VolumeMounts, - ) - } - } - assert.Equal(t, true, foundPrimaryContainer) - assert.Equal(t, p.logsSidecarCfg != nil, foundLogsSidecar) - }) - } + rayJobObj := transformRayJobToCustomObj(dummyRayCustomObj()) + params := []struct { + name string + taskTemplate core.TaskTemplate + // primaryContainerName string + logsSidecarCfg *corev1.Container + expectedVolumes []corev1.Volume + expectedPrimaryContainerVolumeMounts []corev1.VolumeMount + expectedLogsSidecarVolumeMounts []corev1.VolumeMount + }{ + { + "container target", + core.TaskTemplate{ + Id: &core.Identifier{Name: "ray-id"}, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + Args: testArgs, + }, + }, + Custom: rayJobObj, + }, + &corev1.Container{ + Name: "logs-sidecar", + Image: "test-image", + }, + []corev1.Volume{ + { + Name: "system-ray-state", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + }, + []corev1.VolumeMount{ + { + Name: "system-ray-state", + MountPath: "/tmp/ray", + }, + }, + []corev1.VolumeMount{ + { + Name: "system-ray-state", + MountPath: "/tmp/ray", + ReadOnly: true, + }, + }, + }, + { + "container target with no sidecar", + core.TaskTemplate{ + Id: &core.Identifier{Name: "ray-id"}, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + Args: testArgs, + }, + }, + Custom: rayJobObj, + }, + nil, + nil, + nil, + nil, + }, + { + "pod target", + core.TaskTemplate{ + Id: &core.Identifier{Name: "ray-id"}, + Target: transformPodSpecToTaskTemplateTarget(&corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "main", + Image: "primary-image", + }, + }, + }), + Custom: rayJobObj, + Config: map[string]string{ + flytek8s.PrimaryContainerKey: "main", + }, + }, + &corev1.Container{ + Name: "logs-sidecar", + Image: "test-image", + }, + []corev1.Volume{ + { + Name: "system-ray-state", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + }, + []corev1.VolumeMount{ + { + Name: "system-ray-state", + MountPath: "/tmp/ray", + }, + }, + []corev1.VolumeMount{ + { + Name: "system-ray-state", + MountPath: "/tmp/ray", + ReadOnly: true, + }, + }, + }, + { + "pod target with existing ray state volume", + core.TaskTemplate{ + Id: &core.Identifier{Name: "ray-id"}, + Target: transformPodSpecToTaskTemplateTarget(&corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "main", + Image: "primary-image", + VolumeMounts: []corev1.VolumeMount{ + { + Name: "test-vol", + MountPath: "/tmp/ray", + }, + }, + }, + }, + Volumes: []corev1.Volume{ + { + Name: "test-vol", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + }, + }), + Custom: rayJobObj, + Config: map[string]string{ + flytek8s.PrimaryContainerKey: "main", + }, + }, + &corev1.Container{ + Name: "logs-sidecar", + Image: "test-image", + }, + []corev1.Volume{ + { + Name: "test-vol", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + }, + []corev1.VolumeMount{ + { + Name: "test-vol", + MountPath: "/tmp/ray", + }, + }, + []corev1.VolumeMount{ + { + Name: "test-vol", + MountPath: "/tmp/ray", + ReadOnly: true, + }, + }, + }, + } + + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + assert.NoError(t, SetConfig(&Config{ + LogsSidecar: p.logsSidecarCfg, + })) + taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil, "", serviceAccount) + rayJobResourceHandler := rayJobResourceHandler{} + r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + rayJob, ok := r.(*rayv1.RayJob) + assert.True(t, ok) + + headPodSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec + + // Check volumes + assert.EqualValues(t, p.expectedVolumes, headPodSpec.Volumes) + + // Check containers and respective volume mounts + foundPrimaryContainer := false + foundLogsSidecar := false + for _, cnt := range headPodSpec.Containers { + if cnt.Name == "ray-head" { + foundPrimaryContainer = true + assert.EqualValues( + t, + p.expectedPrimaryContainerVolumeMounts, + cnt.VolumeMounts, + ) + } + if p.logsSidecarCfg != nil && cnt.Name == p.logsSidecarCfg.Name { + foundLogsSidecar = true + assert.EqualValues( + t, + p.expectedLogsSidecarVolumeMounts, + cnt.VolumeMounts, + ) + } + } + assert.Equal(t, true, foundPrimaryContainer) + assert.Equal(t, p.logsSidecarCfg != nil, foundLogsSidecar) + }) + } } func newPluginContext(pluginState k8s.PluginState) k8s.PluginContext { - plg := &mocks2.PluginContext{} - - taskExecID := &mocks.TaskExecutionID{} - taskExecID.OnGetID().Return(core.TaskExecutionIdentifier{ - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Name: "my-task-name", - Project: "my-task-project", - Domain: "my-task-domain", - Version: "1", - }, - NodeExecutionId: &core.NodeExecutionIdentifier{ - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my-execution-name", - Project: "my-execution-project", - Domain: "my-execution-domain", - }, - }, - RetryAttempt: 1, - }) - taskExecID.OnGetUniqueNodeID().Return("unique-node") - taskExecID.OnGetGeneratedName().Return("generated-name") - - tskCtx := &mocks.TaskExecutionMetadata{} - tskCtx.OnGetTaskExecutionID().Return(taskExecID) - plg.OnTaskExecutionMetadata().Return(tskCtx) - - pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( - func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = pluginState - return 0 - }, - func(v interface{}) error { - return nil - }) - - plg.OnPluginStateReader().Return(&pluginStateReaderMock) - - return plg + plg := &mocks2.PluginContext{} + + taskExecID := &mocks.TaskExecutionID{} + taskExecID.OnGetID().Return(core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Name: "my-task-name", + Project: "my-task-project", + Domain: "my-task-domain", + Version: "1", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my-execution-name", + Project: "my-execution-project", + Domain: "my-execution-domain", + }, + }, + RetryAttempt: 1, + }) + taskExecID.OnGetUniqueNodeID().Return("unique-node") + taskExecID.OnGetGeneratedName().Return("generated-name") + + tskCtx := &mocks.TaskExecutionMetadata{} + tskCtx.OnGetTaskExecutionID().Return(taskExecID) + plg.OnTaskExecutionMetadata().Return(tskCtx) + + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = pluginState + return 0 + }, + func(v interface{}) error { + return nil + }) + + plg.OnPluginStateReader().Return(&pluginStateReaderMock) + + return plg } func init() { - f := defaultConfig - f.Logs = logs.LogConfig{ - IsKubernetesEnabled: true, - } - - if err := SetConfig(&f); err != nil { - panic(err) - } + f := defaultConfig + f.Logs = logs.LogConfig{ + IsKubernetesEnabled: true, + } + + if err := SetConfig(&f); err != nil { + panic(err) + } } func TestGetTaskPhase(t *testing.T) { - ctx := context.Background() - rayJobResourceHandler := rayJobResourceHandler{} - pluginCtx := newPluginContext(k8s.PluginState{}) - - testCases := []struct { - rayJobPhase rayv1.JobDeploymentStatus - expectedCorePhase pluginsCore.Phase - expectedError bool - }{ - {rayv1.JobDeploymentStatusInitializing, pluginsCore.PhaseInitializing, false}, - {rayv1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning, false}, - {rayv1.JobDeploymentStatusComplete, pluginsCore.PhaseSuccess, false}, - {rayv1.JobDeploymentStatusFailed, pluginsCore.PhasePermanentFailure, false}, - {rayv1.JobDeploymentStatusSuspended, pluginsCore.PhaseUndefined, true}, - } - - for _, tc := range testCases { - t.Run("TestGetTaskPhase_"+string(tc.rayJobPhase), func(t *testing.T) { - rayObject := &rayv1.RayJob{} - rayObject.Status.JobDeploymentStatus = tc.rayJobPhase - startTime := metav1.NewTime(time.Now()) - rayObject.Status.StartTime = &startTime - phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) - if tc.expectedError { - assert.Error(t, err) - } else { - assert.Nil(t, err) - } - assert.Equal(t, tc.expectedCorePhase.String(), phaseInfo.Phase().String()) - }) - } + ctx := context.Background() + rayJobResourceHandler := rayJobResourceHandler{} + pluginCtx := newPluginContext(k8s.PluginState{}) + + testCases := []struct { + rayJobPhase rayv1.JobDeploymentStatus + expectedCorePhase pluginsCore.Phase + expectedError bool + }{ + {rayv1.JobDeploymentStatusInitializing, pluginsCore.PhaseInitializing, false}, + {rayv1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning, false}, + {rayv1.JobDeploymentStatusComplete, pluginsCore.PhaseSuccess, false}, + {rayv1.JobDeploymentStatusFailed, pluginsCore.PhasePermanentFailure, false}, + {rayv1.JobDeploymentStatusSuspended, pluginsCore.PhaseUndefined, true}, + } + + for _, tc := range testCases { + t.Run("TestGetTaskPhase_"+string(tc.rayJobPhase), func(t *testing.T) { + rayObject := &rayv1.RayJob{} + rayObject.Status.JobDeploymentStatus = tc.rayJobPhase + startTime := metav1.NewTime(time.Now()) + rayObject.Status.StartTime = &startTime + phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) + if tc.expectedError { + assert.Error(t, err) + } else { + assert.Nil(t, err) + } + assert.Equal(t, tc.expectedCorePhase.String(), phaseInfo.Phase().String()) + }) + } } func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { - rayJobResourceHandler := rayJobResourceHandler{} + rayJobResourceHandler := rayJobResourceHandler{} - ctx := context.TODO() + ctx := context.TODO() - pluginState := k8s.PluginState{ - Phase: pluginsCore.PhaseInitializing, - PhaseVersion: pluginsCore.DefaultPhaseVersion, - Reason: "task submitted to K8s", - } - pluginCtx := newPluginContext(pluginState) + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseInitializing, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + pluginCtx := newPluginContext(pluginState) - rayObject := &rayv1.RayJob{} - rayObject.Status.JobDeploymentStatus = rayv1.JobDeploymentStatusInitializing - phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) + rayObject := &rayv1.RayJob{} + rayObject.Status.JobDeploymentStatus = rayv1.JobDeploymentStatusInitializing + phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) - assert.NoError(t, err) - assert.Equal(t, phaseInfo.Version(), pluginsCore.DefaultPhaseVersion+1) + assert.NoError(t, err) + assert.Equal(t, phaseInfo.Version(), pluginsCore.DefaultPhaseVersion+1) } func TestGetEventInfo_LogTemplates(t *testing.T) { - pluginCtx := newPluginContext(k8s.PluginState{}) - testCases := []struct { - name string - rayJob rayv1.RayJob - logPlugin tasklog.TemplateLogPlugin - expectedTaskLogs []*core.TaskLog - }{ - { - name: "namespace", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "namespace", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "namespace", - Uri: "http://test/test-namespace", - }, - }, - }, - { - name: "task execution ID", - rayJob: rayv1.RayJob{}, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "taskExecID", - TemplateURIs: []tasklog.TemplateURI{ - "http://test/projects/{{ .executionProject }}/domains/{{ .executionDomain }}/executions/{{ .executionName }}/nodeId/{{ .nodeID }}/taskId/{{ .taskID }}/attempt/{{ .taskRetryAttempt }}", - }, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "taskExecID", - Uri: "http://test/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/unique-node/taskId/my-task-name/attempt/1", - }, - }, - }, - { - name: "ray cluster name", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - Status: rayv1.RayJobStatus{ - RayClusterName: "ray-cluster", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "ray cluster name", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayClusterName }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "ray cluster name", - Uri: "http://test/test-namespace/ray-cluster", - }, - }, - }, - { - name: "ray job ID", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - Status: rayv1.RayJobStatus{ - JobId: "ray-job-1", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "ray job ID", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayJobID }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "ray job ID", - Uri: "http://test/test-namespace/ray-job-1", - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ti, err := getEventInfoForRayJob( - logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}}, - pluginCtx, - &tc.rayJob, - ) - assert.NoError(t, err) - assert.Equal(t, tc.expectedTaskLogs, ti.Logs) - }) - } + pluginCtx := newPluginContext(k8s.PluginState{}) + testCases := []struct { + name string + rayJob rayv1.RayJob + logPlugin tasklog.TemplateLogPlugin + expectedTaskLogs []*core.TaskLog + }{ + { + name: "namespace", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "namespace", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "namespace", + Uri: "http://test/test-namespace", + }, + }, + }, + { + name: "task execution ID", + rayJob: rayv1.RayJob{}, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "taskExecID", + TemplateURIs: []tasklog.TemplateURI{ + "http://test/projects/{{ .executionProject }}/domains/{{ .executionDomain }}/executions/{{ .executionName }}/nodeId/{{ .nodeID }}/taskId/{{ .taskID }}/attempt/{{ .taskRetryAttempt }}", + }, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "taskExecID", + Uri: "http://test/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/unique-node/taskId/my-task-name/attempt/1", + }, + }, + }, + { + name: "ray cluster name", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + Status: rayv1.RayJobStatus{ + RayClusterName: "ray-cluster", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "ray cluster name", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayClusterName }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "ray cluster name", + Uri: "http://test/test-namespace/ray-cluster", + }, + }, + }, + { + name: "ray job ID", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + Status: rayv1.RayJobStatus{ + JobId: "ray-job-1", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "ray job ID", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayJobID }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "ray job ID", + Uri: "http://test/test-namespace/ray-job-1", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ti, err := getEventInfoForRayJob( + logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}}, + pluginCtx, + &tc.rayJob, + ) + assert.NoError(t, err) + assert.Equal(t, tc.expectedTaskLogs, ti.Logs) + }) + } } func TestGetEventInfo_LogTemplates_V1(t *testing.T) { - pluginCtx := newPluginContext(k8s.PluginState{}) - testCases := []struct { - name string - rayJob rayv1.RayJob - logPlugin tasklog.TemplateLogPlugin - expectedTaskLogs []*core.TaskLog - }{ - { - name: "namespace", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "namespace", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "namespace", - Uri: "http://test/test-namespace", - }, - }, - }, - { - name: "task execution ID", - rayJob: rayv1.RayJob{}, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "taskExecID", - TemplateURIs: []tasklog.TemplateURI{ - "http://test/projects/{{ .executionProject }}/domains/{{ .executionDomain }}/executions/{{ .executionName }}/nodeId/{{ .nodeID }}/taskId/{{ .taskID }}/attempt/{{ .taskRetryAttempt }}", - }, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "taskExecID", - Uri: "http://test/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/unique-node/taskId/my-task-name/attempt/1", - }, - }, - }, - { - name: "ray cluster name", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - Status: rayv1.RayJobStatus{ - RayClusterName: "ray-cluster", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "ray cluster name", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayClusterName }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "ray cluster name", - Uri: "http://test/test-namespace/ray-cluster", - }, - }, - }, - { - name: "ray job ID", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - Status: rayv1.RayJobStatus{ - JobId: "ray-job-1", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "ray job ID", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayJobID }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "ray job ID", - Uri: "http://test/test-namespace/ray-job-1", - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ti, err := getEventInfoForRayJob( - logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}}, - pluginCtx, - &tc.rayJob, - ) - assert.NoError(t, err) - assert.Equal(t, tc.expectedTaskLogs, ti.Logs) - }) - } + pluginCtx := newPluginContext(k8s.PluginState{}) + testCases := []struct { + name string + rayJob rayv1.RayJob + logPlugin tasklog.TemplateLogPlugin + expectedTaskLogs []*core.TaskLog + }{ + { + name: "namespace", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "namespace", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "namespace", + Uri: "http://test/test-namespace", + }, + }, + }, + { + name: "task execution ID", + rayJob: rayv1.RayJob{}, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "taskExecID", + TemplateURIs: []tasklog.TemplateURI{ + "http://test/projects/{{ .executionProject }}/domains/{{ .executionDomain }}/executions/{{ .executionName }}/nodeId/{{ .nodeID }}/taskId/{{ .taskID }}/attempt/{{ .taskRetryAttempt }}", + }, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "taskExecID", + Uri: "http://test/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/unique-node/taskId/my-task-name/attempt/1", + }, + }, + }, + { + name: "ray cluster name", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + Status: rayv1.RayJobStatus{ + RayClusterName: "ray-cluster", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "ray cluster name", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayClusterName }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "ray cluster name", + Uri: "http://test/test-namespace/ray-cluster", + }, + }, + }, + { + name: "ray job ID", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + Status: rayv1.RayJobStatus{ + JobId: "ray-job-1", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "ray job ID", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayJobID }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "ray job ID", + Uri: "http://test/test-namespace/ray-job-1", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ti, err := getEventInfoForRayJob( + logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}}, + pluginCtx, + &tc.rayJob, + ) + assert.NoError(t, err) + assert.Equal(t, tc.expectedTaskLogs, ti.Logs) + }) + } } func TestGetEventInfo_DashboardURL(t *testing.T) { - pluginCtx := newPluginContext(k8s.PluginState{}) - testCases := []struct { - name string - rayJob rayv1.RayJob - dashboardURLTemplate tasklog.TemplateLogPlugin - expectedTaskLogs []*core.TaskLog - }{ - { - name: "dashboard URL displayed", - rayJob: rayv1.RayJob{ - Status: rayv1.RayJobStatus{ - DashboardURL: "exists", - JobStatus: rayv1.JobStatusRunning, - }, - }, - dashboardURLTemplate: tasklog.TemplateLogPlugin{ - DisplayName: "Ray Dashboard", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{.generatedName}}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "Ray Dashboard", - Uri: "http://test/generated-name", - }, - }, - }, - { - name: "dashboard URL is not displayed", - rayJob: rayv1.RayJob{ - Status: rayv1.RayJobStatus{ - JobStatus: rayv1.JobStatusPending, - }, - }, - dashboardURLTemplate: tasklog.TemplateLogPlugin{ - DisplayName: "dummy", - TemplateURIs: []tasklog.TemplateURI{"http://dummy"}, - }, - expectedTaskLogs: nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) - ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob) - assert.NoError(t, err) - assert.Equal(t, tc.expectedTaskLogs, ti.Logs) - }) - } + pluginCtx := newPluginContext(k8s.PluginState{}) + testCases := []struct { + name string + rayJob rayv1.RayJob + dashboardURLTemplate tasklog.TemplateLogPlugin + expectedTaskLogs []*core.TaskLog + }{ + { + name: "dashboard URL displayed", + rayJob: rayv1.RayJob{ + Status: rayv1.RayJobStatus{ + DashboardURL: "exists", + JobStatus: rayv1.JobStatusRunning, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "Ray Dashboard", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{.generatedName}}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "Ray Dashboard", + Uri: "http://test/generated-name", + }, + }, + }, + { + name: "dashboard URL is not displayed", + rayJob: rayv1.RayJob{ + Status: rayv1.RayJobStatus{ + JobStatus: rayv1.JobStatusPending, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "dummy", + TemplateURIs: []tasklog.TemplateURI{"http://dummy"}, + }, + expectedTaskLogs: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) + ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob) + assert.NoError(t, err) + assert.Equal(t, tc.expectedTaskLogs, ti.Logs) + }) + } } func TestGetEventInfo_DashboardURL_V1(t *testing.T) { - pluginCtx := newPluginContext(k8s.PluginState{}) - testCases := []struct { - name string - rayJob rayv1.RayJob - dashboardURLTemplate tasklog.TemplateLogPlugin - expectedTaskLogs []*core.TaskLog - }{ - { - name: "dashboard URL displayed", - rayJob: rayv1.RayJob{ - Status: rayv1.RayJobStatus{ - DashboardURL: "exists", - JobStatus: rayv1.JobStatusRunning, - }, - }, - dashboardURLTemplate: tasklog.TemplateLogPlugin{ - DisplayName: "Ray Dashboard", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{.generatedName}}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "Ray Dashboard", - Uri: "http://test/generated-name", - }, - }, - }, - { - name: "dashboard URL is not displayed", - rayJob: rayv1.RayJob{ - Status: rayv1.RayJobStatus{ - JobStatus: rayv1.JobStatusPending, - }, - }, - dashboardURLTemplate: tasklog.TemplateLogPlugin{ - DisplayName: "dummy", - TemplateURIs: []tasklog.TemplateURI{"http://dummy"}, - }, - expectedTaskLogs: nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) - ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob) - assert.NoError(t, err) - assert.Equal(t, tc.expectedTaskLogs, ti.Logs) - }) - } + pluginCtx := newPluginContext(k8s.PluginState{}) + testCases := []struct { + name string + rayJob rayv1.RayJob + dashboardURLTemplate tasklog.TemplateLogPlugin + expectedTaskLogs []*core.TaskLog + }{ + { + name: "dashboard URL displayed", + rayJob: rayv1.RayJob{ + Status: rayv1.RayJobStatus{ + DashboardURL: "exists", + JobStatus: rayv1.JobStatusRunning, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "Ray Dashboard", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{.generatedName}}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "Ray Dashboard", + Uri: "http://test/generated-name", + }, + }, + }, + { + name: "dashboard URL is not displayed", + rayJob: rayv1.RayJob{ + Status: rayv1.RayJobStatus{ + JobStatus: rayv1.JobStatusPending, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "dummy", + TemplateURIs: []tasklog.TemplateURI{"http://dummy"}, + }, + expectedTaskLogs: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) + ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob) + assert.NoError(t, err) + assert.Equal(t, tc.expectedTaskLogs, ti.Logs) + }) + } } func TestGetPropertiesRay(t *testing.T) { - rayJobResourceHandler := rayJobResourceHandler{} - expected := k8s.PluginProperties{} - assert.Equal(t, expected, rayJobResourceHandler.GetProperties()) + rayJobResourceHandler := rayJobResourceHandler{} + expected := k8s.PluginProperties{} + assert.Equal(t, expected, rayJobResourceHandler.GetProperties()) }