diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 3c823510ac..3ec8041c6a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -3,6 +3,7 @@ package mpi import ( "context" "fmt" + "reflect" "testing" "time" @@ -117,7 +118,7 @@ func dummyMPITaskTemplate(id string, args ...interface{}) *core.TaskTemplate { } } -func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext { +func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -170,6 +171,18 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = pluginState + return 0 + }, + func(v interface{}) error { + return nil + }) + + taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock) return taskCtx } @@ -275,7 +288,7 @@ func dummyMPIJobResource(mpiResourceHandler mpiOperatorResourceHandler, mpiObj := dummyMPICustomObj(workers, launcher, slots) taskTemplate := dummyMPITaskTemplate(mpiID, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if err != nil { panic(err) } @@ -302,7 +315,7 @@ func TestBuildResourceMPI(t *testing.T) { mpiObj := dummyMPICustomObj(100, 50, 1) taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -338,13 +351,13 @@ func TestBuildResourceMPIForWrongInput(t *testing.T) { mpiObj := dummyMPICustomObj(0, 0, 1) taskTemplate := dummyMPITaskTemplate(mpiID, mpiObj) - _, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + _, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.Error(t, err) mpiObj = dummyMPICustomObj(1, 1, 1) taskTemplate = dummyMPITaskTemplate(mpiID2, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) app, ok := resource.(*kubeflowv1.MPIJob) assert.Nil(t, err) assert.Equal(t, true, ok) @@ -458,7 +471,7 @@ func TestBuildResourceMPIExtendedResources(t *testing.T) { mpiObj := dummyMPICustomObj(100, 50, 1) taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj) taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyMPITaskContext(taskTemplate, f.resources, f.extendedResourcesOverride) + taskContext := dummyMPITaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, k8s.PluginState{}) mpiResourceHandler := mpiOperatorResourceHandler{} r, err := mpiResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -490,7 +503,7 @@ func TestGetTaskPhase(t *testing.T) { return dummyMPIJobResource(mpiResourceHandler, 2, 1, 1, conditionType) } - taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil) + taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil, k8s.PluginState{}) taskPhase, err := mpiResourceHandler.GetTaskPhase(ctx, taskCtx, dummyMPIJobResourceCreator(mpiOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -522,6 +535,23 @@ func TestGetTaskPhase(t *testing.T) { assert.Nil(t, err) } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + mpiResourceHandler := mpiOperatorResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil, pluginState) + + taskPhase, err := mpiResourceHandler.GetTaskPhase(ctx, taskCtx, dummyMPIJobResource(mpiResourceHandler, 2, 1, 1, mpiOp.JobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetLogs(t *testing.T) { assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ IsKubernetesEnabled: true, @@ -534,7 +564,7 @@ func TestGetLogs(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning) - taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil) + taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil, k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, false, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) @@ -567,7 +597,7 @@ func TestReplicaCounts(t *testing.T) { mpiObj := dummyMPICustomObj(test.workerReplicaCount, test.launcherReplicaCount, 1) taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if test.expectError { assert.Error(t, err) assert.Nil(t, resource) @@ -653,7 +683,7 @@ func TestBuildResourceMPIV1(t *testing.T) { taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -705,7 +735,7 @@ func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) { taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -768,7 +798,7 @@ func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) { taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -783,7 +813,7 @@ func TestGetReplicaCount(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} tfObj := dummyMPICustomObj(1, 1, 0) taskTemplate := dummyMPITaskTemplate("the job", tfObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) MPIJob, ok := resource.(*kubeflowv1.MPIJob) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 68f8f5b65d..f0b17ff7ff 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -124,7 +124,7 @@ func dummyPytorchTaskTemplate(id string, args ...interface{}) *core.TaskTemplate } } -func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionContext { +func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -178,11 +178,10 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1. taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) - inputState := k8s.PluginState{} pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = inputState + *(v.(*k8s.PluginState)) = pluginState return 0 }, func(v interface{}) error { @@ -294,7 +293,7 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl ptObj := dummyPytorchCustomObj(workers) taskTemplate := dummyPytorchTaskTemplate("job1", ptObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) if err != nil { panic(err) } @@ -322,7 +321,7 @@ func TestBuildResourcePytorchElastic(t *testing.T) { ptObj := dummyElasticPytorchCustomObj(2, plugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}) taskTemplate := dummyPytorchTaskTemplate("job2", ptObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -365,7 +364,7 @@ func TestBuildResourcePytorch(t *testing.T) { ptObj := dummyPytorchCustomObj(100) taskTemplate := dummyPytorchTaskTemplate("job3", ptObj) - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -447,7 +446,7 @@ func TestBuildResourcePytorchContainerImage(t *testing.T) { for _, f := range fixtures { t.Run(tCfg.name+" "+f.name, func(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job", tCfg.plugin) - taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride) + taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride, k8s.PluginState{}) pytorchResourceHandler := pytorchOperatorResourceHandler{} r, err := pytorchResourceHandler.BuildResource(context.TODO(), taskContext) assert.NoError(t, err) @@ -589,7 +588,7 @@ func TestBuildResourcePytorchExtendedResources(t *testing.T) { t.Run(tCfg.name+" "+f.name, func(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job", tCfg.plugin) taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, "") + taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, "", k8s.PluginState{}) pytorchResourceHandler := pytorchOperatorResourceHandler{} r, err := pytorchResourceHandler.BuildResource(context.TODO(), taskContext) assert.NoError(t, err) @@ -622,7 +621,7 @@ func TestGetTaskPhase(t *testing.T) { return dummyPytorchJobResource(pytorchResourceHandler, 2, conditionType) } - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "") + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "", k8s.PluginState{}) taskPhase, err := pytorchResourceHandler.GetTaskPhase(ctx, taskCtx, dummyPytorchJobResourceCreator(commonOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -654,6 +653,23 @@ func TestGetTaskPhase(t *testing.T) { assert.Nil(t, err) } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + pytorchResourceHandler := pytorchOperatorResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "", pluginState) + + taskPhase, err := pytorchResourceHandler.GetTaskPhase(ctx, taskCtx, dummyPytorchJobResource(pytorchResourceHandler, 2, commonOp.JobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetLogs(t *testing.T) { assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ IsKubernetesEnabled: true, @@ -665,7 +681,7 @@ func TestGetLogs(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "") + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) @@ -685,7 +701,7 @@ func TestGetLogsElastic(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "") + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) @@ -716,7 +732,7 @@ func TestReplicaCounts(t *testing.T) { ptObj := dummyPytorchCustomObj(test.workerReplicaCount) taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) if test.expectError { assert.Error(t, err) assert.Nil(t, res) @@ -798,7 +814,7 @@ func TestBuildResourcePytorchV1(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -842,7 +858,7 @@ func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -902,7 +918,7 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -973,7 +989,7 @@ func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -995,7 +1011,7 @@ func TestBuildResourcePytorchV1WithElastic(t *testing.T) { taskTemplate.TaskTypeVersion = 1 pytorchResourceHandler := pytorchOperatorResourceHandler{} - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -1031,7 +1047,7 @@ func TestBuildResourcePytorchV1WithZeroWorker(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.Error(t, err) } @@ -1048,7 +1064,7 @@ func TestGetReplicaCount(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} tfObj := dummyPytorchCustomObj(1) taskTemplate := dummyPytorchTaskTemplate("the job", tfObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) PytorchJob, ok := resource.(*kubeflowv1.PyTorchJob) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index f6c0afb068..3df95a834d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -119,7 +119,7 @@ func dummyTensorFlowTaskTemplate(id string, args ...interface{}) *core.TaskTempl } } -func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext { +func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -173,11 +173,10 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *core taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) - inputState := k8s.PluginState{} pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = inputState + *(v.(*k8s.PluginState)) = pluginState return 0 }, func(v interface{}) error { @@ -290,7 +289,7 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso tfObj := dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if err != nil { panic(err) } @@ -315,7 +314,7 @@ func TestGetReplicaCount(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} tfObj := dummyTensorFlowCustomObj(1, 0, 0, 0) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) tensorflowJob, ok := resource.(*kubeflowv1.TFJob) @@ -333,7 +332,7 @@ func TestBuildResourceTensorFlow(t *testing.T) { tfObj := dummyTensorFlowCustomObj(100, 50, 1, 1) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -503,7 +502,7 @@ func TestBuildResourceTensorFlowExtendedResources(t *testing.T) { taskTemplate := *tCfg.taskTemplate taskTemplate.ExtendedResources = f.extendedResourcesBase tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - taskContext := dummyTensorFlowTaskContext(&taskTemplate, f.resources, f.extendedResourcesOverride) + taskContext := dummyTensorFlowTaskContext(&taskTemplate, f.resources, f.extendedResourcesOverride, k8s.PluginState{}) r, err := tensorflowResourceHandler.BuildResource(context.TODO(), taskContext) assert.NoError(t, err) assert.NotNil(t, r) @@ -535,7 +534,7 @@ func TestGetTaskPhase(t *testing.T) { return dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, 1, conditionType) } - taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1, 1)), resourceRequirements, nil) + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1, 1)), resourceRequirements, nil, k8s.PluginState{}) taskPhase, err := tensorflowResourceHandler.GetTaskPhase(ctx, taskCtx, dummyTensorFlowJobResourceCreator(commonOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -567,6 +566,23 @@ func TestGetTaskPhase(t *testing.T) { assert.Nil(t, err) } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1, 1)), resourceRequirements, nil, pluginState) + + taskPhase, err := tensorflowResourceHandler.GetTaskPhase(ctx, taskCtx, dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, 1, commonOp.JobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetLogs(t *testing.T) { assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ IsKubernetesEnabled: true, @@ -580,7 +596,7 @@ func TestGetLogs(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, evaluatorReplicas, commonOp.JobRunning) - taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas)), resourceRequirements, nil) + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas)), resourceRequirements, nil, k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.TensorflowTaskType, tensorFlowJob.ObjectMeta, false, workers, psReplicas, chiefReplicas, evaluatorReplicas) assert.NoError(t, err) @@ -627,7 +643,7 @@ func TestReplicaCounts(t *testing.T) { tfObj := dummyTensorFlowCustomObj(test.workerReplicaCount, test.psReplicaCount, test.chiefReplicaCount, test.evaluatorReplicaCount) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if test.expectError { assert.Error(t, err) assert.Nil(t, resource) @@ -767,7 +783,7 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -833,7 +849,7 @@ func TestBuildResourceTensorFlowV1WithOnlyWorker(t *testing.T) { taskTemplate := dummyTensorFlowTaskTemplate("v1 with only worker replica", taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -907,7 +923,7 @@ func TestBuildResourceTensorFlowV1ResourceTolerations(t *testing.T) { taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource)