From dfc44b4006f0a2d73f6043526da6506cc9830d76 Mon Sep 17 00:00:00 2001 From: Fabio Graetz Date: Thu, 23 May 2024 21:15:11 +0000 Subject: [PATCH] Fix mpi test Signed-off-by: Fabio Graetz --- .../tasks/plugins/k8s/kfoperators/mpi/mpi_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 900091f78a..52c0ca9a65 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -3,6 +3,7 @@ package mpi import ( "context" "fmt" + "reflect" "testing" "time" @@ -170,6 +171,19 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + + inputState := k8s.PluginState{} + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = inputState + return 0 + }, + func(v interface{}) error { + return nil + }) + + taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock) return taskCtx }