diff --git a/flyteidl/gen/pb-es/flyteidl/plugins/ray_pb.ts b/flyteidl/gen/pb-es/flyteidl/plugins/ray_pb.ts index 10877bcb42e..0809147e506 100644 --- a/flyteidl/gen/pb-es/flyteidl/plugins/ray_pb.ts +++ b/flyteidl/gen/pb-es/flyteidl/plugins/ray_pb.ts @@ -5,6 +5,7 @@ import type { BinaryReadOptions, FieldList, JsonReadOptions, JsonValue, PartialMessage, PlainMessage } from "@bufbuild/protobuf"; import { Message, proto3 } from "@bufbuild/protobuf"; +import { Resources } from "../core/tasks_pb.js"; /** * RayJobSpec defines the desired state of RayJob @@ -153,6 +154,13 @@ export class HeadGroupSpec extends Message { */ rayStartParams: { [key: string]: string } = {}; + /** + * Resource specification for ray head pod + * + * @generated from field: flyteidl.core.Resources resources = 2; + */ + resources?: Resources; + constructor(data?: PartialMessage) { super(); proto3.util.initPartial(data, this); @@ -162,6 +170,7 @@ export class HeadGroupSpec extends Message { static readonly typeName = "flyteidl.plugins.HeadGroupSpec"; static readonly fields: FieldList = proto3.util.newFieldList(() => [ { no: 1, name: "ray_start_params", kind: "map", K: 9 /* ScalarType.STRING */, V: {kind: "scalar", T: 9 /* ScalarType.STRING */} }, + { no: 2, name: "resources", kind: "message", T: Resources }, ]); static fromBinary(bytes: Uint8Array, options?: Partial): HeadGroupSpec { @@ -223,6 +232,13 @@ export class WorkerGroupSpec extends Message { */ rayStartParams: { [key: string]: string } = {}; + /** + * Resource specification for ray worker pods + * + * @generated from field: flyteidl.core.Resources resources = 6; + */ + resources?: Resources; + constructor(data?: PartialMessage) { super(); proto3.util.initPartial(data, this); @@ -236,6 +252,7 @@ export class WorkerGroupSpec extends Message { { no: 3, name: "min_replicas", kind: "scalar", T: 5 /* ScalarType.INT32 */ }, { no: 4, name: "max_replicas", kind: "scalar", T: 5 /* ScalarType.INT32 */ }, { no: 5, name: "ray_start_params", kind: "map", K: 9 /* ScalarType.STRING */, V: {kind: "scalar", T: 9 /* ScalarType.STRING */} }, + { no: 6, name: "resources", kind: "message", T: Resources }, ]); static fromBinary(bytes: Uint8Array, options?: Partial): WorkerGroupSpec { diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/ray.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/ray.pb.go index d0f072db51d..153eefecbb8 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/ray.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/ray.pb.go @@ -7,6 +7,7 @@ package plugins import ( + core "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -186,6 +187,8 @@ type HeadGroupSpec struct { // Optional. RayStartParams are the params of the start command: address, object-store-memory. // Refer to https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-start RayStartParams map[string]string `protobuf:"bytes,1,rep,name=ray_start_params,json=rayStartParams,proto3" json:"ray_start_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + // Resource specification for ray head pod + Resources *core.Resources `protobuf:"bytes,2,opt,name=resources,proto3" json:"resources,omitempty"` } func (x *HeadGroupSpec) Reset() { @@ -227,6 +230,13 @@ func (x *HeadGroupSpec) GetRayStartParams() map[string]string { return nil } +func (x *HeadGroupSpec) GetResources() *core.Resources { + if x != nil { + return x.Resources + } + return nil +} + // WorkerGroupSpec are the specs for the worker pods type WorkerGroupSpec struct { state protoimpl.MessageState @@ -244,6 +254,8 @@ type WorkerGroupSpec struct { // Optional. RayStartParams are the params of the start command: address, object-store-memory. // Refer to https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-start RayStartParams map[string]string `protobuf:"bytes,5,rep,name=ray_start_params,json=rayStartParams,proto3" json:"ray_start_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + // Resource specification for ray worker pods + Resources *core.Resources `protobuf:"bytes,6,opt,name=resources,proto3" json:"resources,omitempty"` } func (x *WorkerGroupSpec) Reset() { @@ -313,87 +325,102 @@ func (x *WorkerGroupSpec) GetRayStartParams() map[string]string { return nil } +func (x *WorkerGroupSpec) GetResources() *core.Resources { + if x != nil { + return x.Resources + } + return nil +} + var File_flyteidl_plugins_ray_proto protoreflect.FileDescriptor var file_flyteidl_plugins_ray_proto_rawDesc = []byte{ 0x0a, 0x1a, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2f, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x10, 0x66, 0x6c, - 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x22, 0x92, - 0x02, 0x0a, 0x06, 0x52, 0x61, 0x79, 0x4a, 0x6f, 0x62, 0x12, 0x3d, 0x0a, 0x0b, 0x72, 0x61, 0x79, - 0x5f, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, - 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, - 0x73, 0x2e, 0x52, 0x61, 0x79, 0x43, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x52, 0x0a, 0x72, 0x61, - 0x79, 0x43, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x12, 0x23, 0x0a, 0x0b, 0x72, 0x75, 0x6e, 0x74, - 0x69, 0x6d, 0x65, 0x5f, 0x65, 0x6e, 0x76, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, - 0x01, 0x52, 0x0a, 0x72, 0x75, 0x6e, 0x74, 0x69, 0x6d, 0x65, 0x45, 0x6e, 0x76, 0x12, 0x3d, 0x0a, - 0x1b, 0x73, 0x68, 0x75, 0x74, 0x64, 0x6f, 0x77, 0x6e, 0x5f, 0x61, 0x66, 0x74, 0x65, 0x72, 0x5f, - 0x6a, 0x6f, 0x62, 0x5f, 0x66, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x18, 0x73, 0x68, 0x75, 0x74, 0x64, 0x6f, 0x77, 0x6e, 0x41, 0x66, 0x74, 0x65, - 0x72, 0x4a, 0x6f, 0x62, 0x46, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x73, 0x12, 0x3b, 0x0a, 0x1a, - 0x74, 0x74, 0x6c, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x5f, 0x61, 0x66, 0x74, 0x65, - 0x72, 0x5f, 0x66, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, - 0x52, 0x17, 0x74, 0x74, 0x6c, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x41, 0x66, 0x74, 0x65, - 0x72, 0x46, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x72, 0x75, 0x6e, - 0x74, 0x69, 0x6d, 0x65, 0x5f, 0x65, 0x6e, 0x76, 0x5f, 0x79, 0x61, 0x6d, 0x6c, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0e, 0x72, 0x75, 0x6e, 0x74, 0x69, 0x6d, 0x65, 0x45, 0x6e, 0x76, 0x59, - 0x61, 0x6d, 0x6c, 0x22, 0xd3, 0x01, 0x0a, 0x0a, 0x52, 0x61, 0x79, 0x43, 0x6c, 0x75, 0x73, 0x74, - 0x65, 0x72, 0x12, 0x47, 0x0a, 0x0f, 0x68, 0x65, 0x61, 0x64, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, - 0x5f, 0x73, 0x70, 0x65, 0x63, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x66, 0x6c, - 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x48, - 0x65, 0x61, 0x64, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x52, 0x0d, 0x68, 0x65, - 0x61, 0x64, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x12, 0x4d, 0x0a, 0x11, 0x77, - 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x73, 0x70, 0x65, 0x63, - 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, - 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, - 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x52, 0x0f, 0x77, 0x6f, 0x72, 0x6b, 0x65, - 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x12, 0x2d, 0x0a, 0x12, 0x65, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x61, 0x75, 0x74, 0x6f, 0x73, 0x63, 0x61, 0x6c, 0x69, 0x6e, 0x67, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, - 0x74, 0x6f, 0x73, 0x63, 0x61, 0x6c, 0x69, 0x6e, 0x67, 0x22, 0xb1, 0x01, 0x0a, 0x0d, 0x48, 0x65, - 0x61, 0x64, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x12, 0x5d, 0x0a, 0x10, 0x72, - 0x61, 0x79, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x33, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, - 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x2e, 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, - 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0e, 0x72, 0x61, 0x79, 0x53, - 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x1a, 0x41, 0x0a, 0x13, 0x52, 0x61, - 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, - 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, - 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xb6, 0x02, - 0x0a, 0x0f, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, - 0x63, 0x12, 0x1d, 0x0a, 0x0a, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x4e, 0x61, 0x6d, 0x65, - 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x05, 0x52, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, 0x21, 0x0a, 0x0c, - 0x6d, 0x69, 0x6e, 0x5f, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x05, 0x52, 0x0b, 0x6d, 0x69, 0x6e, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, - 0x21, 0x0a, 0x0c, 0x6d, 0x61, 0x78, 0x5f, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x6d, 0x61, 0x78, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, - 0x61, 0x73, 0x12, 0x5f, 0x0a, 0x10, 0x72, 0x61, 0x79, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, - 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x35, 0x2e, 0x66, - 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, - 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x2e, - 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, - 0x74, 0x72, 0x79, 0x52, 0x0e, 0x72, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, - 0x61, 0x6d, 0x73, 0x1a, 0x41, 0x0a, 0x13, 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, - 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, - 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0xc0, 0x01, 0x0a, 0x14, 0x63, 0x6f, 0x6d, 0x2e, 0x66, - 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x42, - 0x08, 0x52, 0x61, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x3d, 0x67, 0x69, 0x74, - 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, - 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, - 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, - 0x64, 0x6c, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0xa2, 0x02, 0x03, 0x46, 0x50, 0x58, - 0xaa, 0x02, 0x10, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, - 0x69, 0x6e, 0x73, 0xca, 0x02, 0x10, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, - 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0xe2, 0x02, 0x1c, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, - 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x11, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, - 0x3a, 0x3a, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, + 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x1a, 0x19, + 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x74, 0x61, + 0x73, 0x6b, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x92, 0x02, 0x0a, 0x06, 0x52, 0x61, + 0x79, 0x4a, 0x6f, 0x62, 0x12, 0x3d, 0x0a, 0x0b, 0x72, 0x61, 0x79, 0x5f, 0x63, 0x6c, 0x75, 0x73, + 0x74, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x66, 0x6c, 0x79, 0x74, + 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x52, 0x61, 0x79, + 0x43, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x52, 0x0a, 0x72, 0x61, 0x79, 0x43, 0x6c, 0x75, 0x73, + 0x74, 0x65, 0x72, 0x12, 0x23, 0x0a, 0x0b, 0x72, 0x75, 0x6e, 0x74, 0x69, 0x6d, 0x65, 0x5f, 0x65, + 0x6e, 0x76, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0a, 0x72, 0x75, + 0x6e, 0x74, 0x69, 0x6d, 0x65, 0x45, 0x6e, 0x76, 0x12, 0x3d, 0x0a, 0x1b, 0x73, 0x68, 0x75, 0x74, + 0x64, 0x6f, 0x77, 0x6e, 0x5f, 0x61, 0x66, 0x74, 0x65, 0x72, 0x5f, 0x6a, 0x6f, 0x62, 0x5f, 0x66, + 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x18, 0x73, + 0x68, 0x75, 0x74, 0x64, 0x6f, 0x77, 0x6e, 0x41, 0x66, 0x74, 0x65, 0x72, 0x4a, 0x6f, 0x62, 0x46, + 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x73, 0x12, 0x3b, 0x0a, 0x1a, 0x74, 0x74, 0x6c, 0x5f, 0x73, + 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x5f, 0x61, 0x66, 0x74, 0x65, 0x72, 0x5f, 0x66, 0x69, 0x6e, + 0x69, 0x73, 0x68, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x17, 0x74, 0x74, 0x6c, + 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x41, 0x66, 0x74, 0x65, 0x72, 0x46, 0x69, 0x6e, 0x69, + 0x73, 0x68, 0x65, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x72, 0x75, 0x6e, 0x74, 0x69, 0x6d, 0x65, 0x5f, + 0x65, 0x6e, 0x76, 0x5f, 0x79, 0x61, 0x6d, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, + 0x72, 0x75, 0x6e, 0x74, 0x69, 0x6d, 0x65, 0x45, 0x6e, 0x76, 0x59, 0x61, 0x6d, 0x6c, 0x22, 0xd3, + 0x01, 0x0a, 0x0a, 0x52, 0x61, 0x79, 0x43, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x12, 0x47, 0x0a, + 0x0f, 0x68, 0x65, 0x61, 0x64, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x73, 0x70, 0x65, 0x63, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, + 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x47, 0x72, + 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x52, 0x0d, 0x68, 0x65, 0x61, 0x64, 0x47, 0x72, 0x6f, + 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x12, 0x4d, 0x0a, 0x11, 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, + 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x73, 0x70, 0x65, 0x63, 0x18, 0x02, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x21, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, + 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, + 0x53, 0x70, 0x65, 0x63, 0x52, 0x0f, 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, + 0x70, 0x53, 0x70, 0x65, 0x63, 0x12, 0x2d, 0x0a, 0x12, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x5f, + 0x61, 0x75, 0x74, 0x6f, 0x73, 0x63, 0x61, 0x6c, 0x69, 0x6e, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x11, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x73, 0x63, 0x61, + 0x6c, 0x69, 0x6e, 0x67, 0x22, 0xe9, 0x01, 0x0a, 0x0d, 0x48, 0x65, 0x61, 0x64, 0x47, 0x72, 0x6f, + 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x12, 0x5d, 0x0a, 0x10, 0x72, 0x61, 0x79, 0x5f, 0x73, 0x74, + 0x61, 0x72, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x33, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, + 0x69, 0x6e, 0x73, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, + 0x63, 0x2e, 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, + 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0e, 0x72, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, + 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x36, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, + 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, + 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, + 0x65, 0x73, 0x52, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x1a, 0x41, 0x0a, + 0x13, 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, + 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, + 0x22, 0xee, 0x02, 0x0a, 0x0f, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, + 0x53, 0x70, 0x65, 0x63, 0x12, 0x1d, 0x0a, 0x0a, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x4e, + 0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, + 0x21, 0x0a, 0x0c, 0x6d, 0x69, 0x6e, 0x5f, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x6d, 0x69, 0x6e, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, + 0x61, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x6d, 0x61, 0x78, 0x5f, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, + 0x61, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x6d, 0x61, 0x78, 0x52, 0x65, 0x70, + 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, 0x5f, 0x0a, 0x10, 0x72, 0x61, 0x79, 0x5f, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x35, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, + 0x6e, 0x73, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, + 0x65, 0x63, 0x2e, 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, + 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0e, 0x72, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, + 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x36, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, + 0x63, 0x65, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, 0x79, 0x74, + 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, + 0x63, 0x65, 0x73, 0x52, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x1a, 0x41, + 0x0a, 0x13, 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, + 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, + 0x01, 0x42, 0xc0, 0x01, 0x0a, 0x14, 0x63, 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, + 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x42, 0x08, 0x52, 0x61, 0x79, 0x50, + 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, 0x79, 0x74, + 0x65, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, + 0x62, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x70, 0x6c, + 0x75, 0x67, 0x69, 0x6e, 0x73, 0xa2, 0x02, 0x03, 0x46, 0x50, 0x58, 0xaa, 0x02, 0x10, 0x46, 0x6c, + 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0xca, 0x02, + 0x10, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, + 0x73, 0xe2, 0x02, 0x1c, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, + 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, + 0xea, 0x02, 0x11, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, 0x3a, 0x50, 0x6c, 0x75, + 0x67, 0x69, 0x6e, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -416,18 +443,21 @@ var file_flyteidl_plugins_ray_proto_goTypes = []interface{}{ (*WorkerGroupSpec)(nil), // 3: flyteidl.plugins.WorkerGroupSpec nil, // 4: flyteidl.plugins.HeadGroupSpec.RayStartParamsEntry nil, // 5: flyteidl.plugins.WorkerGroupSpec.RayStartParamsEntry + (*core.Resources)(nil), // 6: flyteidl.core.Resources } var file_flyteidl_plugins_ray_proto_depIdxs = []int32{ 1, // 0: flyteidl.plugins.RayJob.ray_cluster:type_name -> flyteidl.plugins.RayCluster 2, // 1: flyteidl.plugins.RayCluster.head_group_spec:type_name -> flyteidl.plugins.HeadGroupSpec 3, // 2: flyteidl.plugins.RayCluster.worker_group_spec:type_name -> flyteidl.plugins.WorkerGroupSpec 4, // 3: flyteidl.plugins.HeadGroupSpec.ray_start_params:type_name -> flyteidl.plugins.HeadGroupSpec.RayStartParamsEntry - 5, // 4: flyteidl.plugins.WorkerGroupSpec.ray_start_params:type_name -> flyteidl.plugins.WorkerGroupSpec.RayStartParamsEntry - 5, // [5:5] is the sub-list for method output_type - 5, // [5:5] is the sub-list for method input_type - 5, // [5:5] is the sub-list for extension type_name - 5, // [5:5] is the sub-list for extension extendee - 0, // [0:5] is the sub-list for field type_name + 6, // 4: flyteidl.plugins.HeadGroupSpec.resources:type_name -> flyteidl.core.Resources + 5, // 5: flyteidl.plugins.WorkerGroupSpec.ray_start_params:type_name -> flyteidl.plugins.WorkerGroupSpec.RayStartParamsEntry + 6, // 6: flyteidl.plugins.WorkerGroupSpec.resources:type_name -> flyteidl.core.Resources + 7, // [7:7] is the sub-list for method output_type + 7, // [7:7] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name } func init() { file_flyteidl_plugins_ray_proto_init() } diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.py index af809dcfcd1..35cdb3dfdea 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.py @@ -11,9 +11,10 @@ _sym_db = _symbol_database.Default() +from flyteidl.core import tasks_pb2 as flyteidl_dot_core_dot_tasks__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lyteidl/plugins/ray.proto\x12\x10\x66lyteidl.plugins\"\x92\x02\n\x06RayJob\x12=\n\x0bray_cluster\x18\x01 \x01(\x0b\x32\x1c.flyteidl.plugins.RayClusterR\nrayCluster\x12#\n\x0bruntime_env\x18\x02 \x01(\tB\x02\x18\x01R\nruntimeEnv\x12=\n\x1bshutdown_after_job_finishes\x18\x03 \x01(\x08R\x18shutdownAfterJobFinishes\x12;\n\x1attl_seconds_after_finished\x18\x04 \x01(\x05R\x17ttlSecondsAfterFinished\x12(\n\x10runtime_env_yaml\x18\x05 \x01(\tR\x0eruntimeEnvYaml\"\xd3\x01\n\nRayCluster\x12G\n\x0fhead_group_spec\x18\x01 \x01(\x0b\x32\x1f.flyteidl.plugins.HeadGroupSpecR\rheadGroupSpec\x12M\n\x11worker_group_spec\x18\x02 \x03(\x0b\x32!.flyteidl.plugins.WorkerGroupSpecR\x0fworkerGroupSpec\x12-\n\x12\x65nable_autoscaling\x18\x03 \x01(\x08R\x11\x65nableAutoscaling\"\xb1\x01\n\rHeadGroupSpec\x12]\n\x10ray_start_params\x18\x01 \x03(\x0b\x32\x33.flyteidl.plugins.HeadGroupSpec.RayStartParamsEntryR\x0erayStartParams\x1a\x41\n\x13RayStartParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xb6\x02\n\x0fWorkerGroupSpec\x12\x1d\n\ngroup_name\x18\x01 \x01(\tR\tgroupName\x12\x1a\n\x08replicas\x18\x02 \x01(\x05R\x08replicas\x12!\n\x0cmin_replicas\x18\x03 \x01(\x05R\x0bminReplicas\x12!\n\x0cmax_replicas\x18\x04 \x01(\x05R\x0bmaxReplicas\x12_\n\x10ray_start_params\x18\x05 \x03(\x0b\x32\x35.flyteidl.plugins.WorkerGroupSpec.RayStartParamsEntryR\x0erayStartParams\x1a\x41\n\x13RayStartParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\xc0\x01\n\x14\x63om.flyteidl.pluginsB\x08RayProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lyteidl/plugins/ray.proto\x12\x10\x66lyteidl.plugins\x1a\x19\x66lyteidl/core/tasks.proto\"\x92\x02\n\x06RayJob\x12=\n\x0bray_cluster\x18\x01 \x01(\x0b\x32\x1c.flyteidl.plugins.RayClusterR\nrayCluster\x12#\n\x0bruntime_env\x18\x02 \x01(\tB\x02\x18\x01R\nruntimeEnv\x12=\n\x1bshutdown_after_job_finishes\x18\x03 \x01(\x08R\x18shutdownAfterJobFinishes\x12;\n\x1attl_seconds_after_finished\x18\x04 \x01(\x05R\x17ttlSecondsAfterFinished\x12(\n\x10runtime_env_yaml\x18\x05 \x01(\tR\x0eruntimeEnvYaml\"\xd3\x01\n\nRayCluster\x12G\n\x0fhead_group_spec\x18\x01 \x01(\x0b\x32\x1f.flyteidl.plugins.HeadGroupSpecR\rheadGroupSpec\x12M\n\x11worker_group_spec\x18\x02 \x03(\x0b\x32!.flyteidl.plugins.WorkerGroupSpecR\x0fworkerGroupSpec\x12-\n\x12\x65nable_autoscaling\x18\x03 \x01(\x08R\x11\x65nableAutoscaling\"\xe9\x01\n\rHeadGroupSpec\x12]\n\x10ray_start_params\x18\x01 \x03(\x0b\x32\x33.flyteidl.plugins.HeadGroupSpec.RayStartParamsEntryR\x0erayStartParams\x12\x36\n\tresources\x18\x02 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x1a\x41\n\x13RayStartParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xee\x02\n\x0fWorkerGroupSpec\x12\x1d\n\ngroup_name\x18\x01 \x01(\tR\tgroupName\x12\x1a\n\x08replicas\x18\x02 \x01(\x05R\x08replicas\x12!\n\x0cmin_replicas\x18\x03 \x01(\x05R\x0bminReplicas\x12!\n\x0cmax_replicas\x18\x04 \x01(\x05R\x0bmaxReplicas\x12_\n\x10ray_start_params\x18\x05 \x03(\x0b\x32\x35.flyteidl.plugins.WorkerGroupSpec.RayStartParamsEntryR\x0erayStartParams\x12\x36\n\tresources\x18\x06 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x1a\x41\n\x13RayStartParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\xc0\x01\n\x14\x63om.flyteidl.pluginsB\x08RayProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -28,16 +29,16 @@ _HEADGROUPSPEC_RAYSTARTPARAMSENTRY._serialized_options = b'8\001' _WORKERGROUPSPEC_RAYSTARTPARAMSENTRY._options = None _WORKERGROUPSPEC_RAYSTARTPARAMSENTRY._serialized_options = b'8\001' - _globals['_RAYJOB']._serialized_start=49 - _globals['_RAYJOB']._serialized_end=323 - _globals['_RAYCLUSTER']._serialized_start=326 - _globals['_RAYCLUSTER']._serialized_end=537 - _globals['_HEADGROUPSPEC']._serialized_start=540 - _globals['_HEADGROUPSPEC']._serialized_end=717 - _globals['_HEADGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_start=652 - _globals['_HEADGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_end=717 - _globals['_WORKERGROUPSPEC']._serialized_start=720 - _globals['_WORKERGROUPSPEC']._serialized_end=1030 - _globals['_WORKERGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_start=652 - _globals['_WORKERGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_end=717 + _globals['_RAYJOB']._serialized_start=76 + _globals['_RAYJOB']._serialized_end=350 + _globals['_RAYCLUSTER']._serialized_start=353 + _globals['_RAYCLUSTER']._serialized_end=564 + _globals['_HEADGROUPSPEC']._serialized_start=567 + _globals['_HEADGROUPSPEC']._serialized_end=800 + _globals['_HEADGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_start=735 + _globals['_HEADGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_end=800 + _globals['_WORKERGROUPSPEC']._serialized_start=803 + _globals['_WORKERGROUPSPEC']._serialized_end=1169 + _globals['_WORKERGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_start=735 + _globals['_WORKERGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_end=800 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.pyi index 81d7382063a..3725d8e5d46 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.pyi @@ -1,3 +1,4 @@ +from flyteidl.core import tasks_pb2 as _tasks_pb2 from google.protobuf.internal import containers as _containers from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message @@ -30,7 +31,7 @@ class RayCluster(_message.Message): def __init__(self, head_group_spec: _Optional[_Union[HeadGroupSpec, _Mapping]] = ..., worker_group_spec: _Optional[_Iterable[_Union[WorkerGroupSpec, _Mapping]]] = ..., enable_autoscaling: bool = ...) -> None: ... class HeadGroupSpec(_message.Message): - __slots__ = ["ray_start_params"] + __slots__ = ["ray_start_params", "resources"] class RayStartParamsEntry(_message.Message): __slots__ = ["key", "value"] KEY_FIELD_NUMBER: _ClassVar[int] @@ -39,11 +40,13 @@ class HeadGroupSpec(_message.Message): value: str def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... RAY_START_PARAMS_FIELD_NUMBER: _ClassVar[int] + RESOURCES_FIELD_NUMBER: _ClassVar[int] ray_start_params: _containers.ScalarMap[str, str] - def __init__(self, ray_start_params: _Optional[_Mapping[str, str]] = ...) -> None: ... + resources: _tasks_pb2.Resources + def __init__(self, ray_start_params: _Optional[_Mapping[str, str]] = ..., resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ...) -> None: ... class WorkerGroupSpec(_message.Message): - __slots__ = ["group_name", "replicas", "min_replicas", "max_replicas", "ray_start_params"] + __slots__ = ["group_name", "replicas", "min_replicas", "max_replicas", "ray_start_params", "resources"] class RayStartParamsEntry(_message.Message): __slots__ = ["key", "value"] KEY_FIELD_NUMBER: _ClassVar[int] @@ -56,9 +59,11 @@ class WorkerGroupSpec(_message.Message): MIN_REPLICAS_FIELD_NUMBER: _ClassVar[int] MAX_REPLICAS_FIELD_NUMBER: _ClassVar[int] RAY_START_PARAMS_FIELD_NUMBER: _ClassVar[int] + RESOURCES_FIELD_NUMBER: _ClassVar[int] group_name: str replicas: int min_replicas: int max_replicas: int ray_start_params: _containers.ScalarMap[str, str] - def __init__(self, group_name: _Optional[str] = ..., replicas: _Optional[int] = ..., min_replicas: _Optional[int] = ..., max_replicas: _Optional[int] = ..., ray_start_params: _Optional[_Mapping[str, str]] = ...) -> None: ... + resources: _tasks_pb2.Resources + def __init__(self, group_name: _Optional[str] = ..., replicas: _Optional[int] = ..., min_replicas: _Optional[int] = ..., max_replicas: _Optional[int] = ..., ray_start_params: _Optional[_Mapping[str, str]] = ..., resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ...) -> None: ... diff --git a/flyteidl/gen/pb_rust/flyteidl.plugins.rs b/flyteidl/gen/pb_rust/flyteidl.plugins.rs index 0252c9d882d..ddfcba6e8bf 100644 --- a/flyteidl/gen/pb_rust/flyteidl.plugins.rs +++ b/flyteidl/gen/pb_rust/flyteidl.plugins.rs @@ -255,6 +255,9 @@ pub struct HeadGroupSpec { /// Refer to #[prost(map="string, string", tag="1")] pub ray_start_params: ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, + /// Resource specification for ray head pod + #[prost(message, optional, tag="2")] + pub resources: ::core::option::Option, } /// WorkerGroupSpec are the specs for the worker pods #[allow(clippy::derive_partial_eq_without_eq)] @@ -276,6 +279,9 @@ pub struct WorkerGroupSpec { /// Refer to #[prost(map="string, string", tag="5")] pub ray_start_params: ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, + /// Resource specification for ray worker pods + #[prost(message, optional, tag="6")] + pub resources: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, Copy, PartialEq, ::prost::Message)] diff --git a/flyteidl/protos/flyteidl/plugins/ray.proto b/flyteidl/protos/flyteidl/plugins/ray.proto index c20c6360e7f..749f16932d9 100644 --- a/flyteidl/protos/flyteidl/plugins/ray.proto +++ b/flyteidl/protos/flyteidl/plugins/ray.proto @@ -4,6 +4,8 @@ package flyteidl.plugins; option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +import "flyteidl/core/tasks.proto"; + // RayJobSpec defines the desired state of RayJob message RayJob { // RayClusterSpec is the cluster template to run the job @@ -35,6 +37,8 @@ message HeadGroupSpec { // Optional. RayStartParams are the params of the start command: address, object-store-memory. // Refer to https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-start map ray_start_params = 1; + // Resource specification for ray head pod + core.Resources resources = 2; } // WorkerGroupSpec are the specs for the worker pods @@ -50,4 +54,6 @@ message WorkerGroupSpec { // Optional. RayStartParams are the params of the start command: address, object-store-memory. // Refer to https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-start map ray_start_params = 5; + // Resource specification for ray worker pods + core.Resources resources = 6; } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 95a87f4efab..e8de072b90b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -1,625 +1,654 @@ package ray import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "regexp" - "strconv" - "strings" - "time" - - rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" - "gopkg.in/yaml.v2" - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/kubernetes/scheme" - "sigs.k8s.io/controller-runtime/pkg/client" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - flyteerr "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" - pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "regexp" + "strconv" + "strings" + "time" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "gopkg.in/yaml.v2" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + flyteerr "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" ) const ( - rayStateMountPath = "/tmp/ray" - defaultRayStateVolName = "system-ray-state" - rayTaskType = "ray" - KindRayJob = "RayJob" - IncludeDashboard = "include-dashboard" - NodeIPAddress = "node-ip-address" - DashboardHost = "dashboard-host" - DisableUsageStatsStartParameter = "disable-usage-stats" - DisableUsageStatsStartParameterVal = "true" + rayStateMountPath = "/tmp/ray" + defaultRayStateVolName = "system-ray-state" + rayTaskType = "ray" + KindRayJob = "RayJob" + IncludeDashboard = "include-dashboard" + NodeIPAddress = "node-ip-address" + DashboardHost = "dashboard-host" + DisableUsageStatsStartParameter = "disable-usage-stats" + DisableUsageStatsStartParameterVal = "true" ) var logTemplateRegexes = struct { - RayClusterName *regexp.Regexp - RayJobID *regexp.Regexp + RayClusterName *regexp.Regexp + RayJobID *regexp.Regexp }{ - tasklog.MustCreateRegex("rayClusterName"), - tasklog.MustCreateRegex("rayJobID"), + tasklog.MustCreateRegex("rayClusterName"), + tasklog.MustCreateRegex("rayJobID"), } type rayJobResourceHandler struct{} func (rayJobResourceHandler) GetProperties() k8s.PluginProperties { - return k8s.PluginProperties{} + return k8s.PluginProperties{} } // BuildResource Creates a new ray job resource func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { - taskTemplate, err := taskCtx.TaskReader().Read(ctx) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error()) - } else if taskTemplate == nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") - } - - rayJob := plugins.RayJob{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &rayJob) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) - } - - podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) - } - - var primaryContainer *v1.Container - var primaryContainerIdx int - for idx, c := range podSpec.Containers { - if c.Name == primaryContainerName { - c := c - primaryContainer = &c - primaryContainerIdx = idx - break - } - } - - if primaryContainer == nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to get primary container from the pod: [%v]", err.Error()) - } - - cfg := GetConfig() - - headNodeRayStartParams := make(map[string]string) - if rayJob.RayCluster.HeadGroupSpec != nil && rayJob.RayCluster.HeadGroupSpec.RayStartParams != nil { - headNodeRayStartParams = rayJob.RayCluster.HeadGroupSpec.RayStartParams - } else if headNode := cfg.Defaults.HeadNode; len(headNode.StartParameters) > 0 { - headNodeRayStartParams = headNode.StartParameters - } - - if _, exist := headNodeRayStartParams[IncludeDashboard]; !exist { - headNodeRayStartParams[IncludeDashboard] = strconv.FormatBool(GetConfig().IncludeDashboard) - } - - if _, exist := headNodeRayStartParams[NodeIPAddress]; !exist { - headNodeRayStartParams[NodeIPAddress] = cfg.Defaults.HeadNode.IPAddress - } - - if _, exist := headNodeRayStartParams[DashboardHost]; !exist { - headNodeRayStartParams[DashboardHost] = cfg.DashboardHost - } - - if _, exists := headNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { - headNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal - } - - podSpec.ServiceAccountName = cfg.ServiceAccount - - headPodSpec := podSpec.DeepCopy() - - rayjob, err := constructRayJob(taskCtx, rayJob, objectMeta, *podSpec, headPodSpec, headNodeRayStartParams, primaryContainerIdx, *primaryContainer) - - return rayjob, err + taskTemplate, err := taskCtx.TaskReader().Read(ctx) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error()) + } else if taskTemplate == nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") + } + + rayJob := plugins.RayJob{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &rayJob) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) + } + + var primaryContainer *v1.Container + var primaryContainerIdx int + for idx, c := range podSpec.Containers { + if c.Name == primaryContainerName { + c := c + primaryContainer = &c + primaryContainerIdx = idx + break + } + } + + if primaryContainer == nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to get primary container from the pod: [%v]", err.Error()) + } + + cfg := GetConfig() + + headNodeRayStartParams := make(map[string]string) + if rayJob.RayCluster.HeadGroupSpec != nil && rayJob.RayCluster.HeadGroupSpec.RayStartParams != nil { + headNodeRayStartParams = rayJob.RayCluster.HeadGroupSpec.RayStartParams + } else if headNode := cfg.Defaults.HeadNode; len(headNode.StartParameters) > 0 { + headNodeRayStartParams = headNode.StartParameters + } + + if _, exist := headNodeRayStartParams[IncludeDashboard]; !exist { + headNodeRayStartParams[IncludeDashboard] = strconv.FormatBool(GetConfig().IncludeDashboard) + } + + if _, exist := headNodeRayStartParams[NodeIPAddress]; !exist { + headNodeRayStartParams[NodeIPAddress] = cfg.Defaults.HeadNode.IPAddress + } + + if _, exist := headNodeRayStartParams[DashboardHost]; !exist { + headNodeRayStartParams[DashboardHost] = cfg.DashboardHost + } + + if _, exists := headNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { + headNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal + } + + podSpec.ServiceAccountName = cfg.ServiceAccount + + rayjob, err := constructRayJob(taskCtx, &rayJob, objectMeta, *podSpec, headNodeRayStartParams, primaryContainerIdx, *primaryContainer) + + return rayjob, err } -func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.RayJob, objectMeta *metav1.ObjectMeta, podSpec v1.PodSpec, headPodSpec *v1.PodSpec, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) (*rayv1.RayJob, error) { - enableIngress := true - cfg := GetConfig() - rayClusterSpec := rayv1.RayClusterSpec{ - HeadGroupSpec: rayv1.HeadGroupSpec{ - Template: buildHeadPodTemplate( - &headPodSpec.Containers[primaryContainerIdx], - headPodSpec, - objectMeta, - taskCtx, - ), - ServiceType: v1.ServiceType(cfg.ServiceType), - EnableIngress: &enableIngress, - RayStartParams: headNodeRayStartParams, - }, - WorkerGroupSpecs: []rayv1.WorkerGroupSpec{}, - EnableInTreeAutoscaling: &rayJob.RayCluster.EnableAutoscaling, - } - - for _, spec := range rayJob.RayCluster.WorkerGroupSpec { - workerPodSpec := podSpec.DeepCopy() - workerPodTemplate := buildWorkerPodTemplate( - &workerPodSpec.Containers[primaryContainerIdx], - workerPodSpec, - objectMeta, - taskCtx, - ) - - workerNodeRayStartParams := make(map[string]string) - if spec.RayStartParams != nil { - workerNodeRayStartParams = spec.RayStartParams - } else if workerNode := cfg.Defaults.WorkerNode; len(workerNode.StartParameters) > 0 { - workerNodeRayStartParams = workerNode.StartParameters - } - - if _, exist := workerNodeRayStartParams[NodeIPAddress]; !exist { - workerNodeRayStartParams[NodeIPAddress] = cfg.Defaults.WorkerNode.IPAddress - } - - if _, exists := workerNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { - workerNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal - } - - minReplicas := spec.MinReplicas - if minReplicas > spec.Replicas { - minReplicas = spec.Replicas - } - maxReplicas := spec.MaxReplicas - if maxReplicas < spec.Replicas { - maxReplicas = spec.Replicas - } - - workerNodeSpec := rayv1.WorkerGroupSpec{ - GroupName: spec.GroupName, - MinReplicas: &minReplicas, - MaxReplicas: &maxReplicas, - Replicas: &spec.Replicas, - RayStartParams: workerNodeRayStartParams, - Template: workerPodTemplate, - } - - rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec) - } - - serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) - if len(serviceAccountName) == 0 { - serviceAccountName = cfg.ServiceAccount - } - - rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName - for index := range rayClusterSpec.WorkerGroupSpecs { - rayClusterSpec.WorkerGroupSpecs[index].Template.Spec.ServiceAccountName = serviceAccountName - } - - shutdownAfterJobFinishes := cfg.ShutdownAfterJobFinishes - ttlSecondsAfterFinished := &cfg.TTLSecondsAfterFinished - if rayJob.ShutdownAfterJobFinishes { - shutdownAfterJobFinishes = true - ttlSecondsAfterFinished = &rayJob.TtlSecondsAfterFinished - } - - submitterPodTemplate := buildSubmitterPodTemplate(headPodSpec, objectMeta, taskCtx) - - // TODO: This is for backward compatibility. Remove this block once runtime_env is removed from ray proto. - var err error - var runtimeEnvYaml string - runtimeEnvYaml = rayJob.RuntimeEnvYaml - // If runtime_env exists but runtime_env_yaml does not, convert runtime_env to runtime_env_yaml - if rayJob.RuntimeEnv != "" && rayJob.RuntimeEnvYaml == "" { - runtimeEnvYaml, err = convertBase64RuntimeEnvToYaml(rayJob.RuntimeEnv) - if err != nil { - return nil, err - } - } - - jobSpec := rayv1.RayJobSpec{ - RayClusterSpec: &rayClusterSpec, - Entrypoint: strings.Join(primaryContainer.Args, " "), - ShutdownAfterJobFinishes: shutdownAfterJobFinishes, - TTLSecondsAfterFinished: *ttlSecondsAfterFinished, - RuntimeEnvYAML: runtimeEnvYaml, - SubmitterPodTemplate: &submitterPodTemplate, - } - - return &rayv1.RayJob{ - TypeMeta: metav1.TypeMeta{ - Kind: KindRayJob, - APIVersion: rayv1.SchemeGroupVersion.String(), - }, - Spec: jobSpec, - ObjectMeta: *objectMeta, - }, nil +func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob *plugins.RayJob, objectMeta *metav1.ObjectMeta, taskPodSpec v1.PodSpec, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) (*rayv1.RayJob, error) { + enableIngress := true + cfg := GetConfig() + + headPodSpec := taskPodSpec.DeepCopy() + headPodTemplate, err := buildHeadPodTemplate( + &headPodSpec.Containers[primaryContainerIdx], + headPodSpec, + objectMeta, + taskCtx, + rayJob.RayCluster.HeadGroupSpec, + ) + if err != nil { + return nil, err + } + + rayClusterSpec := rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: headPodTemplate, + ServiceType: v1.ServiceType(cfg.ServiceType), + EnableIngress: &enableIngress, + RayStartParams: headNodeRayStartParams, + }, + WorkerGroupSpecs: []rayv1.WorkerGroupSpec{}, + EnableInTreeAutoscaling: &rayJob.RayCluster.EnableAutoscaling, + } + + for _, spec := range rayJob.RayCluster.WorkerGroupSpec { + workerPodSpec := taskPodSpec.DeepCopy() + workerPodTemplate, err := buildWorkerPodTemplate( + &workerPodSpec.Containers[primaryContainerIdx], + workerPodSpec, + objectMeta, + taskCtx, + spec, + ) + if err != nil { + return nil, err + } + + workerNodeRayStartParams := make(map[string]string) + if spec.RayStartParams != nil { + workerNodeRayStartParams = spec.RayStartParams + } else if workerNode := cfg.Defaults.WorkerNode; len(workerNode.StartParameters) > 0 { + workerNodeRayStartParams = workerNode.StartParameters + } + + if _, exist := workerNodeRayStartParams[NodeIPAddress]; !exist { + workerNodeRayStartParams[NodeIPAddress] = cfg.Defaults.WorkerNode.IPAddress + } + + if _, exists := workerNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { + workerNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal + } + + minReplicas := spec.MinReplicas + if minReplicas > spec.Replicas { + minReplicas = spec.Replicas + } + maxReplicas := spec.MaxReplicas + if maxReplicas < spec.Replicas { + maxReplicas = spec.Replicas + } + + workerNodeSpec := rayv1.WorkerGroupSpec{ + GroupName: spec.GroupName, + MinReplicas: &minReplicas, + MaxReplicas: &maxReplicas, + Replicas: &spec.Replicas, + RayStartParams: workerNodeRayStartParams, + Template: workerPodTemplate, + } + + rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec) + } + + serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + if len(serviceAccountName) == 0 { + serviceAccountName = cfg.ServiceAccount + } + + rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName + for index := range rayClusterSpec.WorkerGroupSpecs { + rayClusterSpec.WorkerGroupSpecs[index].Template.Spec.ServiceAccountName = serviceAccountName + } + + shutdownAfterJobFinishes := cfg.ShutdownAfterJobFinishes + ttlSecondsAfterFinished := &cfg.TTLSecondsAfterFinished + if rayJob.ShutdownAfterJobFinishes { + shutdownAfterJobFinishes = true + ttlSecondsAfterFinished = &rayJob.TtlSecondsAfterFinished + } + + submitterPodSpec := taskPodSpec.DeepCopy() + submitterPodTemplate := buildSubmitterPodTemplate(submitterPodSpec, objectMeta, taskCtx) + + // TODO: This is for backward compatibility. Remove this block once runtime_env is removed from ray proto. + var runtimeEnvYaml string + runtimeEnvYaml = rayJob.RuntimeEnvYaml + // If runtime_env exists but runtime_env_yaml does not, convert runtime_env to runtime_env_yaml + if rayJob.RuntimeEnv != "" && rayJob.RuntimeEnvYaml == "" { + runtimeEnvYaml, err = convertBase64RuntimeEnvToYaml(rayJob.RuntimeEnv) + if err != nil { + return nil, err + } + } + + jobSpec := rayv1.RayJobSpec{ + RayClusterSpec: &rayClusterSpec, + Entrypoint: strings.Join(primaryContainer.Args, " "), + ShutdownAfterJobFinishes: shutdownAfterJobFinishes, + TTLSecondsAfterFinished: *ttlSecondsAfterFinished, + RuntimeEnvYAML: runtimeEnvYaml, + SubmitterPodTemplate: &submitterPodTemplate, + } + + return &rayv1.RayJob{ + TypeMeta: metav1.TypeMeta{ + Kind: KindRayJob, + APIVersion: rayv1.SchemeGroupVersion.String(), + }, + Spec: jobSpec, + ObjectMeta: *objectMeta, + }, nil } func convertBase64RuntimeEnvToYaml(s string) (string, error) { - // Decode from base64 - data, err := base64.StdEncoding.DecodeString(s) - if err != nil { - return "", err - } - - // Unmarshal JSON - var obj map[string]interface{} - err = json.Unmarshal(data, &obj) - if err != nil { - return "", err - } - - // Convert to YAML - y, err := yaml.Marshal(&obj) - if err != nil { - return "", err - } - - return string(y), nil + // Decode from base64 + data, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return "", err + } + + // Unmarshal JSON + var obj map[string]interface{} + err = json.Unmarshal(data, &obj) + if err != nil { + return "", err + } + + // Convert to YAML + y, err := yaml.Marshal(&obj) + if err != nil { + return "", err + } + + return string(y), nil } func injectLogsSidecar(primaryContainer *v1.Container, podSpec *v1.PodSpec) { - cfg := GetConfig() - if cfg.LogsSidecar == nil { - return - } - sidecar := cfg.LogsSidecar.DeepCopy() - - // Ray logs integration - var rayStateVolMount *v1.VolumeMount - // Look for an existing volume mount on the primary container, mounted at /tmp/ray - for _, vm := range primaryContainer.VolumeMounts { - if vm.MountPath == rayStateMountPath { - vm := vm - rayStateVolMount = &vm - break - } - } - // No existing volume mount exists at /tmp/ray. We create a new volume and volume - // mount and add it to the pod and container specs respectively - if rayStateVolMount == nil { - vol := v1.Volume{ - Name: defaultRayStateVolName, - VolumeSource: v1.VolumeSource{ - EmptyDir: &v1.EmptyDirVolumeSource{}, - }, - } - podSpec.Volumes = append(podSpec.Volumes, vol) - volMount := v1.VolumeMount{ - Name: defaultRayStateVolName, - MountPath: rayStateMountPath, - } - primaryContainer.VolumeMounts = append(primaryContainer.VolumeMounts, volMount) - rayStateVolMount = &volMount - } - // We need to mirror the ray state volume mount into the sidecar as readonly, - // so that we can read the logs written by the head node. - readOnlyRayStateVolMount := *rayStateVolMount.DeepCopy() - readOnlyRayStateVolMount.ReadOnly = true - - // Update volume mounts on sidecar - // If one already exists with the desired mount path, simply replace it. Otherwise, - // add it to sidecar's volume mounts. - foundExistingSidecarVolMount := false - for idx, vm := range sidecar.VolumeMounts { - if vm.MountPath == rayStateMountPath { - foundExistingSidecarVolMount = true - sidecar.VolumeMounts[idx] = readOnlyRayStateVolMount - } - } - if !foundExistingSidecarVolMount { - sidecar.VolumeMounts = append(sidecar.VolumeMounts, readOnlyRayStateVolMount) - } - - // Add sidecar to containers - podSpec.Containers = append(podSpec.Containers, *sidecar) + cfg := GetConfig() + if cfg.LogsSidecar == nil { + return + } + sidecar := cfg.LogsSidecar.DeepCopy() + + // Ray logs integration + var rayStateVolMount *v1.VolumeMount + // Look for an existing volume mount on the primary container, mounted at /tmp/ray + for _, vm := range primaryContainer.VolumeMounts { + if vm.MountPath == rayStateMountPath { + vm := vm + rayStateVolMount = &vm + break + } + } + // No existing volume mount exists at /tmp/ray. We create a new volume and volume + // mount and add it to the pod and container specs respectively + if rayStateVolMount == nil { + vol := v1.Volume{ + Name: defaultRayStateVolName, + VolumeSource: v1.VolumeSource{ + EmptyDir: &v1.EmptyDirVolumeSource{}, + }, + } + podSpec.Volumes = append(podSpec.Volumes, vol) + volMount := v1.VolumeMount{ + Name: defaultRayStateVolName, + MountPath: rayStateMountPath, + } + primaryContainer.VolumeMounts = append(primaryContainer.VolumeMounts, volMount) + rayStateVolMount = &volMount + } + // We need to mirror the ray state volume mount into the sidecar as readonly, + // so that we can read the logs written by the head node. + readOnlyRayStateVolMount := *rayStateVolMount.DeepCopy() + readOnlyRayStateVolMount.ReadOnly = true + + // Update volume mounts on sidecar + // If one already exists with the desired mount path, simply replace it. Otherwise, + // add it to sidecar's volume mounts. + foundExistingSidecarVolMount := false + for idx, vm := range sidecar.VolumeMounts { + if vm.MountPath == rayStateMountPath { + foundExistingSidecarVolMount = true + sidecar.VolumeMounts[idx] = readOnlyRayStateVolMount + } + } + if !foundExistingSidecarVolMount { + sidecar.VolumeMounts = append(sidecar.VolumeMounts, readOnlyRayStateVolMount) + } + + // Add sidecar to containers + podSpec.Containers = append(podSpec.Containers, *sidecar) } -func buildHeadPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { - // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97 - // They should always be the same, so we could hard code here. - primaryContainer.Name = "ray-head" - - envs := []v1.EnvVar{ - { - Name: "MY_POD_IP", - ValueFrom: &v1.EnvVarSource{ - FieldRef: &v1.ObjectFieldSelector{ - FieldPath: "status.podIP", - }, - }, - }, - } - - primaryContainer.Args = []string{} - - primaryContainer.Env = append(primaryContainer.Env, envs...) - - ports := []v1.ContainerPort{ - { - Name: "redis", - ContainerPort: 6379, - }, - { - Name: "head", - ContainerPort: 10001, - }, - { - Name: "dashboard", - ContainerPort: 8265, - }, - } - - primaryContainer.Ports = append(primaryContainer.Ports, ports...) - - // Inject a sidecar for capturing and exposing Ray job logs - injectLogsSidecar(primaryContainer, podSpec) - - podTemplateSpec := v1.PodTemplateSpec{ - Spec: *podSpec, - ObjectMeta: *objectMeta, - } - cfg := config.GetK8sPluginConfig() - podTemplateSpec.SetLabels(utils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) - podTemplateSpec.SetAnnotations(utils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) - return podTemplateSpec +func buildHeadPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, spec *plugins.HeadGroupSpec) (v1.PodTemplateSpec, error) { + // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97 + // They should always be the same, so we could hard code here. + primaryContainer.Name = "ray-head" + + envs := []v1.EnvVar{ + { + Name: "MY_POD_IP", + ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + FieldPath: "status.podIP", + }, + }, + }, + } + + primaryContainer.Args = []string{} + + primaryContainer.Env = append(primaryContainer.Env, envs...) + + ports := []v1.ContainerPort{ + { + Name: "redis", + ContainerPort: 6379, + }, + { + Name: "head", + ContainerPort: 10001, + }, + { + Name: "dashboard", + ContainerPort: 8265, + }, + } + + primaryContainer.Ports = append(primaryContainer.Ports, ports...) + + // Inject a sidecar for capturing and exposing Ray job logs + injectLogsSidecar(primaryContainer, podSpec) + + // Overwrite head pod taskResources if specified + if spec.Resources != nil { + res, err := flytek8s.ToK8sResourceRequirements(spec.Resources) + if err != nil { + return v1.PodTemplateSpec{}, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification HeadGroupSpec Resources[%v], Err: [%v]", spec.Resources, err.Error()) + } + + primaryContainer.Resources = *res + } + + podTemplateSpec := v1.PodTemplateSpec{ + Spec: *podSpec, + ObjectMeta: *objectMeta, + } + cfg := config.GetK8sPluginConfig() + podTemplateSpec.SetLabels(utils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) + podTemplateSpec.SetAnnotations(utils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) + return podTemplateSpec, nil } func buildSubmitterPodTemplate(podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { - podTemplateSpec := v1.PodTemplateSpec{ - Spec: *podSpec, - ObjectMeta: *objectMeta, - } - cfg := config.GetK8sPluginConfig() - podTemplateSpec.SetLabels(utils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) - podTemplateSpec.SetAnnotations(utils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) - return podTemplateSpec + submitterPodSpec := podSpec.DeepCopy() + + podTemplateSpec := v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *submitterPodSpec, + } + + cfg := config.GetK8sPluginConfig() + podTemplateSpec.SetLabels(utils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) + podTemplateSpec.SetAnnotations(utils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) + return podTemplateSpec } -func buildWorkerPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { - // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185 - // They should always be the same, so we could hard code here. - - primaryContainer.Name = "ray-worker" - - primaryContainer.Args = []string{} - - envs := []v1.EnvVar{ - { - Name: "RAY_DISABLE_DOCKER_CPU_WARNING", - Value: "1", - }, - { - Name: "TYPE", - Value: "worker", - }, - { - Name: "CPU_REQUEST", - ValueFrom: &v1.EnvVarSource{ - ResourceFieldRef: &v1.ResourceFieldSelector{ - ContainerName: "ray-worker", - Resource: "requests.cpu", - }, - }, - }, - { - Name: "CPU_LIMITS", - ValueFrom: &v1.EnvVarSource{ - ResourceFieldRef: &v1.ResourceFieldSelector{ - ContainerName: "ray-worker", - Resource: "limits.cpu", - }, - }, - }, - { - Name: "MEMORY_REQUESTS", - ValueFrom: &v1.EnvVarSource{ - ResourceFieldRef: &v1.ResourceFieldSelector{ - ContainerName: "ray-worker", - Resource: "requests.cpu", - }, - }, - }, - { - Name: "MEMORY_LIMITS", - ValueFrom: &v1.EnvVarSource{ - ResourceFieldRef: &v1.ResourceFieldSelector{ - ContainerName: "ray-worker", - Resource: "limits.cpu", - }, - }, - }, - { - Name: "MY_POD_NAME", - ValueFrom: &v1.EnvVarSource{ - FieldRef: &v1.ObjectFieldSelector{ - FieldPath: "metadata.name", - }, - }, - }, - { - Name: "MY_POD_IP", - ValueFrom: &v1.EnvVarSource{ - FieldRef: &v1.ObjectFieldSelector{ - FieldPath: "status.podIP", - }, - }, - }, - } - - primaryContainer.Env = append(primaryContainer.Env, envs...) - - primaryContainer.Lifecycle = &v1.Lifecycle{ - PreStop: &v1.LifecycleHandler{ - Exec: &v1.ExecAction{ - Command: []string{ - "/bin/sh", "-c", "ray stop", - }, - }, - }, - } - - ports := []v1.ContainerPort{ - { - Name: "redis", - ContainerPort: 6379, - }, - { - Name: "head", - ContainerPort: 10001, - }, - { - Name: "dashboard", - ContainerPort: 8265, - }, - } - primaryContainer.Ports = append(primaryContainer.Ports, ports...) - - podTemplateSpec := v1.PodTemplateSpec{ - Spec: *podSpec, - ObjectMeta: *objectMetadata, - } - podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) - podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) - return podTemplateSpec +func buildWorkerPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, spec *plugins.WorkerGroupSpec) (v1.PodTemplateSpec, error) { + // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185 + // They should always be the same, so we could hard code here. + + primaryContainer.Name = "ray-worker" + + primaryContainer.Args = []string{} + + envs := []v1.EnvVar{ + { + Name: "RAY_DISABLE_DOCKER_CPU_WARNING", + Value: "1", + }, + { + Name: "TYPE", + Value: "worker", + }, + { + Name: "CPU_REQUEST", + ValueFrom: &v1.EnvVarSource{ + ResourceFieldRef: &v1.ResourceFieldSelector{ + ContainerName: "ray-worker", + Resource: "requests.cpu", + }, + }, + }, + { + Name: "CPU_LIMITS", + ValueFrom: &v1.EnvVarSource{ + ResourceFieldRef: &v1.ResourceFieldSelector{ + ContainerName: "ray-worker", + Resource: "limits.cpu", + }, + }, + }, + { + Name: "MEMORY_REQUESTS", + ValueFrom: &v1.EnvVarSource{ + ResourceFieldRef: &v1.ResourceFieldSelector{ + ContainerName: "ray-worker", + Resource: "requests.cpu", + }, + }, + }, + { + Name: "MEMORY_LIMITS", + ValueFrom: &v1.EnvVarSource{ + ResourceFieldRef: &v1.ResourceFieldSelector{ + ContainerName: "ray-worker", + Resource: "limits.cpu", + }, + }, + }, + { + Name: "MY_POD_NAME", + ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + FieldPath: "metadata.name", + }, + }, + }, + { + Name: "MY_POD_IP", + ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + FieldPath: "status.podIP", + }, + }, + }, + } + + primaryContainer.Env = append(primaryContainer.Env, envs...) + + primaryContainer.Lifecycle = &v1.Lifecycle{ + PreStop: &v1.LifecycleHandler{ + Exec: &v1.ExecAction{ + Command: []string{ + "/bin/sh", "-c", "ray stop", + }, + }, + }, + } + + ports := []v1.ContainerPort{ + { + Name: "redis", + ContainerPort: 6379, + }, + { + Name: "head", + ContainerPort: 10001, + }, + { + Name: "dashboard", + ContainerPort: 8265, + }, + } + primaryContainer.Ports = append(primaryContainer.Ports, ports...) + + // Overwrite worker pod taskResources if specified + if spec.Resources != nil { + res, err := flytek8s.ToK8sResourceRequirements(spec.Resources) + if err != nil { + return v1.PodTemplateSpec{}, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on WorkerGroupSpec Resources[%v], Err: [%v]", spec.Resources, err.Error()) + } + + primaryContainer.Resources = *res + } + + podTemplateSpec := v1.PodTemplateSpec{ + Spec: *podSpec, + ObjectMeta: *objectMetadata, + } + podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) + podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) + return podTemplateSpec, nil } func (rayJobResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionMetadata) (client.Object, error) { - return &rayv1.RayJob{ - TypeMeta: metav1.TypeMeta{ - Kind: KindRayJob, - APIVersion: rayv1.SchemeGroupVersion.String(), - }, - }, nil + return &rayv1.RayJob{ + TypeMeta: metav1.TypeMeta{ + Kind: KindRayJob, + APIVersion: rayv1.SchemeGroupVersion.String(), + }, + }, nil } func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginContext, rayJob *rayv1.RayJob) (*pluginsCore.TaskInfo, error) { - logPlugin, err := logs.InitializeLogPlugins(&logConfig) - if err != nil { - return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err) - } - - var taskLogs []*core.TaskLog - - taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() - input := tasklog.Input{ - Namespace: rayJob.Namespace, - TaskExecutionID: taskExecID, - ExtraTemplateVars: []tasklog.TemplateVar{}, - } - if rayJob.Status.JobId != "" { - input.ExtraTemplateVars = append( - input.ExtraTemplateVars, - tasklog.TemplateVar{ - Regex: logTemplateRegexes.RayJobID, - Value: rayJob.Status.JobId, - }, - ) - } - if rayJob.Status.RayClusterName != "" { - input.ExtraTemplateVars = append( - input.ExtraTemplateVars, - tasklog.TemplateVar{ - Regex: logTemplateRegexes.RayClusterName, - Value: rayJob.Status.RayClusterName, - }, - ) - } - - // TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs - // RayJob CRD does not include the name of the worker or head pod for now - logOutput, err := logPlugin.GetTaskLogs(input) - if err != nil { - return nil, fmt.Errorf("failed to generate task logs. Error: %w", err) - } - taskLogs = append(taskLogs, logOutput.TaskLogs...) - - // Handling for Ray Dashboard - dashboardURLTemplate := GetConfig().DashboardURLTemplate - if dashboardURLTemplate != nil && - rayJob.Status.DashboardURL != "" && - rayJob.Status.JobStatus == rayv1.JobStatusRunning { - dashboardURLOutput, err := dashboardURLTemplate.GetTaskLogs(input) - if err != nil { - return nil, fmt.Errorf("failed to generate Ray dashboard link. Error: %w", err) - } - taskLogs = append(taskLogs, dashboardURLOutput.TaskLogs...) - } - - return &pluginsCore.TaskInfo{Logs: taskLogs}, nil + logPlugin, err := logs.InitializeLogPlugins(&logConfig) + if err != nil { + return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err) + } + + var taskLogs []*core.TaskLog + + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() + input := tasklog.Input{ + Namespace: rayJob.Namespace, + TaskExecutionID: taskExecID, + ExtraTemplateVars: []tasklog.TemplateVar{}, + } + if rayJob.Status.JobId != "" { + input.ExtraTemplateVars = append( + input.ExtraTemplateVars, + tasklog.TemplateVar{ + Regex: logTemplateRegexes.RayJobID, + Value: rayJob.Status.JobId, + }, + ) + } + if rayJob.Status.RayClusterName != "" { + input.ExtraTemplateVars = append( + input.ExtraTemplateVars, + tasklog.TemplateVar{ + Regex: logTemplateRegexes.RayClusterName, + Value: rayJob.Status.RayClusterName, + }, + ) + } + + // TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs + // RayJob CRD does not include the name of the worker or head pod for now + logOutput, err := logPlugin.GetTaskLogs(input) + if err != nil { + return nil, fmt.Errorf("failed to generate task logs. Error: %w", err) + } + taskLogs = append(taskLogs, logOutput.TaskLogs...) + + // Handling for Ray Dashboard + dashboardURLTemplate := GetConfig().DashboardURLTemplate + if dashboardURLTemplate != nil && + rayJob.Status.DashboardURL != "" && + rayJob.Status.JobStatus == rayv1.JobStatusRunning { + dashboardURLOutput, err := dashboardURLTemplate.GetTaskLogs(input) + if err != nil { + return nil, fmt.Errorf("failed to generate Ray dashboard link. Error: %w", err) + } + taskLogs = append(taskLogs, dashboardURLOutput.TaskLogs...) + } + + return &pluginsCore.TaskInfo{Logs: taskLogs}, nil } func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { - rayJob := resource.(*rayv1.RayJob) - info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob) - if err != nil { - return pluginsCore.PhaseInfoUndefined, err - } - - if len(rayJob.Status.JobDeploymentStatus) == 0 { - return pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling", info), nil - } - - var phaseInfo pluginsCore.PhaseInfo - - // KubeRay creates a Ray cluster first, and then submits a Ray job to the cluster - switch rayJob.Status.JobDeploymentStatus { - case rayv1.JobDeploymentStatusInitializing: - phaseInfo, err = pluginsCore.PhaseInfoInitializing(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "cluster is creating", info), nil - case rayv1.JobDeploymentStatusRunning: - phaseInfo, err = pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil - case rayv1.JobDeploymentStatusComplete: - phaseInfo, err = pluginsCore.PhaseInfoSuccess(info), nil - case rayv1.JobDeploymentStatusSuspended: - phaseInfo, err = pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "Suspended", info), nil - case rayv1.JobDeploymentStatusSuspending: - phaseInfo, err = pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "Suspending", info), nil - case rayv1.JobDeploymentStatusFailed: - failInfo := fmt.Sprintf("Failed to run Ray job %s with error: [%s] %s", rayJob.Name, rayJob.Status.Reason, rayJob.Status.Message) - phaseInfo, err = pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, failInfo, info), nil - default: - // We already handle all known deployment status, so this should never happen unless a future version of ray - // introduced a new job status. - phaseInfo, err = pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job deployment status: %s", rayJob.Status.JobDeploymentStatus) - } - - phaseVersionUpdateErr := k8s.MaybeUpdatePhaseVersionFromPluginContext(&phaseInfo, &pluginContext) - if phaseVersionUpdateErr != nil { - return phaseInfo, phaseVersionUpdateErr - } - - return phaseInfo, err + rayJob := resource.(*rayv1.RayJob) + info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + if len(rayJob.Status.JobDeploymentStatus) == 0 { + return pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling", info), nil + } + + var phaseInfo pluginsCore.PhaseInfo + + // KubeRay creates a Ray cluster first, and then submits a Ray job to the cluster + switch rayJob.Status.JobDeploymentStatus { + case rayv1.JobDeploymentStatusInitializing: + phaseInfo, err = pluginsCore.PhaseInfoInitializing(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "cluster is creating", info), nil + case rayv1.JobDeploymentStatusRunning: + phaseInfo, err = pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil + case rayv1.JobDeploymentStatusComplete: + phaseInfo, err = pluginsCore.PhaseInfoSuccess(info), nil + case rayv1.JobDeploymentStatusFailed: + failInfo := fmt.Sprintf("Failed to run Ray job %s with error: [%s] %s", rayJob.Name, rayJob.Status.Reason, rayJob.Status.Message) + phaseInfo, err = pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, failInfo, info), nil + default: + // We already handle all known deployment status, so this should never happen unless a future version of ray + // introduced a new job status. + phaseInfo, err = pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job deployment status: %s", rayJob.Status.JobDeploymentStatus) + } + + phaseVersionUpdateErr := k8s.MaybeUpdatePhaseVersionFromPluginContext(&phaseInfo, &pluginContext) + if phaseVersionUpdateErr != nil { + return phaseInfo, phaseVersionUpdateErr + } + + return phaseInfo, err } func init() { - if err := rayv1.AddToScheme(scheme.Scheme); err != nil { - panic(err) - } - - pluginmachinery.PluginRegistry().RegisterK8sPlugin( - k8s.PluginEntry{ - ID: rayTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{rayTaskType}, - ResourceToWatch: &rayv1.RayJob{}, - Plugin: rayJobResourceHandler{}, - IsDefault: false, - CustomKubeClient: func(ctx context.Context) (pluginsCore.KubeClient, error) { - remoteConfig := GetConfig().RemoteClusterConfig - if !remoteConfig.Enabled { - // use controller-runtime KubeClient - return nil, nil - } - - kubeConfig, err := k8s.KubeClientConfig(remoteConfig.Endpoint, remoteConfig.Auth) - if err != nil { - return nil, err - } - - return k8s.NewDefaultKubeClient(kubeConfig) - }, - }) + if err := rayv1.AddToScheme(scheme.Scheme); err != nil { + panic(err) + } + + pluginmachinery.PluginRegistry().RegisterK8sPlugin( + k8s.PluginEntry{ + ID: rayTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{rayTaskType}, + ResourceToWatch: &rayv1.RayJob{}, + Plugin: rayJobResourceHandler{}, + IsDefault: false, + CustomKubeClient: func(ctx context.Context) (pluginsCore.KubeClient, error) { + remoteConfig := GetConfig().RemoteClusterConfig + if !remoteConfig.Enabled { + // use controller-runtime KubeClient + return nil, nil + } + + kubeConfig, err := k8s.KubeClientConfig(remoteConfig.Endpoint, remoteConfig.Auth) + if err != nil { + return nil, err + } + + return k8s.NewDefaultKubeClient(kubeConfig) + }, + }) } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 38b2f56785f..d2e35fb67d7 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -1,1105 +1,1201 @@ package ray import ( - "context" - "reflect" - "testing" - "time" - - structpb "github.com/golang/protobuf/ptypes/struct" - rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" - pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" - pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" - mocks2 "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + "context" + "reflect" + "testing" + "time" + + structpb "github.com/golang/protobuf/ptypes/struct" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + mocks2 "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" ) const ( - testImage = "image://" - serviceAccount = "ray_sa" + testImage = "image://" + serviceAccount = "ray_sa" ) var ( - dummyEnvVars = []*core.KeyValuePair{ - {Key: "Env_Var", Value: "Env_Val"}, - } - - testArgs = []string{ - "test-args", - } - - resourceRequirements = &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1000m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("100m"), - corev1.ResourceMemory: resource.MustParse("512Mi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - } - - workerGroupName = "worker-group" + dummyEnvVars = []*core.KeyValuePair{ + {Key: "Env_Var", Value: "Env_Val"}, + } + + testArgs = []string{ + "test-args", + } + + resourceRequirements = &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1000m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("100m"), + corev1.ResourceMemory: resource.MustParse("512Mi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + } + + workerGroupName = "worker-group" ) func transformRayJobToCustomObj(rayJob *plugins.RayJob) *structpb.Struct { - structObj, err := utils.MarshalObjToStruct(rayJob) - if err != nil { - panic(err) - } - return structObj + structObj, err := utils.MarshalObjToStruct(rayJob) + if err != nil { + panic(err) + } + return structObj } func transformPodSpecToTaskTemplateTarget(podSpec *corev1.PodSpec) *core.TaskTemplate_K8SPod { - structObj, err := utils.MarshalObjToStruct(&podSpec) - if err != nil { - panic(err) - } - return &core.TaskTemplate_K8SPod{ - K8SPod: &core.K8SPod{ - PodSpec: structObj, - }, - } + structObj, err := utils.MarshalObjToStruct(&podSpec) + if err != nil { + panic(err) + } + return &core.TaskTemplate_K8SPod{ + K8SPod: &core.K8SPod{ + PodSpec: structObj, + }, + } } func dummyRayCustomObj() *plugins.RayJob { - return &plugins.RayJob{ - RayCluster: &plugins.RayCluster{ - HeadGroupSpec: &plugins.HeadGroupSpec{RayStartParams: map[string]string{"num-cpus": "1"}}, - WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, - EnableAutoscaling: true, - }, - ShutdownAfterJobFinishes: true, - TtlSecondsAfterFinished: 120, - } + return &plugins.RayJob{ + RayCluster: &plugins.RayCluster{ + HeadGroupSpec: &plugins.HeadGroupSpec{RayStartParams: map[string]string{"num-cpus": "1"}}, + WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, + EnableAutoscaling: true, + }, + ShutdownAfterJobFinishes: true, + TtlSecondsAfterFinished: 120, + } } func dummyRayTaskTemplate(id string, rayJob *plugins.RayJob) *core.TaskTemplate { - return &core.TaskTemplate{ - Id: &core.Identifier{Name: id}, - Type: "container", - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Image: testImage, - Args: testArgs, - Env: dummyEnvVars, - }, - }, - Custom: transformRayJobToCustomObj(rayJob), - } + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "container", + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + Args: testArgs, + Env: dummyEnvVars, + }, + }, + Custom: transformRayJobToCustomObj(rayJob), + } } func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage, serviceAccount string) pluginsCore.TaskExecutionContext { - taskCtx := &mocks.TaskExecutionContext{} - inputReader := &pluginIOMocks.InputReader{} - inputReader.OnGetInputPrefixPath().Return("/input/prefix") - inputReader.OnGetInputPath().Return("/input") - inputReader.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) - taskCtx.OnInputReader().Return(inputReader) - - outputReader := &pluginIOMocks.OutputWriter{} - outputReader.OnGetOutputPath().Return("/data/outputs.pb") - outputReader.OnGetOutputPrefixPath().Return("/data/") - outputReader.OnGetRawOutputPrefix().Return("") - outputReader.OnGetCheckpointPrefix().Return("/checkpoint") - outputReader.OnGetPreviousCheckpointsPrefix().Return("/prev") - taskCtx.OnOutputWriter().Return(outputReader) - - taskReader := &mocks.TaskReader{} - taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) - taskCtx.OnTaskReader().Return(taskReader) - - tID := &mocks.TaskExecutionID{} - tID.OnGetID().Return(core.TaskExecutionIdentifier{ - NodeExecutionId: &core.NodeExecutionIdentifier{ - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my_name", - Project: "my_project", - Domain: "my_domain", - }, - }, - }) - tID.OnGetGeneratedName().Return("some-acceptable-name") - - overrides := &mocks.TaskOverrides{} - overrides.OnGetResources().Return(resources) - overrides.OnGetExtendedResources().Return(extendedResources) - overrides.OnGetContainerImage().Return(containerImage) - - taskExecutionMetadata := &mocks.TaskExecutionMetadata{} - taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) - taskExecutionMetadata.OnGetNamespace().Return("test-namespace") - taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"}) - taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) - taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ - Kind: "node", - Name: "blah", - }) - taskExecutionMetadata.OnIsInterruptible().Return(true) - taskExecutionMetadata.OnGetOverrides().Return(overrides) - taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) - taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) - taskExecutionMetadata.OnGetSecurityContext().Return(core.SecurityContext{ - RunAs: &core.Identity{K8SServiceAccount: serviceAccount}, - }) - taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) - taskExecutionMetadata.OnGetConsoleURL().Return("") - taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) - return taskCtx + taskCtx := &mocks.TaskExecutionContext{} + inputReader := &pluginIOMocks.InputReader{} + inputReader.OnGetInputPrefixPath().Return("/input/prefix") + inputReader.OnGetInputPath().Return("/input") + inputReader.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) + taskCtx.OnInputReader().Return(inputReader) + + outputReader := &pluginIOMocks.OutputWriter{} + outputReader.OnGetOutputPath().Return("/data/outputs.pb") + outputReader.OnGetOutputPrefixPath().Return("/data/") + outputReader.OnGetRawOutputPrefix().Return("") + outputReader.OnGetCheckpointPrefix().Return("/checkpoint") + outputReader.OnGetPreviousCheckpointsPrefix().Return("/prev") + taskCtx.OnOutputWriter().Return(outputReader) + + taskReader := &mocks.TaskReader{} + taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) + taskCtx.OnTaskReader().Return(taskReader) + + tID := &mocks.TaskExecutionID{} + tID.OnGetID().Return(core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + tID.OnGetGeneratedName().Return("some-acceptable-name") + + overrides := &mocks.TaskOverrides{} + overrides.OnGetResources().Return(resources) + overrides.OnGetExtendedResources().Return(extendedResources) + overrides.OnGetContainerImage().Return(containerImage) + + taskExecutionMetadata := &mocks.TaskExecutionMetadata{} + taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) + taskExecutionMetadata.OnGetNamespace().Return("test-namespace") + taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"}) + taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) + taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ + Kind: "node", + Name: "blah", + }) + taskExecutionMetadata.OnIsInterruptible().Return(true) + taskExecutionMetadata.OnGetOverrides().Return(overrides) + taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) + taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) + taskExecutionMetadata.OnGetSecurityContext().Return(core.SecurityContext{ + RunAs: &core.Identity{K8SServiceAccount: serviceAccount}, + }) + taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) + taskExecutionMetadata.OnGetConsoleURL().Return("") + taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + return taskCtx } func TestBuildResourceRay(t *testing.T) { - rayJobResourceHandler := rayJobResourceHandler{} - taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) - toleration := []corev1.Toleration{{ - Key: "storage", - Value: "dedicated", - Operator: corev1.TolerationOpExists, - Effect: corev1.TaintEffectNoSchedule, - }} - err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) - assert.Nil(t, err) - - rayCtx := dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount) - RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), rayCtx) - assert.Nil(t, err) - - assert.NotNil(t, RayResource) - ray, ok := RayResource.(*rayv1.RayJob) - assert.True(t, ok) - - assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, true) - assert.Equal(t, ray.Spec.ShutdownAfterJobFinishes, true) - assert.Equal(t, ray.Spec.TTLSecondsAfterFinished, int32(120)) - - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, - map[string]string{ - "dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true", - "node-ip-address": "$MY_POD_IP", "num-cpus": "1", - }) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) - - workerReplica := int32(3) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, workerReplica) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas, workerReplica) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, workerReplica) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) - - // Make sure the default service account is being used if SA is not provided in the task context - rayCtx = dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", "") - RayResource, err = rayJobResourceHandler.BuildResource(context.TODO(), rayCtx) - assert.Nil(t, err) - assert.NotNil(t, RayResource) - ray, ok = RayResource.(*rayv1.RayJob) - assert.True(t, ok) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, GetConfig().ServiceAccount) + rayJobResourceHandler := rayJobResourceHandler{} + taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) + toleration := []corev1.Toleration{{ + Key: "storage", + Value: "dedicated", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }} + err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) + assert.Nil(t, err) + + rayCtx := dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount) + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), rayCtx) + assert.Nil(t, err) + + assert.NotNil(t, RayResource) + ray, ok := RayResource.(*rayv1.RayJob) + assert.True(t, ok) + + assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, true) + assert.Equal(t, ray.Spec.ShutdownAfterJobFinishes, true) + assert.Equal(t, ray.Spec.TTLSecondsAfterFinished, int32(120)) + + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, + map[string]string{ + "dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true", + "node-ip-address": "$MY_POD_IP", "num-cpus": "1", + }) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) + + workerReplica := int32(3) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, workerReplica) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas, workerReplica) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, workerReplica) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) + + // Make sure the default service account is being used if SA is not provided in the task context + rayCtx = dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", "") + RayResource, err = rayJobResourceHandler.BuildResource(context.TODO(), rayCtx) + assert.Nil(t, err) + assert.NotNil(t, RayResource) + ray, ok = RayResource.(*rayv1.RayJob) + assert.True(t, ok) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, GetConfig().ServiceAccount) } func TestBuildResourceRayContainerImage(t *testing.T) { - assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) - - fixtures := []struct { - name string - resources *corev1.ResourceRequirements - containerImageOverride string - }{ - { - "without overrides", - &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - }, - "", - }, - { - "with overrides", - &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - }, - "container-image-override", - }, - } - - for _, f := range fixtures { - t.Run(f.name, func(t *testing.T) { - taskTemplate := dummyRayTaskTemplate("id", dummyRayCustomObj()) - taskContext := dummyRayTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride, serviceAccount) - rayJobResourceHandler := rayJobResourceHandler{} - r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) - assert.Nil(t, err) - assert.NotNil(t, r) - rayJob, ok := r.(*rayv1.RayJob) - assert.True(t, ok) - - var expectedContainerImage string - if len(f.containerImageOverride) > 0 { - expectedContainerImage = f.containerImageOverride - } else { - expectedContainerImage = testImage - } - - // Head node - headNodeSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec - assert.Equal(t, expectedContainerImage, headNodeSpec.Containers[0].Image) - - // Worker node - workerNodeSpec := rayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec - assert.Equal(t, expectedContainerImage, workerNodeSpec.Containers[0].Image) - }) - } + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) + + fixtures := []struct { + name string + resources *corev1.ResourceRequirements + containerImageOverride string + }{ + { + "without overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "", + }, + { + "with overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "container-image-override", + }, + } + + for _, f := range fixtures { + t.Run(f.name, func(t *testing.T) { + taskTemplate := dummyRayTaskTemplate("id", dummyRayCustomObj()) + taskContext := dummyRayTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride, serviceAccount) + rayJobResourceHandler := rayJobResourceHandler{} + r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + rayJob, ok := r.(*rayv1.RayJob) + assert.True(t, ok) + + var expectedContainerImage string + if len(f.containerImageOverride) > 0 { + expectedContainerImage = f.containerImageOverride + } else { + expectedContainerImage = testImage + } + + // Head node + headNodeSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec + assert.Equal(t, expectedContainerImage, headNodeSpec.Containers[0].Image) + + // Worker node + workerNodeSpec := rayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec + assert.Equal(t, expectedContainerImage, workerNodeSpec.Containers[0].Image) + }) + } } func TestBuildResourceRayExtendedResources(t *testing.T) { - assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ - GpuDeviceNodeLabel: "gpu-node-label", - GpuPartitionSizeNodeLabel: "gpu-partition-size", - GpuResourceName: flytek8s.ResourceNvidiaGPU, - })) - - params := []struct { - name string - resources *corev1.ResourceRequirements - extendedResourcesBase *core.ExtendedResources - extendedResourcesOverride *core.ExtendedResources - expectedNsr []corev1.NodeSelectorTerm - expectedTol []corev1.Toleration - }{ - { - "without overrides", - &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - }, - &core.ExtendedResources{ - GpuAccelerator: &core.GPUAccelerator{ - Device: "nvidia-tesla-t4", - }, - }, - nil, - []corev1.NodeSelectorTerm{ - { - MatchExpressions: []corev1.NodeSelectorRequirement{ - { - Key: "gpu-node-label", - Operator: corev1.NodeSelectorOpIn, - Values: []string{"nvidia-tesla-t4"}, - }, - }, - }, - }, - []corev1.Toleration{ - { - Key: "gpu-node-label", - Value: "nvidia-tesla-t4", - Operator: corev1.TolerationOpEqual, - Effect: corev1.TaintEffectNoSchedule, - }, - }, - }, - { - "with overrides", - &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - }, - &core.ExtendedResources{ - GpuAccelerator: &core.GPUAccelerator{ - Device: "nvidia-tesla-t4", - }, - }, - &core.ExtendedResources{ - GpuAccelerator: &core.GPUAccelerator{ - Device: "nvidia-tesla-a100", - PartitionSizeValue: &core.GPUAccelerator_PartitionSize{ - PartitionSize: "1g.5gb", - }, - }, - }, - []corev1.NodeSelectorTerm{ - { - MatchExpressions: []corev1.NodeSelectorRequirement{ - { - Key: "gpu-node-label", - Operator: corev1.NodeSelectorOpIn, - Values: []string{"nvidia-tesla-a100"}, - }, - { - Key: "gpu-partition-size", - Operator: corev1.NodeSelectorOpIn, - Values: []string{"1g.5gb"}, - }, - }, - }, - }, - []corev1.Toleration{ - { - Key: "gpu-node-label", - Value: "nvidia-tesla-a100", - Operator: corev1.TolerationOpEqual, - Effect: corev1.TaintEffectNoSchedule, - }, - { - Key: "gpu-partition-size", - Value: "1g.5gb", - Operator: corev1.TolerationOpEqual, - Effect: corev1.TaintEffectNoSchedule, - }, - }, - }, - } - - for _, p := range params { - t.Run(p.name, func(t *testing.T) { - taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) - taskTemplate.ExtendedResources = p.extendedResourcesBase - taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride, "", serviceAccount) - rayJobResourceHandler := rayJobResourceHandler{} - r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) - assert.Nil(t, err) - assert.NotNil(t, r) - rayJob, ok := r.(*rayv1.RayJob) - assert.True(t, ok) - - // Head node - headNodeSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec - assert.EqualValues( - t, - p.expectedNsr, - headNodeSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, - ) - assert.EqualValues( - t, - p.expectedTol, - headNodeSpec.Tolerations, - ) - - // Worker node - workerNodeSpec := rayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec - assert.EqualValues( - t, - p.expectedNsr, - workerNodeSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, - ) - assert.EqualValues( - t, - p.expectedTol, - workerNodeSpec.Tolerations, - ) - }) - } + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + GpuDeviceNodeLabel: "gpu-node-label", + GpuPartitionSizeNodeLabel: "gpu-partition-size", + GpuResourceName: flytek8s.ResourceNvidiaGPU, + })) + + params := []struct { + name string + resources *corev1.ResourceRequirements + extendedResourcesBase *core.ExtendedResources + extendedResourcesOverride *core.ExtendedResources + expectedNsr []corev1.NodeSelectorTerm + expectedTol []corev1.Toleration + }{ + { + "without overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + &core.ExtendedResources{ + GpuAccelerator: &core.GPUAccelerator{ + Device: "nvidia-tesla-t4", + }, + }, + nil, + []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "gpu-node-label", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"nvidia-tesla-t4"}, + }, + }, + }, + }, + []corev1.Toleration{ + { + Key: "gpu-node-label", + Value: "nvidia-tesla-t4", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + { + "with overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + &core.ExtendedResources{ + GpuAccelerator: &core.GPUAccelerator{ + Device: "nvidia-tesla-t4", + }, + }, + &core.ExtendedResources{ + GpuAccelerator: &core.GPUAccelerator{ + Device: "nvidia-tesla-a100", + PartitionSizeValue: &core.GPUAccelerator_PartitionSize{ + PartitionSize: "1g.5gb", + }, + }, + }, + []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "gpu-node-label", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"nvidia-tesla-a100"}, + }, + { + Key: "gpu-partition-size", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"1g.5gb"}, + }, + }, + }, + }, + []corev1.Toleration{ + { + Key: "gpu-node-label", + Value: "nvidia-tesla-a100", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + { + Key: "gpu-partition-size", + Value: "1g.5gb", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + } + + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) + taskTemplate.ExtendedResources = p.extendedResourcesBase + taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride, "", serviceAccount) + rayJobResourceHandler := rayJobResourceHandler{} + r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + rayJob, ok := r.(*rayv1.RayJob) + assert.True(t, ok) + + // Head node + headNodeSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec + assert.EqualValues( + t, + p.expectedNsr, + headNodeSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, + ) + assert.EqualValues( + t, + p.expectedTol, + headNodeSpec.Tolerations, + ) + + // Worker node + workerNodeSpec := rayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec + assert.EqualValues( + t, + p.expectedNsr, + workerNodeSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, + ) + assert.EqualValues( + t, + p.expectedTol, + workerNodeSpec.Tolerations, + ) + }) + } +} + +func TestBuildResourceRayCustomResources(t *testing.T) { + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) + + headResourceEntries := []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "10"}, + {Name: core.Resources_MEMORY, Value: "10Gi"}, + {Name: core.Resources_GPU, Value: "10"}, + } + headResources := &core.Resources{Requests: headResourceEntries, Limits: headResourceEntries} + + expectedHeadResources, err := flytek8s.ToK8sResourceRequirements(headResources) + require.NoError(t, err) + + workerResourceEntries := []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "20"}, + {Name: core.Resources_MEMORY, Value: "20Gi"}, + {Name: core.Resources_GPU, Value: "20"}, + } + workerResources := &core.Resources{Requests: workerResourceEntries, Limits: workerResourceEntries} + + expectedWorkerResources, err := flytek8s.ToK8sResourceRequirements(workerResources) + require.NoError(t, err) + + params := []struct { + name string + taskResources *corev1.ResourceRequirements + headResources *core.Resources + workerResources *core.Resources + expectedSubmitterResources *corev1.ResourceRequirements + expectedHeadResources *corev1.ResourceRequirements + expectedWorkerResources *corev1.ResourceRequirements + }{ + { + name: "task resources", + taskResources: resourceRequirements, + expectedSubmitterResources: resourceRequirements, + expectedHeadResources: resourceRequirements, + expectedWorkerResources: resourceRequirements, + }, + { + name: "custom worker and head resources", + taskResources: resourceRequirements, + headResources: headResources, + workerResources: workerResources, + expectedSubmitterResources: resourceRequirements, + expectedHeadResources: expectedHeadResources, + expectedWorkerResources: expectedWorkerResources, + }, + } + + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + rayJobInput := dummyRayCustomObj() + + if p.headResources != nil { + rayJobInput.RayCluster.HeadGroupSpec.Resources = p.headResources + } + + if p.workerResources != nil { + for _, spec := range rayJobInput.RayCluster.WorkerGroupSpec { + spec.Resources = p.workerResources + } + } + + taskTemplate := dummyRayTaskTemplate("ray-id", rayJobInput) + taskContext := dummyRayTaskContext(taskTemplate, p.taskResources, nil, "", serviceAccount) + rayJobResourceHandler := rayJobResourceHandler{} + r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + rayJob, ok := r.(*rayv1.RayJob) + assert.True(t, ok) + + submitterPodResources := rayJob.Spec.SubmitterPodTemplate.Spec.Containers[0].Resources + assert.EqualValues(t, + p.expectedSubmitterResources, + &submitterPodResources, + ) + + headPodResources := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Containers[0].Resources + assert.EqualValues(t, + p.expectedHeadResources, + &headPodResources, + ) + + for _, workerGroupSpec := range rayJob.Spec.RayClusterSpec.WorkerGroupSpecs { + workerPodResources := workerGroupSpec.Template.Spec.Containers[0].Resources + assert.EqualValues(t, + p.expectedWorkerResources, + &workerPodResources, + ) + } + }) + } } func TestDefaultStartParameters(t *testing.T) { - rayJobResourceHandler := rayJobResourceHandler{} - rayJob := &plugins.RayJob{ - RayCluster: &plugins.RayCluster{ - HeadGroupSpec: &plugins.HeadGroupSpec{}, - WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, - EnableAutoscaling: true, - }, - ShutdownAfterJobFinishes: true, - TtlSecondsAfterFinished: 120, - } - - taskTemplate := dummyRayTaskTemplate("ray-id", rayJob) - toleration := []corev1.Toleration{{ - Key: "storage", - Value: "dedicated", - Operator: corev1.TolerationOpExists, - Effect: corev1.TaintEffectNoSchedule, - }} - err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) - assert.Nil(t, err) - - RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount)) - assert.Nil(t, err) - - assert.NotNil(t, RayResource) - ray, ok := RayResource.(*rayv1.RayJob) - assert.True(t, ok) - - assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, true) - assert.Equal(t, ray.Spec.ShutdownAfterJobFinishes, true) - assert.Equal(t, ray.Spec.TTLSecondsAfterFinished, int32(120)) - - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, - map[string]string{ - "dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true", - "node-ip-address": "$MY_POD_IP", - }) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) - - workerReplica := int32(3) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, workerReplica) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas, workerReplica) - assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, workerReplica) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) + rayJobResourceHandler := rayJobResourceHandler{} + rayJob := &plugins.RayJob{ + RayCluster: &plugins.RayCluster{ + HeadGroupSpec: &plugins.HeadGroupSpec{}, + WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, MinReplicas: 3, MaxReplicas: 3}}, + EnableAutoscaling: true, + }, + ShutdownAfterJobFinishes: true, + TtlSecondsAfterFinished: 120, + } + + taskTemplate := dummyRayTaskTemplate("ray-id", rayJob) + toleration := []corev1.Toleration{{ + Key: "storage", + Value: "dedicated", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }} + err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) + assert.Nil(t, err) + + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount)) + assert.Nil(t, err) + + assert.NotNil(t, RayResource) + ray, ok := RayResource.(*rayv1.RayJob) + assert.True(t, ok) + + assert.Equal(t, *ray.Spec.RayClusterSpec.EnableInTreeAutoscaling, true) + assert.Equal(t, ray.Spec.ShutdownAfterJobFinishes, true) + assert.Equal(t, ray.Spec.TTLSecondsAfterFinished, int32(120)) + + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, + map[string]string{ + "dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true", + "node-ip-address": "$MY_POD_IP", + }) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) + + workerReplica := int32(3) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, workerReplica) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas, workerReplica) + assert.Equal(t, *ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, workerReplica) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) } func TestInjectLogsSidecar(t *testing.T) { - rayJobObj := transformRayJobToCustomObj(dummyRayCustomObj()) - params := []struct { - name string - taskTemplate core.TaskTemplate - // primaryContainerName string - logsSidecarCfg *corev1.Container - expectedVolumes []corev1.Volume - expectedPrimaryContainerVolumeMounts []corev1.VolumeMount - expectedLogsSidecarVolumeMounts []corev1.VolumeMount - }{ - { - "container target", - core.TaskTemplate{ - Id: &core.Identifier{Name: "ray-id"}, - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Image: testImage, - Args: testArgs, - }, - }, - Custom: rayJobObj, - }, - &corev1.Container{ - Name: "logs-sidecar", - Image: "test-image", - }, - []corev1.Volume{ - { - Name: "system-ray-state", - VolumeSource: corev1.VolumeSource{ - EmptyDir: &corev1.EmptyDirVolumeSource{}, - }, - }, - }, - []corev1.VolumeMount{ - { - Name: "system-ray-state", - MountPath: "/tmp/ray", - }, - }, - []corev1.VolumeMount{ - { - Name: "system-ray-state", - MountPath: "/tmp/ray", - ReadOnly: true, - }, - }, - }, - { - "container target with no sidecar", - core.TaskTemplate{ - Id: &core.Identifier{Name: "ray-id"}, - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Image: testImage, - Args: testArgs, - }, - }, - Custom: rayJobObj, - }, - nil, - nil, - nil, - nil, - }, - { - "pod target", - core.TaskTemplate{ - Id: &core.Identifier{Name: "ray-id"}, - Target: transformPodSpecToTaskTemplateTarget(&corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "main", - Image: "primary-image", - }, - }, - }), - Custom: rayJobObj, - Config: map[string]string{ - flytek8s.PrimaryContainerKey: "main", - }, - }, - &corev1.Container{ - Name: "logs-sidecar", - Image: "test-image", - }, - []corev1.Volume{ - { - Name: "system-ray-state", - VolumeSource: corev1.VolumeSource{ - EmptyDir: &corev1.EmptyDirVolumeSource{}, - }, - }, - }, - []corev1.VolumeMount{ - { - Name: "system-ray-state", - MountPath: "/tmp/ray", - }, - }, - []corev1.VolumeMount{ - { - Name: "system-ray-state", - MountPath: "/tmp/ray", - ReadOnly: true, - }, - }, - }, - { - "pod target with existing ray state volume", - core.TaskTemplate{ - Id: &core.Identifier{Name: "ray-id"}, - Target: transformPodSpecToTaskTemplateTarget(&corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "main", - Image: "primary-image", - VolumeMounts: []corev1.VolumeMount{ - { - Name: "test-vol", - MountPath: "/tmp/ray", - }, - }, - }, - }, - Volumes: []corev1.Volume{ - { - Name: "test-vol", - VolumeSource: corev1.VolumeSource{ - EmptyDir: &corev1.EmptyDirVolumeSource{}, - }, - }, - }, - }), - Custom: rayJobObj, - Config: map[string]string{ - flytek8s.PrimaryContainerKey: "main", - }, - }, - &corev1.Container{ - Name: "logs-sidecar", - Image: "test-image", - }, - []corev1.Volume{ - { - Name: "test-vol", - VolumeSource: corev1.VolumeSource{ - EmptyDir: &corev1.EmptyDirVolumeSource{}, - }, - }, - }, - []corev1.VolumeMount{ - { - Name: "test-vol", - MountPath: "/tmp/ray", - }, - }, - []corev1.VolumeMount{ - { - Name: "test-vol", - MountPath: "/tmp/ray", - ReadOnly: true, - }, - }, - }, - } - - for _, p := range params { - t.Run(p.name, func(t *testing.T) { - assert.NoError(t, SetConfig(&Config{ - LogsSidecar: p.logsSidecarCfg, - })) - taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil, "", serviceAccount) - rayJobResourceHandler := rayJobResourceHandler{} - r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) - assert.Nil(t, err) - assert.NotNil(t, r) - rayJob, ok := r.(*rayv1.RayJob) - assert.True(t, ok) - - headPodSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec - - // Check volumes - assert.EqualValues(t, p.expectedVolumes, headPodSpec.Volumes) - - // Check containers and respective volume mounts - foundPrimaryContainer := false - foundLogsSidecar := false - for _, cnt := range headPodSpec.Containers { - if cnt.Name == "ray-head" { - foundPrimaryContainer = true - assert.EqualValues( - t, - p.expectedPrimaryContainerVolumeMounts, - cnt.VolumeMounts, - ) - } - if p.logsSidecarCfg != nil && cnt.Name == p.logsSidecarCfg.Name { - foundLogsSidecar = true - assert.EqualValues( - t, - p.expectedLogsSidecarVolumeMounts, - cnt.VolumeMounts, - ) - } - } - assert.Equal(t, true, foundPrimaryContainer) - assert.Equal(t, p.logsSidecarCfg != nil, foundLogsSidecar) - }) - } + rayJobObj := transformRayJobToCustomObj(dummyRayCustomObj()) + params := []struct { + name string + taskTemplate core.TaskTemplate + // primaryContainerName string + logsSidecarCfg *corev1.Container + expectedVolumes []corev1.Volume + expectedPrimaryContainerVolumeMounts []corev1.VolumeMount + expectedLogsSidecarVolumeMounts []corev1.VolumeMount + }{ + { + "container target", + core.TaskTemplate{ + Id: &core.Identifier{Name: "ray-id"}, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + Args: testArgs, + }, + }, + Custom: rayJobObj, + }, + &corev1.Container{ + Name: "logs-sidecar", + Image: "test-image", + }, + []corev1.Volume{ + { + Name: "system-ray-state", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + }, + []corev1.VolumeMount{ + { + Name: "system-ray-state", + MountPath: "/tmp/ray", + }, + }, + []corev1.VolumeMount{ + { + Name: "system-ray-state", + MountPath: "/tmp/ray", + ReadOnly: true, + }, + }, + }, + { + "container target with no sidecar", + core.TaskTemplate{ + Id: &core.Identifier{Name: "ray-id"}, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + Args: testArgs, + }, + }, + Custom: rayJobObj, + }, + nil, + nil, + nil, + nil, + }, + { + "pod target", + core.TaskTemplate{ + Id: &core.Identifier{Name: "ray-id"}, + Target: transformPodSpecToTaskTemplateTarget(&corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "main", + Image: "primary-image", + }, + }, + }), + Custom: rayJobObj, + Config: map[string]string{ + flytek8s.PrimaryContainerKey: "main", + }, + }, + &corev1.Container{ + Name: "logs-sidecar", + Image: "test-image", + }, + []corev1.Volume{ + { + Name: "system-ray-state", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + }, + []corev1.VolumeMount{ + { + Name: "system-ray-state", + MountPath: "/tmp/ray", + }, + }, + []corev1.VolumeMount{ + { + Name: "system-ray-state", + MountPath: "/tmp/ray", + ReadOnly: true, + }, + }, + }, + { + "pod target with existing ray state volume", + core.TaskTemplate{ + Id: &core.Identifier{Name: "ray-id"}, + Target: transformPodSpecToTaskTemplateTarget(&corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "main", + Image: "primary-image", + VolumeMounts: []corev1.VolumeMount{ + { + Name: "test-vol", + MountPath: "/tmp/ray", + }, + }, + }, + }, + Volumes: []corev1.Volume{ + { + Name: "test-vol", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + }, + }), + Custom: rayJobObj, + Config: map[string]string{ + flytek8s.PrimaryContainerKey: "main", + }, + }, + &corev1.Container{ + Name: "logs-sidecar", + Image: "test-image", + }, + []corev1.Volume{ + { + Name: "test-vol", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + }, + []corev1.VolumeMount{ + { + Name: "test-vol", + MountPath: "/tmp/ray", + }, + }, + []corev1.VolumeMount{ + { + Name: "test-vol", + MountPath: "/tmp/ray", + ReadOnly: true, + }, + }, + }, + } + + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + assert.NoError(t, SetConfig(&Config{ + LogsSidecar: p.logsSidecarCfg, + })) + taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil, "", serviceAccount) + rayJobResourceHandler := rayJobResourceHandler{} + r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + rayJob, ok := r.(*rayv1.RayJob) + assert.True(t, ok) + + headPodSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec + + // Check volumes + assert.EqualValues(t, p.expectedVolumes, headPodSpec.Volumes) + + // Check containers and respective volume mounts + foundPrimaryContainer := false + foundLogsSidecar := false + for _, cnt := range headPodSpec.Containers { + if cnt.Name == "ray-head" { + foundPrimaryContainer = true + assert.EqualValues( + t, + p.expectedPrimaryContainerVolumeMounts, + cnt.VolumeMounts, + ) + } + if p.logsSidecarCfg != nil && cnt.Name == p.logsSidecarCfg.Name { + foundLogsSidecar = true + assert.EqualValues( + t, + p.expectedLogsSidecarVolumeMounts, + cnt.VolumeMounts, + ) + } + } + assert.Equal(t, true, foundPrimaryContainer) + assert.Equal(t, p.logsSidecarCfg != nil, foundLogsSidecar) + }) + } } func newPluginContext(pluginState k8s.PluginState) k8s.PluginContext { - plg := &mocks2.PluginContext{} - - taskExecID := &mocks.TaskExecutionID{} - taskExecID.OnGetID().Return(core.TaskExecutionIdentifier{ - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Name: "my-task-name", - Project: "my-task-project", - Domain: "my-task-domain", - Version: "1", - }, - NodeExecutionId: &core.NodeExecutionIdentifier{ - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my-execution-name", - Project: "my-execution-project", - Domain: "my-execution-domain", - }, - }, - RetryAttempt: 1, - }) - taskExecID.OnGetUniqueNodeID().Return("unique-node") - taskExecID.OnGetGeneratedName().Return("generated-name") - - tskCtx := &mocks.TaskExecutionMetadata{} - tskCtx.OnGetTaskExecutionID().Return(taskExecID) - plg.OnTaskExecutionMetadata().Return(tskCtx) - - pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( - func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = pluginState - return 0 - }, - func(v interface{}) error { - return nil - }) - - plg.OnPluginStateReader().Return(&pluginStateReaderMock) - - return plg + plg := &mocks2.PluginContext{} + + taskExecID := &mocks.TaskExecutionID{} + taskExecID.OnGetID().Return(core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Name: "my-task-name", + Project: "my-task-project", + Domain: "my-task-domain", + Version: "1", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my-execution-name", + Project: "my-execution-project", + Domain: "my-execution-domain", + }, + }, + RetryAttempt: 1, + }) + taskExecID.OnGetUniqueNodeID().Return("unique-node") + taskExecID.OnGetGeneratedName().Return("generated-name") + + tskCtx := &mocks.TaskExecutionMetadata{} + tskCtx.OnGetTaskExecutionID().Return(taskExecID) + plg.OnTaskExecutionMetadata().Return(tskCtx) + + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = pluginState + return 0 + }, + func(v interface{}) error { + return nil + }) + + plg.OnPluginStateReader().Return(&pluginStateReaderMock) + + return plg } func init() { - f := defaultConfig - f.Logs = logs.LogConfig{ - IsKubernetesEnabled: true, - } - - if err := SetConfig(&f); err != nil { - panic(err) - } + f := defaultConfig + f.Logs = logs.LogConfig{ + IsKubernetesEnabled: true, + } + + if err := SetConfig(&f); err != nil { + panic(err) + } } func TestGetTaskPhase(t *testing.T) { - ctx := context.Background() - rayJobResourceHandler := rayJobResourceHandler{} - pluginCtx := newPluginContext(k8s.PluginState{}) - - testCases := []struct { - rayJobPhase rayv1.JobDeploymentStatus - expectedCorePhase pluginsCore.Phase - expectedError bool - }{ - {rayv1.JobDeploymentStatusInitializing, pluginsCore.PhaseInitializing, false}, - {rayv1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning, false}, - {rayv1.JobDeploymentStatusComplete, pluginsCore.PhaseSuccess, false}, - {rayv1.JobDeploymentStatusFailed, pluginsCore.PhasePermanentFailure, false}, - {rayv1.JobDeploymentStatusSuspended, pluginsCore.PhaseQueued, false}, - {rayv1.JobDeploymentStatusSuspending, pluginsCore.PhaseQueued, false}, - } - - for _, tc := range testCases { - t.Run("TestGetTaskPhase_"+string(tc.rayJobPhase), func(t *testing.T) { - rayObject := &rayv1.RayJob{} - rayObject.Status.JobDeploymentStatus = tc.rayJobPhase - startTime := metav1.NewTime(time.Now()) - rayObject.Status.StartTime = &startTime - phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) - if tc.expectedError { - assert.Error(t, err) - } else { - assert.Nil(t, err) - } - assert.Equal(t, tc.expectedCorePhase.String(), phaseInfo.Phase().String()) - }) - } + ctx := context.Background() + rayJobResourceHandler := rayJobResourceHandler{} + pluginCtx := newPluginContext(k8s.PluginState{}) + + testCases := []struct { + rayJobPhase rayv1.JobDeploymentStatus + expectedCorePhase pluginsCore.Phase + expectedError bool + }{ + {rayv1.JobDeploymentStatusInitializing, pluginsCore.PhaseInitializing, false}, + {rayv1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning, false}, + {rayv1.JobDeploymentStatusComplete, pluginsCore.PhaseSuccess, false}, + {rayv1.JobDeploymentStatusFailed, pluginsCore.PhasePermanentFailure, false}, + {rayv1.JobDeploymentStatusSuspended, pluginsCore.PhaseUndefined, true}, + } + + for _, tc := range testCases { + t.Run("TestGetTaskPhase_"+string(tc.rayJobPhase), func(t *testing.T) { + rayObject := &rayv1.RayJob{} + rayObject.Status.JobDeploymentStatus = tc.rayJobPhase + startTime := metav1.NewTime(time.Now()) + rayObject.Status.StartTime = &startTime + phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) + if tc.expectedError { + assert.Error(t, err) + } else { + assert.Nil(t, err) + } + assert.Equal(t, tc.expectedCorePhase.String(), phaseInfo.Phase().String()) + }) + } } func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { - rayJobResourceHandler := rayJobResourceHandler{} + rayJobResourceHandler := rayJobResourceHandler{} - ctx := context.TODO() + ctx := context.TODO() - pluginState := k8s.PluginState{ - Phase: pluginsCore.PhaseInitializing, - PhaseVersion: pluginsCore.DefaultPhaseVersion, - Reason: "task submitted to K8s", - } - pluginCtx := newPluginContext(pluginState) + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseInitializing, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + pluginCtx := newPluginContext(pluginState) - rayObject := &rayv1.RayJob{} - rayObject.Status.JobDeploymentStatus = rayv1.JobDeploymentStatusInitializing - phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) + rayObject := &rayv1.RayJob{} + rayObject.Status.JobDeploymentStatus = rayv1.JobDeploymentStatusInitializing + phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) - assert.NoError(t, err) - assert.Equal(t, phaseInfo.Version(), pluginsCore.DefaultPhaseVersion+1) + assert.NoError(t, err) + assert.Equal(t, phaseInfo.Version(), pluginsCore.DefaultPhaseVersion+1) } func TestGetEventInfo_LogTemplates(t *testing.T) { - pluginCtx := newPluginContext(k8s.PluginState{}) - testCases := []struct { - name string - rayJob rayv1.RayJob - logPlugin tasklog.TemplateLogPlugin - expectedTaskLogs []*core.TaskLog - }{ - { - name: "namespace", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "namespace", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "namespace", - Uri: "http://test/test-namespace", - }, - }, - }, - { - name: "task execution ID", - rayJob: rayv1.RayJob{}, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "taskExecID", - TemplateURIs: []tasklog.TemplateURI{ - "http://test/projects/{{ .executionProject }}/domains/{{ .executionDomain }}/executions/{{ .executionName }}/nodeId/{{ .nodeID }}/taskId/{{ .taskID }}/attempt/{{ .taskRetryAttempt }}", - }, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "taskExecID", - Uri: "http://test/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/unique-node/taskId/my-task-name/attempt/1", - }, - }, - }, - { - name: "ray cluster name", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - Status: rayv1.RayJobStatus{ - RayClusterName: "ray-cluster", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "ray cluster name", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayClusterName }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "ray cluster name", - Uri: "http://test/test-namespace/ray-cluster", - }, - }, - }, - { - name: "ray job ID", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - Status: rayv1.RayJobStatus{ - JobId: "ray-job-1", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "ray job ID", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayJobID }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "ray job ID", - Uri: "http://test/test-namespace/ray-job-1", - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ti, err := getEventInfoForRayJob( - logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}}, - pluginCtx, - &tc.rayJob, - ) - assert.NoError(t, err) - assert.Equal(t, tc.expectedTaskLogs, ti.Logs) - }) - } + pluginCtx := newPluginContext(k8s.PluginState{}) + testCases := []struct { + name string + rayJob rayv1.RayJob + logPlugin tasklog.TemplateLogPlugin + expectedTaskLogs []*core.TaskLog + }{ + { + name: "namespace", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "namespace", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "namespace", + Uri: "http://test/test-namespace", + }, + }, + }, + { + name: "task execution ID", + rayJob: rayv1.RayJob{}, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "taskExecID", + TemplateURIs: []tasklog.TemplateURI{ + "http://test/projects/{{ .executionProject }}/domains/{{ .executionDomain }}/executions/{{ .executionName }}/nodeId/{{ .nodeID }}/taskId/{{ .taskID }}/attempt/{{ .taskRetryAttempt }}", + }, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "taskExecID", + Uri: "http://test/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/unique-node/taskId/my-task-name/attempt/1", + }, + }, + }, + { + name: "ray cluster name", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + Status: rayv1.RayJobStatus{ + RayClusterName: "ray-cluster", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "ray cluster name", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayClusterName }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "ray cluster name", + Uri: "http://test/test-namespace/ray-cluster", + }, + }, + }, + { + name: "ray job ID", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + Status: rayv1.RayJobStatus{ + JobId: "ray-job-1", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "ray job ID", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayJobID }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "ray job ID", + Uri: "http://test/test-namespace/ray-job-1", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ti, err := getEventInfoForRayJob( + logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}}, + pluginCtx, + &tc.rayJob, + ) + assert.NoError(t, err) + assert.Equal(t, tc.expectedTaskLogs, ti.Logs) + }) + } } func TestGetEventInfo_LogTemplates_V1(t *testing.T) { - pluginCtx := newPluginContext(k8s.PluginState{}) - testCases := []struct { - name string - rayJob rayv1.RayJob - logPlugin tasklog.TemplateLogPlugin - expectedTaskLogs []*core.TaskLog - }{ - { - name: "namespace", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "namespace", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "namespace", - Uri: "http://test/test-namespace", - }, - }, - }, - { - name: "task execution ID", - rayJob: rayv1.RayJob{}, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "taskExecID", - TemplateURIs: []tasklog.TemplateURI{ - "http://test/projects/{{ .executionProject }}/domains/{{ .executionDomain }}/executions/{{ .executionName }}/nodeId/{{ .nodeID }}/taskId/{{ .taskID }}/attempt/{{ .taskRetryAttempt }}", - }, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "taskExecID", - Uri: "http://test/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/unique-node/taskId/my-task-name/attempt/1", - }, - }, - }, - { - name: "ray cluster name", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - Status: rayv1.RayJobStatus{ - RayClusterName: "ray-cluster", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "ray cluster name", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayClusterName }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "ray cluster name", - Uri: "http://test/test-namespace/ray-cluster", - }, - }, - }, - { - name: "ray job ID", - rayJob: rayv1.RayJob{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "test-namespace", - }, - Status: rayv1.RayJobStatus{ - JobId: "ray-job-1", - }, - }, - logPlugin: tasklog.TemplateLogPlugin{ - DisplayName: "ray job ID", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayJobID }}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "ray job ID", - Uri: "http://test/test-namespace/ray-job-1", - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ti, err := getEventInfoForRayJob( - logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}}, - pluginCtx, - &tc.rayJob, - ) - assert.NoError(t, err) - assert.Equal(t, tc.expectedTaskLogs, ti.Logs) - }) - } + pluginCtx := newPluginContext(k8s.PluginState{}) + testCases := []struct { + name string + rayJob rayv1.RayJob + logPlugin tasklog.TemplateLogPlugin + expectedTaskLogs []*core.TaskLog + }{ + { + name: "namespace", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "namespace", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "namespace", + Uri: "http://test/test-namespace", + }, + }, + }, + { + name: "task execution ID", + rayJob: rayv1.RayJob{}, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "taskExecID", + TemplateURIs: []tasklog.TemplateURI{ + "http://test/projects/{{ .executionProject }}/domains/{{ .executionDomain }}/executions/{{ .executionName }}/nodeId/{{ .nodeID }}/taskId/{{ .taskID }}/attempt/{{ .taskRetryAttempt }}", + }, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "taskExecID", + Uri: "http://test/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/unique-node/taskId/my-task-name/attempt/1", + }, + }, + }, + { + name: "ray cluster name", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + Status: rayv1.RayJobStatus{ + RayClusterName: "ray-cluster", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "ray cluster name", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayClusterName }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "ray cluster name", + Uri: "http://test/test-namespace/ray-cluster", + }, + }, + }, + { + name: "ray job ID", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + Status: rayv1.RayJobStatus{ + JobId: "ray-job-1", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "ray job ID", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayJobID }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "ray job ID", + Uri: "http://test/test-namespace/ray-job-1", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ti, err := getEventInfoForRayJob( + logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}}, + pluginCtx, + &tc.rayJob, + ) + assert.NoError(t, err) + assert.Equal(t, tc.expectedTaskLogs, ti.Logs) + }) + } } func TestGetEventInfo_DashboardURL(t *testing.T) { - pluginCtx := newPluginContext(k8s.PluginState{}) - testCases := []struct { - name string - rayJob rayv1.RayJob - dashboardURLTemplate tasklog.TemplateLogPlugin - expectedTaskLogs []*core.TaskLog - }{ - { - name: "dashboard URL displayed", - rayJob: rayv1.RayJob{ - Status: rayv1.RayJobStatus{ - DashboardURL: "exists", - JobStatus: rayv1.JobStatusRunning, - }, - }, - dashboardURLTemplate: tasklog.TemplateLogPlugin{ - DisplayName: "Ray Dashboard", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{.generatedName}}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "Ray Dashboard", - Uri: "http://test/generated-name", - }, - }, - }, - { - name: "dashboard URL is not displayed", - rayJob: rayv1.RayJob{ - Status: rayv1.RayJobStatus{ - JobStatus: rayv1.JobStatusPending, - }, - }, - dashboardURLTemplate: tasklog.TemplateLogPlugin{ - DisplayName: "dummy", - TemplateURIs: []tasklog.TemplateURI{"http://dummy"}, - }, - expectedTaskLogs: nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) - ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob) - assert.NoError(t, err) - assert.Equal(t, tc.expectedTaskLogs, ti.Logs) - }) - } + pluginCtx := newPluginContext(k8s.PluginState{}) + testCases := []struct { + name string + rayJob rayv1.RayJob + dashboardURLTemplate tasklog.TemplateLogPlugin + expectedTaskLogs []*core.TaskLog + }{ + { + name: "dashboard URL displayed", + rayJob: rayv1.RayJob{ + Status: rayv1.RayJobStatus{ + DashboardURL: "exists", + JobStatus: rayv1.JobStatusRunning, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "Ray Dashboard", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{.generatedName}}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "Ray Dashboard", + Uri: "http://test/generated-name", + }, + }, + }, + { + name: "dashboard URL is not displayed", + rayJob: rayv1.RayJob{ + Status: rayv1.RayJobStatus{ + JobStatus: rayv1.JobStatusPending, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "dummy", + TemplateURIs: []tasklog.TemplateURI{"http://dummy"}, + }, + expectedTaskLogs: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) + ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob) + assert.NoError(t, err) + assert.Equal(t, tc.expectedTaskLogs, ti.Logs) + }) + } } func TestGetEventInfo_DashboardURL_V1(t *testing.T) { - pluginCtx := newPluginContext(k8s.PluginState{}) - testCases := []struct { - name string - rayJob rayv1.RayJob - dashboardURLTemplate tasklog.TemplateLogPlugin - expectedTaskLogs []*core.TaskLog - }{ - { - name: "dashboard URL displayed", - rayJob: rayv1.RayJob{ - Status: rayv1.RayJobStatus{ - DashboardURL: "exists", - JobStatus: rayv1.JobStatusRunning, - }, - }, - dashboardURLTemplate: tasklog.TemplateLogPlugin{ - DisplayName: "Ray Dashboard", - TemplateURIs: []tasklog.TemplateURI{"http://test/{{.generatedName}}"}, - }, - expectedTaskLogs: []*core.TaskLog{ - { - Name: "Ray Dashboard", - Uri: "http://test/generated-name", - }, - }, - }, - { - name: "dashboard URL is not displayed", - rayJob: rayv1.RayJob{ - Status: rayv1.RayJobStatus{ - JobStatus: rayv1.JobStatusPending, - }, - }, - dashboardURLTemplate: tasklog.TemplateLogPlugin{ - DisplayName: "dummy", - TemplateURIs: []tasklog.TemplateURI{"http://dummy"}, - }, - expectedTaskLogs: nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) - ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob) - assert.NoError(t, err) - assert.Equal(t, tc.expectedTaskLogs, ti.Logs) - }) - } + pluginCtx := newPluginContext(k8s.PluginState{}) + testCases := []struct { + name string + rayJob rayv1.RayJob + dashboardURLTemplate tasklog.TemplateLogPlugin + expectedTaskLogs []*core.TaskLog + }{ + { + name: "dashboard URL displayed", + rayJob: rayv1.RayJob{ + Status: rayv1.RayJobStatus{ + DashboardURL: "exists", + JobStatus: rayv1.JobStatusRunning, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "Ray Dashboard", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{.generatedName}}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "Ray Dashboard", + Uri: "http://test/generated-name", + }, + }, + }, + { + name: "dashboard URL is not displayed", + rayJob: rayv1.RayJob{ + Status: rayv1.RayJobStatus{ + JobStatus: rayv1.JobStatusPending, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "dummy", + TemplateURIs: []tasklog.TemplateURI{"http://dummy"}, + }, + expectedTaskLogs: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) + ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob) + assert.NoError(t, err) + assert.Equal(t, tc.expectedTaskLogs, ti.Logs) + }) + } } func TestGetPropertiesRay(t *testing.T) { - rayJobResourceHandler := rayJobResourceHandler{} - expected := k8s.PluginProperties{} - assert.Equal(t, expected, rayJobResourceHandler.GetProperties()) + rayJobResourceHandler := rayJobResourceHandler{} + expected := k8s.PluginProperties{} + assert.Equal(t, expected, rayJobResourceHandler.GetProperties()) }