diff --git a/pkg/cache/application.go b/pkg/cache/application.go index 8438957ba..de34ed9c8 100644 --- a/pkg/cache/application.go +++ b/pkg/cache/application.go @@ -48,7 +48,7 @@ type Application struct { groups []string taskMap map[string]*Task tags map[string]string - taskGroups []TaskGroup + taskGroups map[string]TaskGroup taskGroupsDefinition string schedulingParamsDefinition string placeholderOwnerReferences []metav1.OwnerReference @@ -80,7 +80,7 @@ func NewApplication(appID, queueName, user string, groups []string, tags map[str taskMap: taskMap, tags: tags, sm: newAppState(), - taskGroups: make([]TaskGroup, 0), + taskGroups: make(map[string]TaskGroup), lock: &locking.RWMutex{}, schedulerAPI: scheduler, placeholderTimeoutInSec: 0, @@ -163,12 +163,33 @@ func (app *Application) GetSchedulingParamsDefinition() string { return app.schedulingParamsDefinition } +// check if the task-groups is correct +func (app *Application) checkTaskGroups(taskGroups []TaskGroup, pod *v1.Pod) { + tgs := make(map[string]TaskGroup) + + for _, taskGroup := range taskGroups { + if _, exists := tgs[taskGroup.Name]; exists { + // for duplicated task-group, users will receive the event + log.Log(log.ShimCacheApplication).Warn("duplicate task-group within the task-groups", + zap.String("appID", app.applicationID), + zap.String("groupName", taskGroup.Name)) + events.GetRecorder().Eventf(pod.DeepCopy(), nil, v1.EventTypeWarning, "GangScheduling", + "TaskGroupDuplicated", "Application %s has duplicated task-group %s", app.applicationID, taskGroup.Name) + } else { + tgs[taskGroup.Name] = taskGroup + } + } +} + func (app *Application) setTaskGroups(taskGroups []TaskGroup) { app.lock.Lock() defer app.lock.Unlock() - app.taskGroups = taskGroups - for _, taskGroup := range app.taskGroups { - app.placeholderAsk = common.Add(app.placeholderAsk, common.GetTGResource(taskGroup.MinResource, int64(taskGroup.MinMember))) + for _, taskGroup := range taskGroups { + // for duplicated task-group, will no longer be added to the app + if _, exists := app.taskGroups[taskGroup.Name]; !exists { + app.taskGroups[taskGroup.Name] = taskGroup + app.placeholderAsk = common.Add(app.placeholderAsk, common.GetTGResource(taskGroup.MinResource, int64(taskGroup.MinMember))) + } } } @@ -181,7 +202,15 @@ func (app *Application) getPlaceholderAsk() *si.Resource { func (app *Application) getTaskGroups() []TaskGroup { app.lock.RLock() defer app.lock.RUnlock() - return app.taskGroups + + if len(app.taskGroups) > 0 { + taskGroups := make([]TaskGroup, 0, len(app.taskGroups)) + for _, taskGroup := range app.taskGroups { + taskGroups = append(taskGroups, taskGroup) + } + return taskGroups + } + return nil } func (app *Application) setPlaceholderOwnerReferences(ref []metav1.OwnerReference) { diff --git a/pkg/cache/context.go b/pkg/cache/context.go index 613bbc520..cb9d2e4e8 100644 --- a/pkg/cache/context.go +++ b/pkg/cache/context.go @@ -338,7 +338,7 @@ func (ctx *Context) ensureAppAndTaskCreated(pod *v1.Pod, app *Application) { } app = ctx.addApplication(&AddApplicationRequest{ Metadata: appMeta, - }) + }, pod) } // get task metadata @@ -957,10 +957,10 @@ func (ctx *Context) AddApplication(request *AddApplicationRequest) *Application ctx.lock.Lock() defer ctx.lock.Unlock() - return ctx.addApplication(request) + return ctx.addApplication(request, nil) } -func (ctx *Context) addApplication(request *AddApplicationRequest) *Application { +func (ctx *Context) addApplication(request *AddApplicationRequest, pod *v1.Pod) *Application { log.Log(log.ShimContext).Debug("AddApplication", zap.Any("Request", request)) if app := ctx.getApplication(request.Metadata.ApplicationID); app != nil { return app @@ -980,6 +980,7 @@ func (ctx *Context) addApplication(request *AddApplicationRequest) *Application request.Metadata.Groups, request.Metadata.Tags, ctx.apiProvider.GetAPIs().SchedulerAPI) + app.checkTaskGroups(request.Metadata.TaskGroups, pod) app.setTaskGroups(request.Metadata.TaskGroups) app.setTaskGroupsDefinition(request.Metadata.Tags[constants.AnnotationTaskGroups]) app.setSchedulingParamsDefinition(request.Metadata.Tags[constants.AnnotationSchedulingPolicyParam]) diff --git a/pkg/cache/placeholder_test.go b/pkg/cache/placeholder_test.go index cbabd2891..7203324dd 100644 --- a/pkg/cache/placeholder_test.go +++ b/pkg/cache/placeholder_test.go @@ -119,9 +119,10 @@ func TestNewPlaceholder(t *testing.T) { assert.Equal(t, app.placeholderAsk.Resources[siCommon.Memory].Value, int64(10*1024*1000*1000)) assert.Equal(t, app.placeholderAsk.Resources["pods"].Value, int64(10)) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, holder.appID, appID) - assert.Equal(t, holder.taskGroupName, app.taskGroups[0].Name) + assert.Equal(t, holder.taskGroupName, tgs[0].Name) assert.Equal(t, holder.pod.Spec.SchedulerName, constants.SchedulerName) assert.Equal(t, holder.pod.Name, "ph-name") assert.Equal(t, holder.pod.Namespace, namespace) @@ -132,7 +133,7 @@ func TestNewPlaceholder(t *testing.T) { "labelKey1": "labelKeyValue1", }) assert.Equal(t, len(holder.pod.Annotations), 7, "unexpected number of annotations") - assert.Equal(t, holder.pod.Annotations[constants.AnnotationTaskGroupName], app.taskGroups[0].Name) + assert.Equal(t, holder.pod.Annotations[constants.AnnotationTaskGroupName], tgs[0].Name) assert.Equal(t, holder.pod.Annotations[constants.AnnotationPlaceholderFlag], constants.True) assert.Equal(t, holder.pod.Annotations["annotationKey0"], "annotationValue0") assert.Equal(t, holder.pod.Annotations["annotationKey1"], "annotationValue1") @@ -163,7 +164,8 @@ func TestNewPlaceholderWithNodeSelectors(t *testing.T) { "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.NodeSelector), 2) assert.Equal(t, holder.pod.Spec.NodeSelector["nodeType"], "test") assert.Equal(t, holder.pod.Spec.NodeSelector["nodeState"], "healthy") @@ -178,7 +180,8 @@ func TestNewPlaceholderWithTolerations(t *testing.T) { "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.Tolerations), 1) tlr := holder.pod.Spec.Tolerations[0] assert.Equal(t, tlr.Key, "key1") @@ -196,7 +199,8 @@ func TestNewPlaceholderWithAffinity(t *testing.T) { "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.Affinity.PodAffinity.RequiredDuringSchedulingIgnoredDuringExecution), 1) term := holder.pod.Spec.Affinity.PodAffinity.RequiredDuringSchedulingIgnoredDuringExecution assert.Equal(t, term[0].TopologyKey, "topologyKey") @@ -215,14 +219,16 @@ func TestNewPlaceholderTaskGroupsDefinition(t *testing.T) { app := NewApplication(appID, queue, "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, "", holder.pod.Annotations[constants.AnnotationTaskGroups]) app = NewApplication(appID, queue, "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) app.setTaskGroupsDefinition("taskGroupsDef") - holder = newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs = app.getTaskGroups() + holder = newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, "taskGroupsDef", holder.pod.Annotations[constants.AnnotationTaskGroups]) var priority *int32 assert.Equal(t, priority, holder.pod.Spec.Priority) @@ -234,7 +240,9 @@ func TestNewPlaceholderExtendedResources(t *testing.T) { app := NewApplication(appID, queue, "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.Containers[0].Resources.Requests), 5, "expected requests not found") assert.Equal(t, len(holder.pod.Spec.Containers[0].Resources.Limits), 5, "expected limits not found") assert.Equal(t, holder.pod.Spec.Containers[0].Resources.Limits[gpu], holder.pod.Spec.Containers[0].Resources.Requests[gpu], "gpu: expected same value for request and limit") @@ -271,7 +279,8 @@ func TestNewPlaceholderWithPriorityClassName(t *testing.T) { app.taskMap[taskID1] = task1 app.setOriginatingTask(task1) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.Containers[0].Resources.Requests), 5, "expected requests not found") assert.Equal(t, len(holder.pod.Spec.Containers[0].Resources.Limits), 5, "expected limits not found") assert.Equal(t, holder.pod.Spec.Containers[0].Resources.Limits[gpu], holder.pod.Spec.Containers[0].Resources.Requests[gpu], "gpu: expected same value for request and limit") @@ -287,7 +296,8 @@ func TestNewPlaceholderWithTopologySpreadConstraints(t *testing.T) { "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.TopologySpreadConstraints), 1) assert.Equal(t, holder.pod.Spec.TopologySpreadConstraints[0].MaxSkew, int32(1)) assert.Equal(t, holder.pod.Spec.TopologySpreadConstraints[0].TopologyKey, v1.LabelTopologyZone) @@ -297,3 +307,30 @@ func TestNewPlaceholderWithTopologySpreadConstraints(t *testing.T) { "labelKey1": "labelKeyValue1", }) } + +func TestNewPlaceholderWithDuplicatedTaskGroup(t *testing.T) { + mockedSchedulerAPI := newMockSchedulerAPI() + + // in this case, suppose pod1 triggers the creation of app. + pod1 := &v1.Pod{ + TypeMeta: metav1.TypeMeta{ + Kind: "Pod", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "gang-scheduling-job-app01-0", + UID: "UID-01", + }, + } + app := NewApplication(appID, queue, + "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) + + var duplicatedTaskGroups = make([]TaskGroup, 0, 2) + duplicatedTaskGroups = append(duplicatedTaskGroups, taskGroups[0]) + duplicatedTaskGroups = append(duplicatedTaskGroups, taskGroups[0]) + + app.checkTaskGroups(duplicatedTaskGroups, pod1) + app.setTaskGroups(duplicatedTaskGroups) + tgs := app.getTaskGroups() + assert.Equal(t, len(tgs), 1) +}