diff --git a/pkg/cache/application.go b/pkg/cache/application.go index 3bedb132e..a6ec1f3f8 100644 --- a/pkg/cache/application.go +++ b/pkg/cache/application.go @@ -115,14 +115,10 @@ func (app *Application) canHandle(ev events.ApplicationEvent) bool { return app.sm.Can(ev.GetEvent()) } -func (app *Application) GetTask(taskID string) (*Task, error) { +func (app *Application) GetTask(taskID string) *Task { app.lock.RLock() defer app.lock.RUnlock() - if task, ok := app.taskMap[taskID]; ok { - return task, nil - } - return nil, fmt.Errorf("task %s doesn't exist in application %s", - taskID, app.applicationID) + return app.taskMap[taskID] } func (app *Application) GetApplicationID() string { diff --git a/pkg/cache/application_test.go b/pkg/cache/application_test.go index 5d9a3ce9f..da6f085dc 100644 --- a/pkg/cache/application_test.go +++ b/pkg/cache/application_test.go @@ -1184,9 +1184,7 @@ func TestPlaceholderTimeoutEvents(t *testing.T) { }) assert.Assert(t, task1 != nil) assert.Equal(t, task1.GetTaskID(), "task02") - - _, taskErr := app.GetTask("task02") - assert.NilError(t, taskErr, "Task should exist") + assert.Assert(t, app.GetTask("task02") != nil, "Task should exist") task1.allocationKey = allocationKey diff --git a/pkg/cache/context.go b/pkg/cache/context.go index f9abaed63..0f7764ade 100644 --- a/pkg/cache/context.go +++ b/pkg/cache/context.go @@ -351,7 +351,7 @@ func (ctx *Context) ensureAppAndTaskCreated(pod *v1.Pod) { } // add task if it doesn't already exist - if _, taskErr := app.GetTask(string(pod.UID)); taskErr != nil { + if task := app.GetTask(string(pod.UID)); task == nil { ctx.addTask(&AddTaskRequest{ Metadata: taskMeta, }) @@ -1097,8 +1097,8 @@ func (ctx *Context) addTask(request *AddTaskRequest) *Task { zap.String("appID", request.Metadata.ApplicationID), zap.String("taskID", request.Metadata.TaskID)) if app := ctx.getApplication(request.Metadata.ApplicationID); app != nil { - existingTask, err := app.GetTask(request.Metadata.TaskID) - if err != nil { + existingTask := app.GetTask(request.Metadata.TaskID) + if existingTask == nil { var originator bool // Is this task the originator of the application? @@ -1156,8 +1156,8 @@ func (ctx *Context) getTask(appID string, taskID string) *Task { zap.String("appID", appID)) return nil } - task, err := app.GetTask(taskID) - if err != nil { + task := app.GetTask(taskID) + if task == nil { log.Log(log.ShimContext).Debug("task is not found in applications", zap.String("taskID", taskID), zap.String("appID", appID)) diff --git a/pkg/cache/context_test.go b/pkg/cache/context_test.go index 23f68f2dc..002b3c813 100644 --- a/pkg/cache/context_test.go +++ b/pkg/cache/context_test.go @@ -1007,8 +1007,7 @@ func TestRecoverTask(t *testing.T) { for _, tt := range taskInfoVerifiers { t.Run(tt.taskID, func(t *testing.T) { // verify the info for the recovered task - rt, err := app.GetTask(tt.taskID) - assert.NilError(t, err) + rt := app.GetTask(tt.taskID) assert.Equal(t, rt.GetTaskState(), tt.expectedState) assert.Equal(t, rt.allocationKey, tt.expectedAllocationKey) assert.Equal(t, rt.pod.Name, tt.expectedPodName) @@ -2142,9 +2141,8 @@ func TestTaskRemoveOnCompletion(t *testing.T) { // check removal app.Schedule() - appTask, err := app.GetTask(taskUID1) + appTask := app.GetTask(taskUID1) assert.Assert(t, appTask == nil) - assert.Error(t, err, "task task00001 doesn't exist in application app01") } func TestAssumePod(t *testing.T) { diff --git a/pkg/plugin/scheduler_plugin.go b/pkg/plugin/scheduler_plugin.go index 6d0351ca2..7b46b619b 100644 --- a/pkg/plugin/scheduler_plugin.go +++ b/pkg/plugin/scheduler_plugin.go @@ -302,7 +302,7 @@ func NewSchedulerPlugin(_ context.Context, _ runtime.Object, handle framework.Ha func (sp *YuniKornSchedulerPlugin) getTask(appID, taskID string) (app *cache.Application, task *cache.Task, ok bool) { if app := sp.context.GetApplication(appID); app != nil { - if task, err := app.GetTask(taskID); err == nil { + if task := app.GetTask(taskID); task != nil { return app, task, true } } diff --git a/pkg/shim/scheduler_mock_test.go b/pkg/shim/scheduler_mock_test.go index 1bbe5f02c..1e8f19106 100644 --- a/pkg/shim/scheduler_mock_test.go +++ b/pkg/shim/scheduler_mock_test.go @@ -167,8 +167,7 @@ func (fc *MockScheduler) waitAndAssertTaskState(t *testing.T, appID, taskID, exp assert.Equal(t, app != nil, true) assert.Equal(t, app.GetApplicationID(), appID) - task, err := app.GetTask(taskID) - assert.NilError(t, err, "Task retrieval failed") + task := app.GetTask(taskID) deadline := time.Now().Add(10 * time.Second) for { if task.GetTaskState() == expectedState {