diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go index ce8e16f1fce..9451115c1e7 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go @@ -148,7 +148,7 @@ func dummyDaskTaskTemplate(customImage string, resources *core.Resources, podTem } } -func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, isInterruptible bool) pluginsCore.TaskExecutionContext { +func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, isInterruptible bool, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} @@ -199,11 +199,10 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc taskExecutionMetadata.OnGetOverrides().Return(overrides) taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) - inputState := k8s.PluginState{} pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = inputState + *(v.(*k8s.PluginState)) = pluginState return 0 }, func(v interface{}) error { @@ -218,7 +217,7 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -329,7 +328,7 @@ func TestBuildResourceDaskCustomImages(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate(customImage, nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -362,7 +361,7 @@ func TestBuildResourceDaskDefaultResoureRequirements(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -419,7 +418,7 @@ func TestBuildResourcesDaskCustomResoureRequirements(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", &protobufResources, "") - taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -474,7 +473,7 @@ func TestBuildResourceDaskInterruptible(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, true) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, true, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -508,7 +507,7 @@ func TestBuildResouceDaskUsePodTemplate(t *testing.T) { flytek8s.DefaultPodTemplateStore.Store(podTemplate) daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, podTemplateName) - taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -628,7 +627,7 @@ func TestBuildResourceDaskExtendedResources(t *testing.T) { t.Run(f.name, func(t *testing.T) { taskTemplate := dummyDaskTaskTemplate("", nil, "") taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyDaskTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, false) + taskContext := dummyDaskTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, false, k8s.PluginState{}) daskResourceHandler := daskResourceHandler{} r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -694,7 +693,7 @@ func TestBuildIdentityResourceDask(t *testing.T) { } taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{}) identityResources, err := daskResourceHandler.BuildIdentityResource(context.TODO(), taskContext.TaskExecutionMetadata()) if err != nil { panic(err) @@ -707,7 +706,7 @@ func TestGetTaskPhaseDask(t *testing.T) { ctx := context.TODO() taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false) + taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{}) taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob("")) assert.NoError(t, err) @@ -751,3 +750,21 @@ func TestGetTaskPhaseDask(t *testing.T) { assert.NotNil(t, taskPhase.Info().Logs) assert.Nil(t, err) } + +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + daskResourceHandler := daskResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseInitializing, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskTemplate := dummyDaskTaskTemplate("", nil, "") + taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, pluginState) + + taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 52c0ca9a65a..9b1ff39075c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -118,7 +118,7 @@ func dummyMPITaskTemplate(id string, args ...interface{}) *core.TaskTemplate { } } -func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext { +func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -172,11 +172,10 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) - inputState := k8s.PluginState{} pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = inputState + *(v.(*k8s.PluginState)) = pluginState return 0 }, func(v interface{}) error { @@ -289,7 +288,7 @@ func dummyMPIJobResource(mpiResourceHandler mpiOperatorResourceHandler, mpiObj := dummyMPICustomObj(workers, launcher, slots) taskTemplate := dummyMPITaskTemplate(mpiID, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if err != nil { panic(err) } @@ -316,7 +315,7 @@ func TestBuildResourceMPI(t *testing.T) { mpiObj := dummyMPICustomObj(100, 50, 1) taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -352,13 +351,13 @@ func TestBuildResourceMPIForWrongInput(t *testing.T) { mpiObj := dummyMPICustomObj(0, 0, 1) taskTemplate := dummyMPITaskTemplate(mpiID, mpiObj) - _, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + _, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.Error(t, err) mpiObj = dummyMPICustomObj(1, 1, 1) taskTemplate = dummyMPITaskTemplate(mpiID2, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) app, ok := resource.(*kubeflowv1.MPIJob) assert.Nil(t, err) assert.Equal(t, true, ok) @@ -472,7 +471,7 @@ func TestBuildResourceMPIExtendedResources(t *testing.T) { mpiObj := dummyMPICustomObj(100, 50, 1) taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj) taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyMPITaskContext(taskTemplate, f.resources, f.extendedResourcesOverride) + taskContext := dummyMPITaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, k8s.PluginState{}) mpiResourceHandler := mpiOperatorResourceHandler{} r, err := mpiResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -504,7 +503,7 @@ func TestGetTaskPhase(t *testing.T) { return dummyMPIJobResource(mpiResourceHandler, 2, 1, 1, conditionType) } - taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil) + taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil, k8s.PluginState{}) taskPhase, err := mpiResourceHandler.GetTaskPhase(ctx, taskCtx, dummyMPIJobResourceCreator(mpiOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -536,6 +535,23 @@ func TestGetTaskPhase(t *testing.T) { assert.Nil(t, err) } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + mpiResourceHandler := mpiOperatorResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil, pluginState) + + taskPhase, err := mpiResourceHandler.GetTaskPhase(ctx, taskCtx, dummyMPIJobResource(mpiResourceHandler, 2, 1, 1, mpiOp.JobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetLogs(t *testing.T) { assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ IsKubernetesEnabled: true, @@ -548,7 +564,7 @@ func TestGetLogs(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning) - taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil) + taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil, k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, false, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) @@ -581,7 +597,7 @@ func TestReplicaCounts(t *testing.T) { mpiObj := dummyMPICustomObj(test.workerReplicaCount, test.launcherReplicaCount, 1) taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if test.expectError { assert.Error(t, err) assert.Nil(t, resource) @@ -705,7 +721,7 @@ func TestBuildResourceMPIV1(t *testing.T) { taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -780,7 +796,7 @@ func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) { taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -880,7 +896,7 @@ func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) { taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -896,7 +912,7 @@ func TestGetReplicaCount(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} tfObj := dummyMPICustomObj(1, 1, 0) taskTemplate := dummyMPITaskTemplate("the job", tfObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) MPIJob, ok := resource.(*kubeflowv1.MPIJob) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 1e9c038d120..b6e27aea691 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -124,7 +124,7 @@ func dummyPytorchTaskTemplate(id string, args ...interface{}) *core.TaskTemplate } } -func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionContext { +func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -178,11 +178,10 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1. taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) - inputState := k8s.PluginState{} pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = inputState + *(v.(*k8s.PluginState)) = pluginState return 0 }, func(v interface{}) error { @@ -294,7 +293,7 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl ptObj := dummyPytorchCustomObj(workers) taskTemplate := dummyPytorchTaskTemplate("job1", ptObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) if err != nil { panic(err) } @@ -322,7 +321,7 @@ func TestBuildResourcePytorchElastic(t *testing.T) { ptObj := dummyElasticPytorchCustomObj(2, plugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}) taskTemplate := dummyPytorchTaskTemplate("job2", ptObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -365,7 +364,7 @@ func TestBuildResourcePytorch(t *testing.T) { ptObj := dummyPytorchCustomObj(100) taskTemplate := dummyPytorchTaskTemplate("job3", ptObj) - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -447,7 +446,7 @@ func TestBuildResourcePytorchContainerImage(t *testing.T) { for _, f := range fixtures { t.Run(tCfg.name+" "+f.name, func(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job", tCfg.plugin) - taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride) + taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride, k8s.PluginState{}) pytorchResourceHandler := pytorchOperatorResourceHandler{} r, err := pytorchResourceHandler.BuildResource(context.TODO(), taskContext) assert.NoError(t, err) @@ -589,7 +588,7 @@ func TestBuildResourcePytorchExtendedResources(t *testing.T) { t.Run(tCfg.name+" "+f.name, func(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job", tCfg.plugin) taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, "") + taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, "", k8s.PluginState{}) pytorchResourceHandler := pytorchOperatorResourceHandler{} r, err := pytorchResourceHandler.BuildResource(context.TODO(), taskContext) assert.NoError(t, err) @@ -622,7 +621,7 @@ func TestGetTaskPhase(t *testing.T) { return dummyPytorchJobResource(pytorchResourceHandler, 2, conditionType) } - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "") + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "", k8s.PluginState{}) taskPhase, err := pytorchResourceHandler.GetTaskPhase(ctx, taskCtx, dummyPytorchJobResourceCreator(commonOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -654,6 +653,23 @@ func TestGetTaskPhase(t *testing.T) { assert.Nil(t, err) } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + pytorchResourceHandler := pytorchOperatorResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "", pluginState) + + taskPhase, err := pytorchResourceHandler.GetTaskPhase(ctx, taskCtx, dummyPytorchJobResource(pytorchResourceHandler, 4, commonOp.JobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetLogs(t *testing.T) { assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ IsKubernetesEnabled: true, @@ -665,7 +681,7 @@ func TestGetLogs(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "") + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) @@ -685,7 +701,7 @@ func TestGetLogsElastic(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "") + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) @@ -716,7 +732,7 @@ func TestReplicaCounts(t *testing.T) { ptObj := dummyPytorchCustomObj(test.workerReplicaCount) taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) if test.expectError { assert.Error(t, err) assert.Nil(t, res) @@ -834,7 +850,7 @@ func TestBuildResourcePytorchV1(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -896,7 +912,7 @@ func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -978,7 +994,7 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -1086,7 +1102,7 @@ func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -1121,7 +1137,7 @@ func TestBuildResourcePytorchV1WithElastic(t *testing.T) { taskTemplate.TaskTypeVersion = 1 pytorchResourceHandler := pytorchOperatorResourceHandler{} - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -1170,7 +1186,7 @@ func TestBuildResourcePytorchV1WithZeroWorker(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.Error(t, err) } } @@ -1188,7 +1204,7 @@ func TestGetReplicaCount(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} tfObj := dummyPytorchCustomObj(1) taskTemplate := dummyPytorchTaskTemplate("the job", tfObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) PytorchJob, ok := resource.(*kubeflowv1.PyTorchJob) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index f755b12ddf7..bea37d10063 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -119,7 +119,7 @@ func dummyTensorFlowTaskTemplate(id string, args ...interface{}) *core.TaskTempl } } -func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext { +func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -173,11 +173,10 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *core taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) - inputState := k8s.PluginState{} pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = inputState + *(v.(*k8s.PluginState)) = pluginState return 0 }, func(v interface{}) error { @@ -290,7 +289,7 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso tfObj := dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if err != nil { panic(err) } @@ -315,7 +314,7 @@ func TestGetReplicaCount(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} tfObj := dummyTensorFlowCustomObj(1, 0, 0, 0) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) tensorflowJob, ok := resource.(*kubeflowv1.TFJob) @@ -333,7 +332,7 @@ func TestBuildResourceTensorFlow(t *testing.T) { tfObj := dummyTensorFlowCustomObj(100, 50, 1, 1) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -528,7 +527,7 @@ func TestBuildResourceTensorFlowExtendedResources(t *testing.T) { taskTemplate := *tCfg.taskTemplate taskTemplate.ExtendedResources = f.extendedResourcesBase tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - taskContext := dummyTensorFlowTaskContext(&taskTemplate, f.resources, f.extendedResourcesOverride) + taskContext := dummyTensorFlowTaskContext(&taskTemplate, f.resources, f.extendedResourcesOverride, k8s.PluginState{}) r, err := tensorflowResourceHandler.BuildResource(context.TODO(), taskContext) assert.NoError(t, err) assert.NotNil(t, r) @@ -561,7 +560,7 @@ func TestGetTaskPhase(t *testing.T) { return dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, 1, conditionType) } - taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1, 1)), resourceRequirements, nil) + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1, 1)), resourceRequirements, nil, k8s.PluginState{}) taskPhase, err := tensorflowResourceHandler.GetTaskPhase(ctx, taskCtx, dummyTensorFlowJobResourceCreator(commonOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -593,6 +592,23 @@ func TestGetTaskPhase(t *testing.T) { assert.Nil(t, err) } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1, 1)), resourceRequirements, nil, pluginState) + + taskPhase, err := tensorflowResourceHandler.GetTaskPhase(ctx, taskCtx, dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, 1, commonOp.JobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetLogs(t *testing.T) { assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ IsKubernetesEnabled: true, @@ -606,7 +622,7 @@ func TestGetLogs(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, evaluatorReplicas, commonOp.JobRunning) - taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas)), resourceRequirements, nil) + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas)), resourceRequirements, nil, k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.TensorflowTaskType, tensorFlowJob.ObjectMeta, false, workers, psReplicas, chiefReplicas, evaluatorReplicas) assert.NoError(t, err) @@ -653,7 +669,7 @@ func TestReplicaCounts(t *testing.T) { tfObj := dummyTensorFlowCustomObj(test.workerReplicaCount, test.psReplicaCount, test.chiefReplicaCount, test.evaluatorReplicaCount) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if test.expectError { assert.Error(t, err) assert.Nil(t, resource) @@ -868,7 +884,7 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -957,7 +973,7 @@ func TestBuildResourceTensorFlowV1WithOnlyWorker(t *testing.T) { taskTemplate := dummyTensorFlowTaskTemplate("v1 with only worker replica", taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -1070,7 +1086,7 @@ func TestBuildResourceTensorFlowV1ResourceTolerations(t *testing.T) { taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 7d4de88a716..77152023083 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -676,7 +676,7 @@ func TestInjectLogsSidecar(t *testing.T) { } } -func newPluginContext() k8s.PluginContext { +func newPluginContext(pluginState k8s.PluginState) k8s.PluginContext { plg := &mocks2.PluginContext{} taskExecID := &mocks.TaskExecutionID{} @@ -704,11 +704,10 @@ func newPluginContext() k8s.PluginContext { tskCtx.OnGetTaskExecutionID().Return(taskExecID) plg.OnTaskExecutionMetadata().Return(tskCtx) - inputState := k8s.PluginState{} pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = inputState + *(v.(*k8s.PluginState)) = pluginState return 0 }, func(v interface{}) error { @@ -734,7 +733,7 @@ func init() { func TestGetTaskPhase(t *testing.T) { ctx := context.Background() rayJobResourceHandler := rayJobResourceHandler{} - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { rayJobPhase rayv1.JobDeploymentStatus @@ -765,8 +764,28 @@ func TestGetTaskPhase(t *testing.T) { } } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + rayJobResourceHandler := rayJobResourceHandler{} + + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseInitializing, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + pluginCtx := newPluginContext(pluginState) + + rayObject := &rayv1.RayJob{} + rayObject.Status.JobDeploymentStatus = rayv1.JobDeploymentStatusInitializing + phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) + + assert.NoError(t, err) + assert.Equal(t, phaseInfo.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetEventInfo_LogTemplates(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1.RayJob @@ -865,7 +884,7 @@ func TestGetEventInfo_LogTemplates(t *testing.T) { } func TestGetEventInfo_LogTemplates_V1(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1.RayJob @@ -964,7 +983,7 @@ func TestGetEventInfo_LogTemplates_V1(t *testing.T) { } func TestGetEventInfo_DashboardURL(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1.RayJob @@ -1016,7 +1035,7 @@ func TestGetEventInfo_DashboardURL(t *testing.T) { } func TestGetEventInfo_DashboardURL_V1(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1.RayJob diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 7feeb80e06f..c7959c60bb9 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -97,7 +97,7 @@ func TestGetEventInfo(t *testing.T) { }, }, })) - taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("blah-1", dummySparkConf), false) + taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("blah-1", dummySparkConf), false, k8s.PluginState{}) info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState)) assert.NoError(t, err) assert.Len(t, info.Logs, 6) @@ -172,7 +172,7 @@ func TestGetTaskPhase(t *testing.T) { sparkResourceHandler := sparkResourceHandler{} ctx := context.TODO() - taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("", dummySparkConf), false) + taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("", dummySparkConf), false, k8s.PluginState{}) taskPhase, err := sparkResourceHandler.GetTaskPhase(ctx, taskCtx, dummySparkApplication(sj.NewState)) assert.NoError(t, err) assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseQueued) @@ -234,6 +234,24 @@ func TestGetTaskPhase(t *testing.T) { assert.Nil(t, err) } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + sparkResourceHandler := sparkResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseInitializing, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + + taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("", dummySparkConf), false, pluginState) + + taskPhase, err := sparkResourceHandler.GetTaskPhase(ctx, taskCtx, dummySparkApplication(sj.SubmittedState)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func dummySparkApplication(state sj.ApplicationStateType) *sj.SparkApplication { return &sj.SparkApplication{ @@ -353,7 +371,7 @@ func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec * } } -func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) pluginsCore.TaskExecutionContext { +func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -413,11 +431,10 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) taskExecutionMetadata.On("GetK8sServiceAccount").Return("new-val") taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) - inputState := k8s.PluginState{} pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = inputState + *(v.(*k8s.PluginState)) = pluginState return 0 }, func(v interface{}) error { @@ -576,7 +593,7 @@ func TestBuildResourceContainer(t *testing.T) { defaultConfig := defaultPluginConfig() assert.NoError(t, config.SetK8sPluginConfig(defaultConfig)) - resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true)) + resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})) assert.Nil(t, err) assert.NotNil(t, resource) @@ -724,7 +741,7 @@ func TestBuildResourceContainer(t *testing.T) { dummyConfWithRequest["spark.kubernetes.executor.request.cores"] = "4" taskTemplate = dummySparkTaskTemplateContainer("blah-1", dummyConfWithRequest) - resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) + resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false, k8s.PluginState{})) assert.Nil(t, err) assert.NotNil(t, resource) sparkApp, ok = resource.(*sj.SparkApplication) @@ -734,7 +751,7 @@ func TestBuildResourceContainer(t *testing.T) { assert.Equal(t, dummyConfWithRequest["spark.kubernetes.executor.request.cores"], sparkApp.Spec.SparkConf["spark.kubernetes.executor.limit.cores"]) // Case 3: Interruptible False - resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) + resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false, k8s.PluginState{})) assert.Nil(t, err) assert.NotNil(t, resource) sparkApp, ok = resource.(*sj.SparkApplication) @@ -782,7 +799,7 @@ func TestBuildResourceContainer(t *testing.T) { // Case 4: Invalid Spark Task-Template taskTemplate.Custom = nil - resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) + resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false, k8s.PluginState{})) assert.NotNil(t, err) assert.Nil(t, resource) } @@ -802,7 +819,7 @@ func TestBuildResourcePodTemplate(t *testing.T) { taskTemplate.GetK8SPod() sparkResourceHandler := sparkResourceHandler{} - taskCtx := dummySparkTaskContext(taskTemplate, true) + taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{}) resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx) assert.Nil(t, err)