diff --git a/src/go/api/experiment/experiment.go b/src/go/api/experiment/experiment.go index 1f1026f4..120e80cb 100644 --- a/src/go/api/experiment/experiment.go +++ b/src/go/api/experiment/experiment.go @@ -92,6 +92,18 @@ func init() { }) } +// Hook is a function to be called during the different lifecycle stages of an +// experiment. The first argument is the experiment stage (create, start, stop, +// delete), and the second argument is the experiment, name. +type Hook func(string, string) + +var hooks = make(map[string][]Hook) + +// RegisterHook registers a Hook for the given experiment stage. +func RegisterHook(stage string, hook Hook) { + hooks[stage] = append(hooks[stage], hook) +} + // List collects experiments, each in a struct that references the latest // versioned experiment spec and status. It returns a slice of experiments and // any errors encountered while gathering and decoding them. @@ -252,6 +264,10 @@ func Create(ctx context.Context, opts ...CreateOption) error { return fmt.Errorf("creating experiment config: %w", err) } + for _, hook := range hooks["create"] { + hook("create", o.name) + } + return nil } @@ -543,6 +559,10 @@ func Start(ctx context.Context, opts ...StartOption) error { return fmt.Errorf("updating experiment config: %w", err) } + for _, hook := range hooks["start"] { + hook("start", o.name) + } + return nil } @@ -587,6 +607,10 @@ func Stop(name string) error { errors = multierror.Append(errors, fmt.Errorf("updating experiment config: %w", err)) } + for _, hook := range hooks["stop"] { + hook("stop", name) + } + return errors } @@ -746,6 +770,10 @@ func Delete(name string) error { errors = multierror.Append(errors, fmt.Errorf("deleting experiment base directory: %w", err)) } + for _, hook := range hooks["delete"] { + hook("delete", name) + } + return errors } diff --git a/src/go/api/scorch/break.go b/src/go/api/scorch/break.go index ec0e824d..8b62e310 100644 --- a/src/go/api/scorch/break.go +++ b/src/go/api/scorch/break.go @@ -135,8 +135,7 @@ func (this Break) breakPoint(ctx context.Context, stage Action) error { } select { - case <-ctx.Done(): - // don't return ctx error here so we can clean up tap and internet access below + case <-ctx.Done(): // this blocks until the context is canceled case <-done: // this blocks until web terminal is exited } } diff --git a/src/go/api/scorch/scorchexe/cancelers.go b/src/go/api/scorch/scorchexe/cancelers.go new file mode 100644 index 00000000..2011c4bc --- /dev/null +++ b/src/go/api/scorch/scorchexe/cancelers.go @@ -0,0 +1,65 @@ +package scorchexe + +import ( + "context" + "fmt" + "strings" + "sync" +) + +var ( + cancelers = make(map[string]context.CancelFunc) + cancelersMu sync.Mutex +) + +func AddCanceler(ctx context.Context, exp string, run int) context.Context { + key := fmt.Sprintf("%s/%d", exp, run) + + cancelersMu.Lock() + defer cancelersMu.Unlock() + + ctx, cancel := context.WithCancel(ctx) + cancelers[key] = cancel + + return ctx +} + +func HasCanceler(exp string, run int) bool { + key := fmt.Sprintf("%s/%d", exp, run) + + cancelersMu.Lock() + defer cancelersMu.Unlock() + + _, ok := cancelers[key] + + return ok +} + +func GetExperimentCancelers(exp string) []context.CancelFunc { + var expCancelers []context.CancelFunc + + cancelersMu.Lock() + defer cancelersMu.Unlock() + + for run := range cancelers { + // run keys are prefixed with the name of the experiment + if strings.HasPrefix(run, exp+"/") { + expCancelers = append(expCancelers, cancelers[run]) + delete(cancelers, run) + } + } + + return expCancelers +} + +func GetCanceler(exp string, run int) context.CancelFunc { + key := fmt.Sprintf("%s/%d", exp, run) + + cancelersMu.Lock() + defer cancelersMu.Unlock() + + cancel := cancelers[key] + delete(cancelers, key) + + return cancel +} diff --git a/src/go/types/version/v1/experiment.go b/src/go/types/version/v1/experiment.go index ac5ae86f..07b9f3c0 100644 --- a/src/go/types/version/v1/experiment.go +++ b/src/go/types/version/v1/experiment.go @@ -302,22 +302,42 @@ func (this ExperimentStatus) StartTime() string { } func (this ExperimentStatus) AppStatus() map[string]any { + if this.AppsF == nil { + return make(map[string]any) + } + return this.AppsF } func (this ExperimentStatus) AppFrequency() map[string]string { + if this.FrequencyF == nil { + return make(map[string]string) + } + return this.FrequencyF } func (this ExperimentStatus) AppRunning() map[string]bool { + if this.RunningF == nil { + return make(map[string]bool) + } + return this.RunningF } func (this ExperimentStatus) VLANs() map[string]int { + if this.VLANsF == nil { + return make(map[string]int) + } + return this.VLANsF } func (this ExperimentStatus) Schedules() map[string]string { + if this.SchedulesF == nil { + return make(map[string]string) + } + return this.SchedulesF } @@ -365,10 +385,18 @@ func (this *ExperimentStatus) SetAppRunning(a string, r bool) { } func (this *ExperimentStatus) SetVLANs(v map[string]int) { + if this.VLANsF == nil { + this.VLANsF = make(map[string]int) + } + this.VLANsF = v } func (this *ExperimentStatus) SetSchedule(s map[string]string) { + if this.SchedulesF == nil { + this.SchedulesF = make(map[string]string) + } + this.SchedulesF = s } diff --git a/src/go/web/scorch/handlers.go b/src/go/web/scorch/handlers.go index ea6be683..152894f0 100644 --- a/src/go/web/scorch/handlers.go +++ b/src/go/web/scorch/handlers.go @@ -27,6 +27,14 @@ import ( "golang.org/x/net/websocket" ) +func init() { + experiment.RegisterHook("stop", func(stage, name string) { + for _, cancel := range scorchexe.GetExperimentCancelers(name) { + cancel() + } + }) +} + type termClient struct { id string ws *websocket.Conn @@ -48,8 +56,6 @@ var ( termClientIDs = make(map[string]chan struct{}) - cancelers = make(map[string]context.CancelFunc) - mu sync.Mutex ) @@ -588,33 +594,19 @@ func StartPipeline(w http.ResponseWriter, r *http.Request) error { return weberror.NewWebError(err, "unable to get experiment %s from store", name) } - key := fmt.Sprintf("%s/%d", name, run) - - // protect `cancelers` map - mu.Lock() - defer mu.Unlock() - - // TODO (btr): we some how got stuck here at least once where a scorch run was - // started, then the experiment was killed, but the scorch run key stayed in - // the cancelers map. I'm still not entirely sure how this could happen, but - // if the mutex lock isn't blocked then we could do something like trigger - // reaping of scorch runs for experiments that have been stopped. We could - // also base the cancel context for a scorch run off the cancel context for - // the experiment, but in order to do this we'll need to refactor code to - // avoid an import loop. - - if _, ok := cancelers[key]; ok { + if scorchexe.HasCanceler(name, run) { return weberror.NewWebError(nil, "Scorch run already executing for experiment %s", name) } // We don't want to use the HTTP request's context here. - ctx, cancel := context.WithCancel(context.Background()) + ctx = scorchexe.AddCanceler(context.Background(), name, run) ctx = app.SetContextTriggerUI(ctx) - cancelers[key] = cancel go func() { log.Debug("executing Scorch run %d for experiment %s", run, name) + key := fmt.Sprintf("%s/%d", name, run) + broker.Broadcast( broker.NewRequestPolicy("experiments/trigger", "create", name), broker.NewResource("apps/scorch", key, "start"), @@ -641,15 +633,12 @@ func StartPipeline(w http.ResponseWriter, r *http.Request) error { ) } - // protect `cancelers` map - mu.Lock() - defer mu.Unlock() - // Ensure context is canceled to avoid leakage. It's okay to call the // `cancel` function multiple times. It's a no-op after the first time it's // called. - cancel() - delete(cancelers, key) + if cancel := scorchexe.GetCanceler(name, run); cancel != nil { + cancel() + } }() w.WriteHeader(http.StatusNoContent) @@ -679,17 +668,12 @@ func CancelPipeline(w http.ResponseWriter, r *http.Request) error { return err.SetStatus(http.StatusForbidden) } - key := fmt.Sprintf("%s/%d", name, run) - - // protect `cancelers` map - mu.Lock() - defer mu.Unlock() - - if cancel, ok := cancelers[key]; ok { + if cancel := scorchexe.GetCanceler(name, run); cancel != nil { log.Debug("canceling Scorch run %d for experiment %s", run, name) cancel() - delete(cancelers, key) + + key := fmt.Sprintf("%s/%d", name, run) broker.Broadcast( broker.NewRequestPolicy("experiments/trigger", "delete", name),