Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leverage KubeRay v1 instead of v1alpha1 for resources #4818

Merged
merged 14 commits into from
Feb 13, 2024
2 changes: 2 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/ray/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ var (
IncludeDashboard: true,
DashboardHost: "0.0.0.0",
EnableUsageStats: false,
KubeRayCrdVersion: "v1alpha1",
Defaults: DefaultConfig{
HeadNode: NodeConfig{
StartParameters: map[string]string{
Expand Down Expand Up @@ -85,6 +86,7 @@ type Config struct {
DashboardURLTemplate *tasklog.TemplateLogPlugin `json:"dashboardURLTemplate" pflag:"-,Template for URL of Ray dashboard running on a head node."`
Defaults DefaultConfig `json:"defaults" pflag:"-,Default configuration for ray jobs"`
EnableUsageStats bool `json:"enableUsageStats" pflag:",Enable usage stats for ray jobs. These stats are submitted to usage-stats.ray.io per https://docs.ray.io/en/latest/cluster/usage-stats.html"`
KubeRayCrdVersion string `json:"kubeRayCrdVersion" pflag:",Version of the Ray CRD to use when creating RayClusters or RayJobs."`
}

type DefaultConfig struct {
Expand Down
254 changes: 240 additions & 14 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"strings"
"time"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
rayv1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -28,14 +29,15 @@
)

const (
rayStateMountPath = "/tmp/ray"
defaultRayStateVolName = "system-ray-state"
rayTaskType = "ray"
KindRayJob = "RayJob"
IncludeDashboard = "include-dashboard"
NodeIPAddress = "node-ip-address"
DashboardHost = "dashboard-host"
DisableUsageStatsStartParameter = "disable-usage-stats"
rayStateMountPath = "/tmp/ray"
defaultRayStateVolName = "system-ray-state"
rayTaskType = "ray"
KindRayJob = "RayJob"
IncludeDashboard = "include-dashboard"
NodeIPAddress = "node-ip-address"
DashboardHost = "dashboard-host"
DisableUsageStatsStartParameter = "disable-usage-stats"
DisableUsageStatsStartParameterVal = "true"
)

var logTemplateRegexes = struct {
Expand All @@ -52,7 +54,7 @@
return k8s.PluginProperties{}
}

// BuildResource Creates a new ray job resource.
// BuildResource Creates a new ray job resource for v1 or v1alpha1.
func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
Expand Down Expand Up @@ -109,11 +111,22 @@
}

if _, exists := headNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats {
headNodeRayStartParams[DisableUsageStatsStartParameter] = "true"
headNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal
}

enableIngress := true
headPodSpec := podSpec.DeepCopy()

if cfg.KubeRayCrdVersion == "v1" {
return constructV1Job(taskCtx, rayJob, objectMeta, *podSpec, headPodSpec, headReplicas, headNodeRayStartParams, primaryContainerIdx, *primaryContainer), nil
}

Check warning on line 121 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L120-L121

Added lines #L120 - L121 were not covered by tests

return constructV1Alpha1Job(taskCtx, rayJob, objectMeta, *podSpec, headPodSpec, headReplicas, headNodeRayStartParams, primaryContainerIdx, *primaryContainer), nil

}

func constructV1Alpha1Job(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.RayJob, objectMeta *metav1.ObjectMeta, podSpec v1.PodSpec, headPodSpec *v1.PodSpec, headReplicas int32, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) *rayv1alpha1.RayJob {
enableIngress := true
cfg := GetConfig()
rayClusterSpec := rayv1alpha1.RayClusterSpec{
HeadGroupSpec: rayv1alpha1.HeadGroupSpec{
Template: buildHeadPodTemplate(
Expand Down Expand Up @@ -152,7 +165,7 @@
}

if _, exists := workerNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats {
workerNodeRayStartParams[DisableUsageStatsStartParameter] = "true"
workerNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal
}

minReplicas := spec.MinReplicas
Expand Down Expand Up @@ -198,16 +211,111 @@
RuntimeEnv: rayJob.RuntimeEnv,
}

rayJobObject := rayv1alpha1.RayJob{
return &rayv1alpha1.RayJob{
TypeMeta: metav1.TypeMeta{
Kind: KindRayJob,
APIVersion: rayv1alpha1.SchemeGroupVersion.String(),
},
Spec: jobSpec,
ObjectMeta: *objectMeta,
}
}

func constructV1Job(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.RayJob, objectMeta *metav1.ObjectMeta, podSpec v1.PodSpec, headPodSpec *v1.PodSpec, headReplicas int32, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) *rayv1.RayJob {
enableIngress := true
cfg := GetConfig()
rayClusterSpec := rayv1.RayClusterSpec{
HeadGroupSpec: rayv1.HeadGroupSpec{
Template: buildHeadPodTemplate(
&headPodSpec.Containers[primaryContainerIdx],
headPodSpec,
objectMeta,
taskCtx,
),
ServiceType: v1.ServiceType(cfg.ServiceType),
Replicas: &headReplicas,
EnableIngress: &enableIngress,
RayStartParams: headNodeRayStartParams,
},
WorkerGroupSpecs: []rayv1.WorkerGroupSpec{},
EnableInTreeAutoscaling: &rayJob.RayCluster.EnableAutoscaling,
}

for _, spec := range rayJob.RayCluster.WorkerGroupSpec {
workerPodSpec := podSpec.DeepCopy()
workerPodTemplate := buildWorkerPodTemplate(
&workerPodSpec.Containers[primaryContainerIdx],
workerPodSpec,
objectMeta,
taskCtx,
)

workerNodeRayStartParams := make(map[string]string)
if spec.RayStartParams != nil {
workerNodeRayStartParams = spec.RayStartParams
} else if workerNode := cfg.Defaults.WorkerNode; len(workerNode.StartParameters) > 0 {
workerNodeRayStartParams = workerNode.StartParameters
}

Check warning on line 258 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L224-L258

Added lines #L224 - L258 were not covered by tests

if _, exist := workerNodeRayStartParams[NodeIPAddress]; !exist {
workerNodeRayStartParams[NodeIPAddress] = cfg.Defaults.WorkerNode.IPAddress
}

Check warning on line 262 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L260-L262

Added lines #L260 - L262 were not covered by tests

if _, exists := workerNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats {
workerNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal
}

Check warning on line 266 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L264-L266

Added lines #L264 - L266 were not covered by tests

return &rayJobObject, nil
minReplicas := spec.MinReplicas
if minReplicas > spec.Replicas {
minReplicas = spec.Replicas
}
maxReplicas := spec.MaxReplicas
if maxReplicas < spec.Replicas {
maxReplicas = spec.Replicas
}

Check warning on line 275 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L268-L275

Added lines #L268 - L275 were not covered by tests

workerNodeSpec := rayv1.WorkerGroupSpec{
GroupName: spec.GroupName,
MinReplicas: &minReplicas,
MaxReplicas: &maxReplicas,
Replicas: &spec.Replicas,
RayStartParams: workerNodeRayStartParams,
Template: workerPodTemplate,
}

rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec)

Check warning on line 286 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L277-L286

Added lines #L277 - L286 were not covered by tests
}

serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName
for index := range rayClusterSpec.WorkerGroupSpecs {
rayClusterSpec.WorkerGroupSpecs[index].Template.Spec.ServiceAccountName = serviceAccountName
}

Check warning on line 294 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L289-L294

Added lines #L289 - L294 were not covered by tests

shutdownAfterJobFinishes := cfg.ShutdownAfterJobFinishes
ttlSecondsAfterFinished := &cfg.TTLSecondsAfterFinished
if rayJob.ShutdownAfterJobFinishes {
shutdownAfterJobFinishes = true
ttlSecondsAfterFinished = &rayJob.TtlSecondsAfterFinished
}

Check warning on line 301 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L296-L301

Added lines #L296 - L301 were not covered by tests

jobSpec := rayv1.RayJobSpec{
RayClusterSpec: &rayClusterSpec,
Entrypoint: strings.Join(primaryContainer.Args, " "),
ShutdownAfterJobFinishes: shutdownAfterJobFinishes,
TTLSecondsAfterFinished: ttlSecondsAfterFinished,
RuntimeEnv: rayJob.RuntimeEnv,
}

return &rayv1.RayJob{
TypeMeta: metav1.TypeMeta{
Kind: KindRayJob,
APIVersion: rayv1alpha1.SchemeGroupVersion.String(),
},
Spec: jobSpec,
ObjectMeta: *objectMeta,
}

Check warning on line 318 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L303-L318

Added lines #L303 - L318 were not covered by tests
}

func injectLogsSidecar(primaryContainer *v1.Container, podSpec *v1.PodSpec) {
Expand Down Expand Up @@ -503,7 +611,125 @@
return &pluginsCore.TaskInfo{Logs: taskLogs}, nil
}

func getEventInfoForRayJobV1(logConfig logs.LogConfig, pluginContext k8s.PluginContext, rayJob *rayv1.RayJob) (*pluginsCore.TaskInfo, error) {
logPlugin, err := logs.InitializeLogPlugins(&logConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err)
}

Check warning on line 618 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L617-L618

Added lines #L617 - L618 were not covered by tests

var taskLogs []*core.TaskLog

taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID()
input := tasklog.Input{
Namespace: rayJob.Namespace,
TaskExecutionID: taskExecID,
ExtraTemplateVars: []tasklog.TemplateVar{},
}
if rayJob.Status.JobId != "" {
input.ExtraTemplateVars = append(
input.ExtraTemplateVars,
tasklog.TemplateVar{
Regex: logTemplateRegexes.RayJobID,
Value: rayJob.Status.JobId,
},
)
}
if rayJob.Status.RayClusterName != "" {
input.ExtraTemplateVars = append(
input.ExtraTemplateVars,
tasklog.TemplateVar{
Regex: logTemplateRegexes.RayClusterName,
Value: rayJob.Status.RayClusterName,
},
)
}

// TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs
// RayJob CRD does not include the name of the worker or head pod for now
logOutput, err := logPlugin.GetTaskLogs(input)
if err != nil {
return nil, fmt.Errorf("failed to generate task logs. Error: %w", err)
}

Check warning on line 652 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L651-L652

Added lines #L651 - L652 were not covered by tests
taskLogs = append(taskLogs, logOutput.TaskLogs...)

// Handling for Ray Dashboard
dashboardURLTemplate := GetConfig().DashboardURLTemplate
if dashboardURLTemplate != nil &&
rayJob.Status.DashboardURL != "" &&
rayJob.Status.JobStatus == rayv1.JobStatusRunning {
dashboardURLOutput, err := dashboardURLTemplate.GetTaskLogs(input)
if err != nil {
return nil, fmt.Errorf("failed to generate Ray dashboard link. Error: %w", err)
}

Check warning on line 663 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L662-L663

Added lines #L662 - L663 were not covered by tests
taskLogs = append(taskLogs, dashboardURLOutput.TaskLogs...)
}

return &pluginsCore.TaskInfo{Logs: taskLogs}, nil
}

func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
crdVersion := GetConfig().KubeRayCrdVersion
if crdVersion == "v1" {
return plugin.GetTaskPhaseV1(ctx, pluginContext, resource)
}

Check warning on line 674 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L673-L674

Added lines #L673 - L674 were not covered by tests

return plugin.GetTaskPhaseV1Alpha1(ctx, pluginContext, resource)
}

func (plugin rayJobResourceHandler) GetTaskPhaseV1(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
rayJob := resource.(*rayv1.RayJob)
info, err := getEventInfoForRayJobV1(GetConfig().Logs, pluginContext, rayJob)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

Check warning on line 684 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L683-L684

Added lines #L683 - L684 were not covered by tests

if len(rayJob.Status.JobDeploymentStatus) == 0 {
return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling"), nil
}

Check warning on line 688 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L687-L688

Added lines #L687 - L688 were not covered by tests

// KubeRay creates a Ray cluster first, and then submits a Ray job to the cluster
switch rayJob.Status.JobDeploymentStatus {
case rayv1.JobDeploymentStatusInitializing:
return pluginsCore.PhaseInfoInitializing(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "cluster is creating", info), nil
case rayv1.JobDeploymentStatusFailedToGetOrCreateRayCluster:
reason := fmt.Sprintf("Failed to create Ray cluster %s with error: %s", rayJob.Name, rayJob.Status.Message)
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
case rayv1.JobDeploymentStatusFailedJobDeploy:
reason := fmt.Sprintf("Failed to submit Ray job %s with error: %s", rayJob.Name, rayJob.Status.Message)
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
// JobDeploymentStatusSuspended is used when the suspend flag is set in rayJob. The suspend flag allows the temporary suspension of a Job's execution, which can be resumed later.
// Certain versions of KubeRay use a K8s job to submit a Ray job to the Ray cluster. JobDeploymentStatusWaitForK8sJob indicates that the K8s job is under creation.
case rayv1.JobDeploymentStatusWaitForDashboard, rayv1.JobDeploymentStatusFailedToGetJobStatus, rayv1.JobDeploymentStatusWaitForDashboardReady, rayv1.JobDeploymentStatusWaitForK8sJob, rayv1.JobDeploymentStatusSuspended:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil
case rayv1.JobDeploymentStatusRunning, rayv1.JobDeploymentStatusComplete:
switch rayJob.Status.JobStatus {
case rayv1.JobStatusFailed:
reason := fmt.Sprintf("Failed to run Ray job %s with error: %s", rayJob.Name, rayJob.Status.Message)
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
case rayv1.JobStatusSucceeded:
return pluginsCore.PhaseInfoSuccess(info), nil
// JobStatusStopped can occur when the suspend flag is set in rayJob.
case rayv1.JobStatusPending, rayv1.JobStatusStopped:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil
case rayv1.JobStatusRunning:
phaseInfo := pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info)
if len(info.Logs) > 0 {
phaseInfo = phaseInfo.WithVersion(pluginsCore.DefaultPhaseVersion + 1)
}

Check warning on line 718 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L717-L718

Added lines #L717 - L718 were not covered by tests
return phaseInfo, nil
default:
// We already handle all known job status, so this should never happen unless a future version of ray
// introduced a new job status.
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job status: %s", rayJob.Status.JobStatus)

Check warning on line 723 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L720-L723

Added lines #L720 - L723 were not covered by tests
}
default:
// We already handle all known deployment status, so this should never happen unless a future version of ray
// introduced a new job status.
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job deployment status: %s", rayJob.Status.JobDeploymentStatus)

Check warning on line 728 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L725-L728

Added lines #L725 - L728 were not covered by tests
}
}

func (plugin rayJobResourceHandler) GetTaskPhaseV1Alpha1(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
rayJob := resource.(*rayv1alpha1.RayJob)
info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob)
if err != nil {
Expand Down
Loading
Loading