diff --git a/api/server.go b/api/server.go index ec605d49df1..8307d73dcf8 100644 --- a/api/server.go +++ b/api/server.go @@ -1,9 +1,12 @@ +// Package api contains the REST API implementation for k6. +// It also registers the services endpoints like pprof package api import ( "context" "fmt" "net/http" + _ "net/http/pprof" //nolint:gosec // Register pprof handlers "time" "github.com/sirupsen/logrus" @@ -15,18 +18,41 @@ import ( "go.k6.io/k6/metrics/engine" ) -func newHandler(cs *v1.ControlSurface) http.Handler { +func newHandler(cs *v1.ControlSurface, profilingEnabled bool) http.Handler { mux := http.NewServeMux() mux.Handle("/v1/", v1.NewHandler(cs)) mux.Handle("/ping", handlePing(cs.RunState.Logger)) mux.Handle("/", handlePing(cs.RunState.Logger)) + + injectProfilerHandler(mux, profilingEnabled) + return mux } +func injectProfilerHandler(mux *http.ServeMux, profilingEnabled bool) { + var handler http.Handler + + handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Header().Add("Content-Type", "text/plain; charset=utf-8") + _, _ = rw.Write([]byte("To enable profiling, please run k6 with the --profiling-enabled flag")) + }) + + if profilingEnabled { + handler = http.DefaultServeMux + } + + mux.Handle("/debug/pprof/", handler) +} + // GetServer returns a http.Server instance that can serve k6's REST API. func GetServer( - runCtx context.Context, addr string, runState *lib.TestRunState, - samples chan metrics.SampleContainer, me *engine.MetricsEngine, es *execution.Scheduler, + runCtx context.Context, + addr string, + profilingEnabled bool, + runState *lib.TestRunState, + samples chan metrics.SampleContainer, + me *engine.MetricsEngine, + es *execution.Scheduler, ) *http.Server { // TODO: reduce the control surface as much as possible? For example, if // we refactor the Runner API, we won't need to send the Samples channel. @@ -38,7 +64,7 @@ func GetServer( RunState: runState, } - mux := withLoggingHandler(runState.Logger, newHandler(cs)) + mux := withLoggingHandler(runState.Logger, newHandler(cs, profilingEnabled)) return &http.Server{Addr: addr, Handler: mux, ReadHeaderTimeout: 10 * time.Second} } diff --git a/cmd/root.go b/cmd/root.go index 6bf12711d89..0b5e6e6d602 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -185,6 +185,12 @@ func rootCmdPersistentFlagSet(gs *state.GlobalState) *pflag.FlagSet { flags.BoolVarP(&gs.Flags.Verbose, "verbose", "v", gs.DefaultFlags.Verbose, "enable verbose logging") flags.BoolVarP(&gs.Flags.Quiet, "quiet", "q", gs.DefaultFlags.Quiet, "disable progress updates") flags.StringVarP(&gs.Flags.Address, "address", "a", gs.DefaultFlags.Address, "address for the REST API server") + flags.BoolVar( + &gs.Flags.ProfilingEnabled, + "profiling-enabled", + gs.DefaultFlags.ProfilingEnabled, + "enable profiling (pprof) endpoints, k6's REST API should be enabled as well", + ) return flags } diff --git a/cmd/run.go b/cmd/run.go index 2c640513c17..be6452149de 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -270,10 +270,20 @@ func (c *cmdRun) run(cmd *cobra.Command, args []string) (err error) { srvCtx, srvCancel := context.WithCancel(globalCtx) defer srvCancel() - srv := api.GetServer(runCtx, c.gs.Flags.Address, testRunState, samples, metricsEngine, execScheduler) + srv := api.GetServer( + runCtx, + c.gs.Flags.Address, c.gs.Flags.ProfilingEnabled, + testRunState, + samples, + metricsEngine, + execScheduler, + ) go func() { defer apiWG.Done() logger.Debugf("Starting the REST API server on %s", c.gs.Flags.Address) + if c.gs.Flags.ProfilingEnabled { + logger.Debugf("Profiling exposed on http://%s/debug/pprof/", c.gs.Flags.Address) + } if aerr := srv.ListenAndServe(); aerr != nil && !errors.Is(aerr, http.ErrServerClosed) { // Only exit k6 if the user has explicitly set the REST API address if cmd.Flags().Lookup("address").Changed { diff --git a/cmd/state/state.go b/cmd/state/state.go index 3facd99d019..35b2da2f0ab 100644 --- a/cmd/state/state.go +++ b/cmd/state/state.go @@ -134,21 +134,23 @@ func NewGlobalState(ctx context.Context) *GlobalState { // GlobalFlags contains global config values that apply for all k6 sub-commands. type GlobalFlags struct { - ConfigFilePath string - Quiet bool - NoColor bool - Address string - LogOutput string - LogFormat string - Verbose bool + ConfigFilePath string + Quiet bool + NoColor bool + Address string + ProfilingEnabled bool + LogOutput string + LogFormat string + Verbose bool } // GetDefaultFlags returns the default global flags. func GetDefaultFlags(homeDir string) GlobalFlags { return GlobalFlags{ - Address: "localhost:6565", - ConfigFilePath: filepath.Join(homeDir, "loadimpact", "k6", defaultConfigFileName), - LogOutput: "stderr", + Address: "localhost:6565", + ProfilingEnabled: false, + ConfigFilePath: filepath.Join(homeDir, "loadimpact", "k6", defaultConfigFileName), + LogOutput: "stderr", } } diff --git a/cmd/ui.go b/cmd/ui.go index a2aaab725c1..3367223f1af 100644 --- a/cmd/ui.go +++ b/cmd/ui.go @@ -124,6 +124,10 @@ func printExecutionDescription( } fmt.Fprintf(buf, " output: %s\n", valueColor.Sprint(strings.Join(outputDescriptions, ", "))) + if gs.Flags.ProfilingEnabled && gs.Flags.Address != "" { + fmt.Fprintf(buf, " profiling: %s\n", valueColor.Sprintf("http://%s/debug/pprof/", gs.Flags.Address)) + } + fmt.Fprintf(buf, "\n") maxDuration, _ := lib.GetEndOffset(execPlan)