diff --git a/pkg/common/test/schedulerapi_mock.go b/pkg/common/test/schedulerapi_mock.go index c0b781cbf..3d7f11471 100644 --- a/pkg/common/test/schedulerapi_mock.go +++ b/pkg/common/test/schedulerapi_mock.go @@ -27,10 +27,10 @@ import ( ) type SchedulerAPIMock struct { - registerCount int32 - UpdateAllocationCount int32 - UpdateApplicationCount int32 - UpdateNodeCount int32 + registerCount atomic.Int32 + UpdateAllocationCount atomic.Int32 + UpdateApplicationCount atomic.Int32 + UpdateNodeCount atomic.Int32 registerFn func(request *si.RegisterResourceManagerRequest, callback api.ResourceManagerCallback) (*si.RegisterResourceManagerResponse, error) UpdateAllocationFn func(request *si.AllocationRequest) error @@ -41,10 +41,6 @@ type SchedulerAPIMock struct { func NewSchedulerAPIMock() *SchedulerAPIMock { return &SchedulerAPIMock{ - registerCount: int32(0), - UpdateAllocationCount: int32(0), - UpdateApplicationCount: int32(0), - UpdateNodeCount: int32(0), registerFn: func(request *si.RegisterResourceManagerRequest, callback api.ResourceManagerCallback) (response *si.RegisterResourceManagerResponse, e error) { return nil, nil @@ -93,28 +89,28 @@ func (api *SchedulerAPIMock) RegisterResourceManager(request *si.RegisterResourc callback api.ResourceManagerCallback) (*si.RegisterResourceManagerResponse, error) { api.lock.Lock() defer api.lock.Unlock() - atomic.AddInt32(&api.registerCount, 1) + api.registerCount.Add(1) return api.registerFn(request, callback) } func (api *SchedulerAPIMock) UpdateAllocation(request *si.AllocationRequest) error { api.lock.Lock() defer api.lock.Unlock() - atomic.AddInt32(&api.UpdateAllocationCount, 1) + api.UpdateAllocationCount.Add(1) return api.UpdateAllocationFn(request) } func (api *SchedulerAPIMock) UpdateApplication(request *si.ApplicationRequest) error { api.lock.Lock() defer api.lock.Unlock() - atomic.AddInt32(&api.UpdateApplicationCount, 1) + api.UpdateApplicationCount.Add(1) return api.UpdateApplicationFn(request) } func (api *SchedulerAPIMock) UpdateNode(request *si.NodeRequest) error { api.lock.Lock() defer api.lock.Unlock() - atomic.AddInt32(&api.UpdateNodeCount, 1) + api.UpdateNodeCount.Add(1) return api.UpdateNodeFn(request) } @@ -125,26 +121,26 @@ func (api *SchedulerAPIMock) UpdateConfiguration(request *si.UpdateConfiguration } func (api *SchedulerAPIMock) GetRegisterCount() int32 { - return atomic.LoadInt32(&api.registerCount) + return api.registerCount.Load() } func (api *SchedulerAPIMock) GetUpdateAllocationCount() int32 { - return atomic.LoadInt32(&api.UpdateAllocationCount) + return api.UpdateAllocationCount.Load() } func (api *SchedulerAPIMock) GetUpdateApplicationCount() int32 { - return atomic.LoadInt32(&api.UpdateApplicationCount) + return api.UpdateApplicationCount.Load() } func (api *SchedulerAPIMock) GetUpdateNodeCount() int32 { - return atomic.LoadInt32(&api.UpdateNodeCount) + return api.UpdateNodeCount.Load() } func (api *SchedulerAPIMock) ResetAllCounters() { - atomic.StoreInt32(&api.registerCount, 0) - atomic.StoreInt32(&api.UpdateAllocationCount, 0) - atomic.StoreInt32(&api.UpdateApplicationCount, 0) - atomic.StoreInt32(&api.UpdateNodeCount, 0) + api.registerCount.Store(0) + api.UpdateAllocationCount.Store(0) + api.UpdateApplicationCount.Store(0) + api.UpdateNodeCount.Store(0) } func (api *SchedulerAPIMock) Stop() { diff --git a/pkg/dispatcher/dispatch_test.go b/pkg/dispatcher/dispatch_test.go index cdf9a94be..8a2cc3819 100644 --- a/pkg/dispatcher/dispatch_test.go +++ b/pkg/dispatcher/dispatch_test.go @@ -22,7 +22,6 @@ import ( "fmt" "runtime" "strings" - "sync/atomic" "testing" "time" @@ -188,7 +187,7 @@ func TestEventWillNotBeLostWhenEventChannelIsFull(t *testing.T) { } // check event channel is full and some events are dispatched asynchronously - assert.Assert(t, atomic.LoadInt32(&asyncDispatchCount) > 0) + assert.Assert(t, asyncDispatchCount.Load() > 0) // wait until all events are handled dispatcher.drain() @@ -198,7 +197,7 @@ func TestEventWillNotBeLostWhenEventChannelIsFull(t *testing.T) { // assert all event are handled assert.Equal(t, recorder.size(), numEvents) - assert.Assert(t, atomic.LoadInt32(&asyncDispatchCount) == 0) + assert.Assert(t, asyncDispatchCount.Load() == 0) // ensure state is stopped assert.Equal(t, dispatcher.isRunning(), false) @@ -241,7 +240,7 @@ func TestDispatchTimeout(t *testing.T) { // 2nd one should be added to the channel // 3rd one should be posted as an async request time.Sleep(100 * time.Millisecond) - assert.Equal(t, atomic.LoadInt32(&asyncDispatchCount), int32(1)) + assert.Equal(t, asyncDispatchCount.Load(), int32(1)) // verify Dispatcher#asyncDispatch is called buf := make([]byte, 1<<16) @@ -250,7 +249,7 @@ func TestDispatchTimeout(t *testing.T) { // wait until async dispatch routine times out err := utils.WaitForCondition(func() bool { - return atomic.LoadInt32(&asyncDispatchCount) == int32(0) + return asyncDispatchCount.Load() == int32(0) }, 100*time.Millisecond, DispatchTimeout+AsyncDispatchCheckInterval) assert.NilError(t, err) diff --git a/pkg/dispatcher/dispatcher.go b/pkg/dispatcher/dispatcher.go index 980bba3ee..b0960f192 100644 --- a/pkg/dispatcher/dispatcher.go +++ b/pkg/dispatcher/dispatcher.go @@ -47,7 +47,7 @@ var ( AsyncDispatchLimit int32 AsyncDispatchCheckInterval = 3 * time.Second DispatchTimeout time.Duration - asyncDispatchCount int32 = 0 + asyncDispatchCount atomic.Int32 = atomic.Int32{} ) // central dispatcher that dispatches scheduling events. @@ -169,14 +169,14 @@ func (p *Dispatcher) dispatch(event events.SchedulingEvent) error { // async-dispatch try to enqueue the event in every 3 seconds util timeout, // it's only called when event channel is full. func (p *Dispatcher) asyncDispatch(event events.SchedulingEvent) { - count := atomic.AddInt32(&asyncDispatchCount, 1) + count := asyncDispatchCount.Add(1) log.Log(log.ShimDispatcher).Warn("event channel is full, transition to async-dispatch mode", zap.Int32("asyncDispatchCount", count)) if count > AsyncDispatchLimit { panic(fmt.Errorf("dispatcher exceeds async-dispatch limit")) } go func(beginTime time.Time, stop chan struct{}) { - defer atomic.AddInt32(&asyncDispatchCount, -1) + defer asyncDispatchCount.Add(-1) for p.isRunning() { select { case <-stop: