diff --git a/util/stopwaiter/stopwaiter.go b/util/stopwaiter/stopwaiter.go index 974178e1cf..c242ac26ab 100644 --- a/util/stopwaiter/stopwaiter.go +++ b/util/stopwaiter/stopwaiter.go @@ -119,8 +119,14 @@ func getAllStackTraces() string { func (s *StopWaiterSafe) stopAndWaitImpl(warningTimeout time.Duration) error { s.StopOnly() + if !s.Started() { + // No need to wait, because nothing can be started if it's already stopped. + return nil + } // Even if StopOnly has been previously called, make sure we wait for everything to shut down. // Otherwise, a StopOnly call followed by StopAndWait might return early without waiting. + // At this point started must be true (because it was true above and cannot go back to false), + // so GetWaitChannel won't return an error. waitChan, err := s.GetWaitChannel() if err != nil { return err diff --git a/util/stopwaiter/stopwaiter_test.go b/util/stopwaiter/stopwaiter_test.go index c561e1f43b..68e49ac2be 100644 --- a/util/stopwaiter/stopwaiter_test.go +++ b/util/stopwaiter/stopwaiter_test.go @@ -5,6 +5,7 @@ package stopwaiter import ( "context" + "sync/atomic" "testing" "time" @@ -73,3 +74,19 @@ func TestStopWaiterStopAndWaitMultipleTimes(t *testing.T) { sw.StopAndWait() sw.StopAndWait() } + +func TestStopWaiterStopOnlyThenStopAndWait(t *testing.T) { + t.Parallel() + sw := StopWaiter{} + sw.Start(context.Background(), &TestStruct{}) + var threadStopping atomic.Bool + sw.LaunchThread(func(context.Context) { + time.Sleep(time.Second) + threadStopping.Store(true) + }) + sw.StopOnly() + sw.StopAndWait() + if !threadStopping.Load() { + t.Error("StopAndWait returned before background thread stopped") + } +}