diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go index ce8e16f1fc..9451115c1e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go @@ -148,7 +148,7 @@ func dummyDaskTaskTemplate(customImage string, resources *core.Resources, podTem } } -func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, isInterruptible bool) pluginsCore.TaskExecutionContext { +func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, isInterruptible bool, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} @@ -199,11 +199,10 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc taskExecutionMetadata.OnGetOverrides().Return(overrides) taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) - inputState := k8s.PluginState{} pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = inputState + *(v.(*k8s.PluginState)) = pluginState return 0 }, func(v interface{}) error { @@ -218,7 +217,7 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -329,7 +328,7 @@ func TestBuildResourceDaskCustomImages(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate(customImage, nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -362,7 +361,7 @@ func TestBuildResourceDaskDefaultResoureRequirements(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -419,7 +418,7 @@ func TestBuildResourcesDaskCustomResoureRequirements(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", &protobufResources, "") - taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -474,7 +473,7 @@ func TestBuildResourceDaskInterruptible(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, true) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, true, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -508,7 +507,7 @@ func TestBuildResouceDaskUsePodTemplate(t *testing.T) { flytek8s.DefaultPodTemplateStore.Store(podTemplate) daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, podTemplateName) - taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -628,7 +627,7 @@ func TestBuildResourceDaskExtendedResources(t *testing.T) { t.Run(f.name, func(t *testing.T) { taskTemplate := dummyDaskTaskTemplate("", nil, "") taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyDaskTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, false) + taskContext := dummyDaskTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, false, k8s.PluginState{}) daskResourceHandler := daskResourceHandler{} r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -694,7 +693,7 @@ func TestBuildIdentityResourceDask(t *testing.T) { } taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{}) identityResources, err := daskResourceHandler.BuildIdentityResource(context.TODO(), taskContext.TaskExecutionMetadata()) if err != nil { panic(err) @@ -707,7 +706,7 @@ func TestGetTaskPhaseDask(t *testing.T) { ctx := context.TODO() taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false) + taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{}) taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob("")) assert.NoError(t, err) @@ -751,3 +750,21 @@ func TestGetTaskPhaseDask(t *testing.T) { assert.NotNil(t, taskPhase.Info().Logs) assert.Nil(t, err) } + +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + daskResourceHandler := daskResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseInitializing, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskTemplate := dummyDaskTaskTemplate("", nil, "") + taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, pluginState) + + taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 42ab0bf5ac..712b2f634e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -681,7 +681,7 @@ func TestInjectLogsSidecar(t *testing.T) { } } -func newPluginContext() k8s.PluginContext { +func newPluginContext(pluginState k8s.PluginState) k8s.PluginContext { plg := &mocks2.PluginContext{} taskExecID := &mocks.TaskExecutionID{} @@ -709,11 +709,10 @@ func newPluginContext() k8s.PluginContext { tskCtx.OnGetTaskExecutionID().Return(taskExecID) plg.OnTaskExecutionMetadata().Return(tskCtx) - inputState := k8s.PluginState{} pluginStateReaderMock := mocks.PluginStateReader{} - pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( func(v interface{}) uint8 { - *(v.(*k8s.PluginState)) = inputState + *(v.(*k8s.PluginState)) = pluginState return 0 }, func(v interface{}) error { @@ -739,7 +738,7 @@ func init() { func TestGetTaskPhase(t *testing.T) { ctx := context.Background() rayJobResourceHandler := rayJobResourceHandler{} - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { rayJobPhase rayv1alpha1.JobStatus @@ -783,7 +782,7 @@ func TestGetTaskPhase(t *testing.T) { func TestGetTaskPhase_V1(t *testing.T) { ctx := context.Background() rayJobResourceHandler := rayJobResourceHandler{} - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { rayJobPhase rayv1.JobStatus @@ -824,8 +823,28 @@ func TestGetTaskPhase_V1(t *testing.T) { } } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + rayJobResourceHandler := rayJobResourceHandler{} + + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseInitializing, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + pluginCtx := newPluginContext(pluginState) + + rayObject := &rayv1alpha1.RayJob{} + rayObject.Status.JobDeploymentStatus = rayv1alpha1.JobDeploymentStatusInitializing + phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) + + assert.NoError(t, err) + assert.Equal(t, phaseInfo.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetEventInfo_LogTemplates(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1alpha1.RayJob @@ -924,7 +943,7 @@ func TestGetEventInfo_LogTemplates(t *testing.T) { } func TestGetEventInfo_LogTemplates_V1(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1.RayJob @@ -1023,7 +1042,7 @@ func TestGetEventInfo_LogTemplates_V1(t *testing.T) { } func TestGetEventInfo_DashboardURL(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1alpha1.RayJob @@ -1075,7 +1094,7 @@ func TestGetEventInfo_DashboardURL(t *testing.T) { } func TestGetEventInfo_DashboardURL_V1(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1.RayJob