Skip to content

Commit

Permalink
feat(replica-spec): Update corresponding golang files based on protob…
Browse files Browse the repository at this point in the history
…uf changes

Resolves: flyteorg#4408
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed May 11, 2024
1 parent b645e7b commit 61a95d6
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 188 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
meta_v1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"
kfplugins "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"
flyteerr "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/logs"
Expand Down Expand Up @@ -241,11 +242,11 @@ func ParseCleanPodPolicy(flyteCleanPodPolicy kfplugins.CleanPodPolicy) commonOp.
}

// Get k8s restart policy from flyte kubeflow plugins restart policy.
func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.RestartPolicy {
restartPolicyMap := map[kfplugins.RestartPolicy]commonOp.RestartPolicy{
kfplugins.RestartPolicy_RESTART_POLICY_NEVER: commonOp.RestartPolicyNever,
kfplugins.RestartPolicy_RESTART_POLICY_ON_FAILURE: commonOp.RestartPolicyOnFailure,
kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS: commonOp.RestartPolicyAlways,
func ParseRestartPolicy(flyteRestartPolicy plugins.RestartPolicy) commonOp.RestartPolicy {
restartPolicyMap := map[plugins.RestartPolicy]commonOp.RestartPolicy{
plugins.RestartPolicy_RESTART_POLICY_NEVER: commonOp.RestartPolicyNever,
plugins.RestartPolicy_RESTART_POLICY_ON_FAILURE: commonOp.RestartPolicyOnFailure,
plugins.RestartPolicy_RESTART_POLICY_ALWAYS: commonOp.RestartPolicyAlways,
}
return restartPolicyMap[flyteRestartPolicy]
}
Expand Down Expand Up @@ -290,10 +291,7 @@ func ToReplicaSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext
}

type kfDistributedReplicaSpec interface {
GetReplicas() int32
GetImage() string
GetResources() *core.Resources
GetRestartPolicy() kfplugins.RestartPolicy
GetCommon() *plugins.CommonReplicaSpec
}

type allowsCommandOverride interface {
Expand All @@ -302,8 +300,8 @@ type allowsCommandOverride interface {

func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs kfDistributedReplicaSpec, primaryContainerName string, isMaster bool) (*commonOp.ReplicaSpec, error) {
taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{}
if rs != nil && rs.GetResources() != nil {
resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources())
if rs != nil && rs.GetCommon().GetResources() != nil {
resources, err := flytek8s.ToK8sResourceRequirements(rs.GetCommon().GetResources())
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error())
}
Expand All @@ -329,16 +327,16 @@ func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExe
if err := OverrideContainerSpec(
&replicaSpec.Template.Spec,
primaryContainerName,
rs.GetImage(),
rs.GetCommon().GetImage(),
command,
); err != nil {
return nil, err
}

replicaSpec.RestartPolicy = ParseRestartPolicy(rs.GetRestartPolicy())
replicaSpec.RestartPolicy = ParseRestartPolicy(rs.GetCommon().GetRestartPolicy())

if !isMaster {
replicas := rs.GetReplicas()
replicas := rs.GetCommon().GetReplicas()
replicaSpec.Replicas = &replicas
}
}
Expand Down
102 changes: 56 additions & 46 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,29 +596,33 @@ func TestBuildResourceMPIV1(t *testing.T) {
workerCommand := []string{"/usr/sbin/sshd", "/.sshd_config"}
taskConfig := &kfplugins.DistributedMPITrainingTask{
LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{
Image: testImage,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
{Name: core.Resources_MEMORY, Value: "250Mi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
{Name: core.Resources_MEMORY, Value: "500Mi"},
Common: &kfplugins.CommonReplicaSpec{
Image: testImage,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
{Name: core.Resources_MEMORY, Value: "250Mi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
{Name: core.Resources_MEMORY, Value: "500Mi"},
},
},
},
Command: launcherCommand,
},
WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
Common: &kfplugins.CommonReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
},
},
},
Command: workerCommand,
Expand Down Expand Up @@ -673,15 +677,17 @@ func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) {

taskConfig := &kfplugins.DistributedMPITrainingTask{
WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
Common: &kfplugins.CommonReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
},
},
},
Command: []string{"/usr/sbin/sshd", "/.sshd_config"},
Expand Down Expand Up @@ -735,29 +741,33 @@ func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) {

taskConfig := &kfplugins.DistributedMPITrainingTask{
LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
{Name: core.Resources_MEMORY, Value: "250Mi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
{Name: core.Resources_MEMORY, Value: "500Mi"},
Common: &kfplugins.CommonReplicaSpec{
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
{Name: core.Resources_MEMORY, Value: "250Mi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
{Name: core.Resources_MEMORY, Value: "500Mi"},
},
},
},
},
WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
{Name: core.Resources_GPU, Value: "1"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
{Name: core.Resources_GPU, Value: "1"},
Common: &kfplugins.CommonReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
{Name: core.Resources_GPU, Value: "1"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
{Name: core.Resources_GPU, Value: "1"},
},
},
},
},
Expand Down
116 changes: 66 additions & 50 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -729,29 +729,33 @@ func TestReplicaCounts(t *testing.T) {
func TestBuildResourcePytorchV1(t *testing.T) {
taskConfig := &kfplugins.DistributedPyTorchTrainingTask{
MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Image: testImageMaster,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
{Name: core.Resources_MEMORY, Value: "250Mi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
{Name: core.Resources_MEMORY, Value: "500Mi"},
Common: &kfplugins.CommonReplicaSpec{
Image: testImageMaster,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
{Name: core.Resources_MEMORY, Value: "250Mi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
{Name: core.Resources_MEMORY, Value: "500Mi"},
},
},
RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS,
},
RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS,
},
WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
Common: &kfplugins.CommonReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
},
},
},
},
Expand Down Expand Up @@ -814,7 +818,9 @@ func TestBuildResourcePytorchV1(t *testing.T) {
func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) {
taskConfig := &kfplugins.DistributedPyTorchTrainingTask{
WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Replicas: 100,
Common: &kfplugins.CommonReplicaSpec{
Replicas: 100,
},
},
RunPolicy: &kfplugins.RunPolicy{
CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL,
Expand Down Expand Up @@ -845,15 +851,17 @@ func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) {
func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) {
taskConfig := &kfplugins.DistributedPyTorchTrainingTask{
WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
Common: &kfplugins.CommonReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
},
},
},
},
Expand Down Expand Up @@ -926,29 +934,33 @@ func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) {

taskConfig := &kfplugins.DistributedPyTorchTrainingTask{
MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
{Name: core.Resources_MEMORY, Value: "250Mi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
{Name: core.Resources_MEMORY, Value: "500Mi"},
Common: &kfplugins.CommonReplicaSpec{
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
{Name: core.Resources_MEMORY, Value: "250Mi"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
{Name: core.Resources_MEMORY, Value: "500Mi"},
},
},
},
},
WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
{Name: core.Resources_GPU, Value: "1"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
{Name: core.Resources_GPU, Value: "1"},
Common: &kfplugins.CommonReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
{Name: core.Resources_MEMORY, Value: "1Gi"},
{Name: core.Resources_GPU, Value: "1"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
{Name: core.Resources_MEMORY, Value: "2Gi"},
{Name: core.Resources_GPU, Value: "1"},
},
},
},
},
Expand All @@ -973,7 +985,9 @@ func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) {
func TestBuildResourcePytorchV1WithElastic(t *testing.T) {
taskConfig := &kfplugins.DistributedPyTorchTrainingTask{
WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Replicas: 2,
Common: &kfplugins.CommonReplicaSpec{
Replicas: 2,
},
},
ElasticConfig: &kfplugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"},
}
Expand Down Expand Up @@ -1011,7 +1025,9 @@ func TestBuildResourcePytorchV1WithElastic(t *testing.T) {
func TestBuildResourcePytorchV1WithZeroWorker(t *testing.T) {
taskConfig := &kfplugins.DistributedPyTorchTrainingTask{
WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Replicas: 0,
Common: &kfplugins.CommonReplicaSpec{
Replicas: 0,
},
},
}
pytorchResourceHandler := pytorchOperatorResourceHandler{}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task
for t, cfg := range replicaSpecCfgMap {
// Short circuit if replica set has no replicas to avoid unnecessarily
// generating pod specs
if cfg.GetReplicas() <= 0 {
if cfg.GetCommon().GetReplicas() <= 0 {
continue
}
rs, err := common.ToReplicaSpecWithOverrides(ctx, taskCtx, cfg, kubeflowv1.TFJobDefaultContainerName, false)
Expand Down
Loading

0 comments on commit 61a95d6

Please sign in to comment.