From 91d6d401babe332cb1b3ee7f971641efb064586f Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Thu, 1 Aug 2024 23:46:52 -0700 Subject: [PATCH] Fix nil pointer when task plugin load returns error (#5622) Signed-off-by: Bugra Gedik --- .../pkg/controller/nodes/task/handler.go | 7 ++++--- .../pkg/controller/nodes/task/handler_test.go | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index d1595890d8..9ec47985c9 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -248,13 +248,14 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error logger.Infof(ctx, "Loading Plugin [%s] ENABLED", p.ID) cp, err := pluginCore.LoadPlugin(ctx, sCtxFinal, p) + if err != nil { + return regErrors.Wrapf(err, "failed to load plugin - %s", p.ID) + } + if cp.GetID() == agent.ID { t.agentService.CorePlugin = cp } - if err != nil { - return regErrors.Wrapf(err, "failed to load plugin - %s", p.ID) - } // For every default plugin for a task type specified in flytepropeller config we validate that the plugin's // static definition includes that task type as something it is registered to handle. for _, tt := range p.RegisteredTaskTypes { diff --git a/flytepropeller/pkg/controller/nodes/task/handler_test.go b/flytepropeller/pkg/controller/nodes/task/handler_test.go index 4e6798cfef..31e1be9a7f 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/task/handler_test.go @@ -126,6 +126,8 @@ func Test_task_Setup(t *testing.T) { k8sPluginDefault := &pluginK8sMocks.Plugin{} k8sPluginDefault.OnGetProperties().Return(pluginK8s.PluginProperties{}) + loadErrorPluginType := "loadError" + corePluginEntry := pluginCore.PluginEntry{ ID: corePluginType, RegisteredTaskTypes: []pluginCore.TaskType{corePluginType}, @@ -154,6 +156,13 @@ func Test_task_Setup(t *testing.T) { RegisteredTaskTypes: []pluginCore.TaskType{k8sPluginDefaultType}, ResourceToWatch: &v1.Pod{}, } + loadErrorPluginEntry := pluginCore.PluginEntry{ + ID: loadErrorPluginType, + RegisteredTaskTypes: []pluginCore.TaskType{loadErrorPluginType}, + LoadPlugin: func(ctx context.Context, iCtx pluginCore.SetupContext) (pluginCore.Plugin, error) { + return nil, fmt.Errorf("test") + }, + } type wantFields struct { pluginIDs map[pluginCore.TaskType]string @@ -232,6 +241,15 @@ func Test_task_Setup(t *testing.T) { }, }, false}, + {"load-error", + testPluginRegistry{ + core: []pluginCore.PluginEntry{loadErrorPluginEntry}, + k8s: []pluginK8s.PluginEntry{}, + }, + []string{loadErrorPluginType}, + map[string]string{corePluginType: loadErrorPluginType}, + wantFields{}, + true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {