diff --git a/pkg/cache/context.go b/pkg/cache/context.go index 0f7764ade..474a1deb9 100644 --- a/pkg/cache/context.go +++ b/pkg/cache/context.go @@ -202,7 +202,7 @@ func (ctx *Context) updateNodeInternal(node *v1.Node, register bool) { if applicationID == "" { ctx.updateForeignPod(pod) } else { - ctx.updateYuniKornPod(pod) + ctx.updateYuniKornPod(applicationID, pod) } } @@ -296,22 +296,26 @@ func (ctx *Context) UpdatePod(_, newObj interface{}) { log.Log(log.ShimContext).Error("failed to update pod", zap.Error(err)) return } - if utils.GetApplicationIDFromPod(pod) == "" { + applicationID := utils.GetApplicationIDFromPod(pod) + if applicationID == "" { ctx.updateForeignPod(pod) } else { - ctx.updateYuniKornPod(pod) + ctx.updateYuniKornPod(applicationID, pod) } } -func (ctx *Context) updateYuniKornPod(pod *v1.Pod) { - // treat terminated pods like a remove - if utils.IsPodTerminated(pod) { - if taskMeta, ok := getTaskMetadata(pod); ok { - if app := ctx.getApplication(taskMeta.ApplicationID); app != nil { - ctx.notifyTaskComplete(taskMeta.ApplicationID, taskMeta.TaskID) - } +func (ctx *Context) updateYuniKornPod(appID string, pod *v1.Pod) { + var app *Application + taskID := string(pod.UID) + if app = ctx.getApplication(appID); app != nil { + if task, err := app.GetTask(taskID); task != nil && err == nil { + task.setTaskPod(pod) } + } + // treat terminated pods like a remove + if utils.IsPodTerminated(pod) { + ctx.notifyTaskComplete(appID, taskID) log.Log(log.ShimContext).Debug("Request to update terminated pod, removing from cache", zap.String("podName", pod.Name)) ctx.schedulerCache.RemovePod(pod) return diff --git a/pkg/cache/context_test.go b/pkg/cache/context_test.go index 002b3c813..028bd8173 100644 --- a/pkg/cache/context_test.go +++ b/pkg/cache/context_test.go @@ -526,6 +526,11 @@ func TestUpdatePod(t *testing.T) { context.UpdatePod(pod1, pod3) pod = context.schedulerCache.GetPod(uid1) assert.Check(t, pod == nil, "pod still found after termination") + app := context.getApplication("yunikorn-test-00001") + // ensure that an updated pod is updated inside the Task + task, err := app.GetTask("UID-00001") + assert.NilError(t, err) + assert.Assert(t, task.GetTaskPod() == pod3, "task pod has not been updated") // ensure a non-terminated pod is updated context.UpdatePod(pod1, pod2) diff --git a/pkg/cache/task.go b/pkg/cache/task.go index 97f041ea2..02b07d16e 100644 --- a/pkg/cache/task.go +++ b/pkg/cache/task.go @@ -176,7 +176,7 @@ func (task *Task) getNodeName() string { } func (task *Task) DeleteTaskPod() error { - return task.context.apiProvider.GetAPIs().KubeClient.Delete(task.pod) + return task.context.apiProvider.GetAPIs().KubeClient.Delete(task.GetTaskPod()) } func (task *Task) UpdateTaskPodStatus(pod *v1.Pod) (*v1.Pod, error) { @@ -544,9 +544,11 @@ func (task *Task) releaseAllocation() { // this reduces the scheduling overhead by blocking such // request away from the core scheduler. func (task *Task) sanityCheckBeforeScheduling() error { + task.lock.RLock() // Check PVCs used by the pod namespace := task.pod.Namespace manifest := &(task.pod.Spec) + task.lock.RUnlock() for i := range manifest.Volumes { volume := &manifest.Volumes[i] if volume.PersistentVolumeClaim == nil { @@ -599,3 +601,9 @@ func (task *Task) failWithEvent(errorMessage, actionReason string) { events.GetRecorder().Eventf(task.pod.DeepCopy(), nil, v1.EventTypeWarning, actionReason, actionReason, errorMessage) } + +func (task *Task) setTaskPod(pod *v1.Pod) { + task.lock.Lock() + defer task.lock.Unlock() + task.pod = pod +}