diff --git a/changelog/unreleased/pull-273 b/changelog/unreleased/pull-273 new file mode 100644 index 00000000..c8cb44fa --- /dev/null +++ b/changelog/unreleased/pull-273 @@ -0,0 +1,7 @@ +Change: Server is now shutdown cleanly on TERM or INT signals + +Server now listens for TERM and INT signals and cleanly closes down the http.Server and listener. + +This is particularly useful when listening on a unix socket, as the server will remove the socket file from it shuts down. + +https://github.com/restic/rest-server/pull/273 diff --git a/cmd/rest-server/main.go b/cmd/rest-server/main.go index eef0cd41..f0d5d248 100644 --- a/cmd/rest-server/main.go +++ b/cmd/rest-server/main.go @@ -1,159 +1,196 @@ package main import ( + "context" "errors" "fmt" "log" + "net" "net/http" "os" "os/signal" "path/filepath" "runtime" "runtime/pprof" + "sync" "syscall" restserver "github.com/restic/rest-server" "github.com/spf13/cobra" ) -// cmdRoot is the base command when no other command has been specified. -var cmdRoot = &cobra.Command{ - Use: "rest-server", - Short: "Run a REST server for use with restic", - SilenceErrors: true, - SilenceUsage: true, - RunE: runRoot, - Args: func(cmd *cobra.Command, args []string) error { - if len(args) != 0 { - return fmt.Errorf("rest-server expects no arguments - unknown argument: %s", args[0]) - } - return nil - }, - Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH), -} +type restServerApp struct { + CmdRoot *cobra.Command + Server restserver.Server + CpuProfile string -var server = restserver.Server{ - Path: filepath.Join(os.TempDir(), "restic"), - Listen: ":8000", + listenerAddressMu sync.Mutex + listenerAddress net.Addr // set after startup } -var ( - cpuProfile string -) - -func init() { - flags := cmdRoot.Flags() - flags.StringVar(&cpuProfile, "cpu-profile", cpuProfile, "write CPU profile to file") - flags.BoolVar(&server.Debug, "debug", server.Debug, "output debug messages") - flags.StringVar(&server.Listen, "listen", server.Listen, "listen address") - flags.StringVar(&server.Log, "log", server.Log, "write HTTP requests in the combined log format to the specified `filename` (use \"-\" for logging to stdout)") - flags.Int64Var(&server.MaxRepoSize, "max-size", server.MaxRepoSize, "the maximum size of the repository in bytes") - flags.StringVar(&server.Path, "path", server.Path, "data directory") - flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support") - flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path") - flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path") - flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication") - flags.StringVar(&server.HtpasswdPath, "htpasswd-file", server.HtpasswdPath, "location of .htpasswd file (default: \"/.htpasswd)\"") - flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload, +// cmdRoot is the base command when no other command has been specified. +func newRestServerApp() *restServerApp { + rv := &restServerApp{ + CmdRoot: &cobra.Command{ + Use: "rest-server", + Short: "Run a REST server for use with restic", + SilenceErrors: true, + SilenceUsage: true, + Args: func(cmd *cobra.Command, args []string) error { + if len(args) != 0 { + return fmt.Errorf("rest-server expects no arguments - unknown argument: %s", args[0]) + } + return nil + }, + Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH), + }, + Server: restserver.Server{ + Path: filepath.Join(os.TempDir(), "restic"), + Listen: ":8000", + }, + } + rv.CmdRoot.RunE = rv.runRoot + flags := rv.CmdRoot.Flags() + + flags.StringVar(&rv.CpuProfile, "cpu-profile", rv.CpuProfile, "write CPU profile to file") + flags.BoolVar(&rv.Server.Debug, "debug", rv.Server.Debug, "output debug messages") + flags.StringVar(&rv.Server.Listen, "listen", rv.Server.Listen, "listen address") + flags.StringVar(&rv.Server.Log, "log", rv.Server.Log, "write HTTP requests in the combined log format to the specified `filename` (use \"-\" for logging to stdout)") + flags.Int64Var(&rv.Server.MaxRepoSize, "max-size", rv.Server.MaxRepoSize, "the maximum size of the repository in bytes") + flags.StringVar(&rv.Server.Path, "path", rv.Server.Path, "data directory") + flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support") + flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path") + flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path") + flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable .htpasswd authentication") + flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"/.htpasswd)\"") + flags.BoolVar(&rv.Server.NoVerifyUpload, "no-verify-upload", rv.Server.NoVerifyUpload, "do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device") - flags.BoolVar(&server.AppendOnly, "append-only", server.AppendOnly, "enable append only mode") - flags.BoolVar(&server.PrivateRepos, "private-repos", server.PrivateRepos, "users can only access their private repo") - flags.BoolVar(&server.Prometheus, "prometheus", server.Prometheus, "enable Prometheus metrics") - flags.BoolVar(&server.PrometheusNoAuth, "prometheus-no-auth", server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint") + flags.BoolVar(&rv.Server.AppendOnly, "append-only", rv.Server.AppendOnly, "enable append only mode") + flags.BoolVar(&rv.Server.PrivateRepos, "private-repos", rv.Server.PrivateRepos, "users can only access their private repo") + flags.BoolVar(&rv.Server.Prometheus, "prometheus", rv.Server.Prometheus, "enable Prometheus metrics") + flags.BoolVar(&rv.Server.PrometheusNoAuth, "prometheus-no-auth", rv.Server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint") + + return rv } var version = "0.12.1-dev" -func tlsSettings() (bool, string, string, error) { +func (app *restServerApp) tlsSettings() (bool, string, string, error) { var key, cert string - if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") { + if !app.Server.TLS && (app.Server.TLSKey != "" || app.Server.TLSCert != "") { return false, "", "", errors.New("requires enabled TLS") - } else if !server.TLS { + } else if !app.Server.TLS { return false, "", "", nil } - if server.TLSKey != "" { - key = server.TLSKey + if app.Server.TLSKey != "" { + key = app.Server.TLSKey } else { - key = filepath.Join(server.Path, "private_key") + key = filepath.Join(app.Server.Path, "private_key") } - if server.TLSCert != "" { - cert = server.TLSCert + if app.Server.TLSCert != "" { + cert = app.Server.TLSCert } else { - cert = filepath.Join(server.Path, "public_key") + cert = filepath.Join(app.Server.Path, "public_key") } - return server.TLS, key, cert, nil + return app.Server.TLS, key, cert, nil } -func runRoot(cmd *cobra.Command, args []string) error { +// returns the address that the app is listening on. +// returns nil if the application hasn't finished starting yet +func (app *restServerApp) ListenerAddress() net.Addr { + app.listenerAddressMu.Lock() + defer app.listenerAddressMu.Unlock() + return app.listenerAddress +} + +func (app *restServerApp) runRoot(cmd *cobra.Command, args []string) error { log.SetFlags(0) - log.Printf("Data directory: %s", server.Path) + log.Printf("Data directory: %s", app.Server.Path) - if cpuProfile != "" { - f, err := os.Create(cpuProfile) + if app.CpuProfile != "" { + f, err := os.Create(app.CpuProfile) if err != nil { return err } + defer f.Close() + if err := pprof.StartCPUProfile(f); err != nil { return err } - log.Println("CPU profiling enabled") + defer pprof.StopCPUProfile() - // clean profiling shutdown on sigint - sigintCh := make(chan os.Signal, 1) - go func() { - for range sigintCh { - pprof.StopCPUProfile() - log.Println("Stopped CPU profiling") - err := f.Close() - if err != nil { - log.Printf("error closing CPU profile file: %v", err) - } - os.Exit(130) - } - }() - signal.Notify(sigintCh, syscall.SIGINT) + log.Println("CPU profiling enabled") + defer log.Println("Stopped CPU profiling") } - if server.NoAuth { + if app.Server.NoAuth { log.Println("Authentication disabled") } else { log.Println("Authentication enabled") } - handler, err := restserver.NewHandler(&server) + handler, err := restserver.NewHandler(&app.Server) if err != nil { log.Fatalf("error: %v", err) } - if server.PrivateRepos { + if app.Server.PrivateRepos { log.Println("Private repositories enabled") } else { log.Println("Private repositories disabled") } - enabledTLS, privateKey, publicKey, err := tlsSettings() + enabledTLS, privateKey, publicKey, err := app.tlsSettings() if err != nil { return err } - listener, err := findListener(server.Listen) + listener, err := findListener(app.Server.Listen) if err != nil { return fmt.Errorf("unable to listen: %w", err) } - if !enabledTLS { - err = http.Serve(listener, handler) - } else { - log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey) - err = http.ServeTLS(listener, handler, publicKey, privateKey) + // set listener address, this is useful for tests + app.listenerAddressMu.Lock() + app.listenerAddress = listener.Addr() + app.listenerAddressMu.Unlock() + + srv := &http.Server{ + Handler: handler, + } + + // run server in background + go func() { + if !enabledTLS { + err = srv.Serve(listener) + } else { + log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey) + err = srv.ServeTLS(listener, publicKey, privateKey) + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("listen and serve returned err: %v", err) + } + }() + + // wait until done + <-app.CmdRoot.Context().Done() + + // gracefully shutdown server + if err := srv.Shutdown(context.Background()); err != nil { + return fmt.Errorf("server shutdown returned an err: %w", err) } - return err + log.Println("shutdown cleanly") + return nil } func main() { - if err := cmdRoot.Execute(); err != nil { + // create context to be notified on interrupt or term signal so that we can shutdown cleanly + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + if err := newRestServerApp().CmdRoot.ExecuteContext(ctx); err != nil { log.Fatalf("error: %v", err) } } diff --git a/cmd/rest-server/main_test.go b/cmd/rest-server/main_test.go index 48bfa147..73669814 100644 --- a/cmd/rest-server/main_test.go +++ b/cmd/rest-server/main_test.go @@ -1,10 +1,18 @@ package main import ( + "context" + "errors" + "fmt" "io/ioutil" + "net/http" + "net/url" "os" "path/filepath" + "strings" + "sync" "testing" + "time" restserver "github.com/restic/rest-server" ) @@ -47,17 +55,17 @@ func TestTLSSettings(t *testing.T) { } for _, test := range tests { - + app := newRestServerApp() t.Run("", func(t *testing.T) { // defer func() { restserver.Server = defaultConfig }() if test.passed.Path != "" { - server.Path = test.passed.Path + app.Server.Path = test.passed.Path } - server.TLS = test.passed.TLS - server.TLSKey = test.passed.TLSKey - server.TLSCert = test.passed.TLSCert + app.Server.TLS = test.passed.TLS + app.Server.TLSKey = test.passed.TLSKey + app.Server.TLSCert = test.passed.TLSCert - gotTLS, gotKey, gotCert, err := tlsSettings() + gotTLS, gotKey, gotCert, err := app.tlsSettings() if err != nil && !test.expected.Error { t.Fatalf("tls_settings returned err (%v)", err) } @@ -146,3 +154,123 @@ func TestGetHandler(t *testing.T) { t.Errorf("NoAuth=false with .htpasswd: expected no error, got %v", err) } } + +// helper method to test the app. Starts app with passed arguments, +// then will call the callback function which can make requests against +// the application. If the callback function fails due to errors returned +// by http.Do() (i.e. *url.Error), then it will be retried until successful, +// or the passed timeout passes. +func testServerWithArgs(args []string, timeout time.Duration, cb func(context.Context, *restServerApp) error) error { + // create the app with passed args + app := newRestServerApp() + app.CmdRoot.SetArgs(args) + + // create context that will timeout + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // wait group for our client and server tasks + jobs := &sync.WaitGroup{} + jobs.Add(2) + + // run the server, saving the error + var serverErr error + go func() { + defer jobs.Done() + defer cancel() // if the server is stopped, no point keep the client alive + serverErr = app.CmdRoot.ExecuteContext(ctx) + }() + + // run the client, saving the error + var clientErr error + go func() { + defer jobs.Done() + defer cancel() // once the client is done, stop the server + + var urlError *url.Error + + // execute in loop, as we will retry for network errors + // (such as the server hasn't started yet) + for { + clientErr = cb(ctx, app) + switch { + case clientErr == nil: + return // success, we're done + case errors.As(clientErr, &urlError): + // if a network error (url.Error), then wait and retry + // as server may not be ready yet + select { + case <-time.After(time.Millisecond * 100): + continue + case <-ctx.Done(): // unless we run out of time first + clientErr = context.Canceled + return + } + default: + return // other error type, we're done + } + } + }() + + // wait for both to complete + jobs.Wait() + + // report back if either failed + if clientErr != nil || serverErr != nil { + return fmt.Errorf("client or server error, client: %v, server: %v", clientErr, serverErr) + } + + return nil +} + +func TestHttpListen(t *testing.T) { + td := t.TempDir() + + // create some content and parent dirs + if err := os.MkdirAll(filepath.Join(td, "data", "repo1"), 0700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(td, "data", "repo1", "config"), []byte("foo"), 0700); err != nil { + t.Fatal(err) + } + + for _, args := range [][]string{ + {"--no-auth", "--path", filepath.Join(td, "data"), "--listen", "127.0.0.1:0"}, // test emphemeral port + {"--no-auth", "--path", filepath.Join(td, "data"), "--listen", "127.0.0.1:9000"}, // test "normal" port + {"--no-auth", "--path", filepath.Join(td, "data"), "--listen", "127.0.0.1:9000"}, // test that server was shutdown cleanly and that we can re-use that port + } { + err := testServerWithArgs(args, time.Second*10, func(ctx context.Context, app *restServerApp) error { + for _, test := range []struct { + Path string + StatusCode int + }{ + {"/repo1/", http.StatusMethodNotAllowed}, + {"/repo1/config", http.StatusOK}, + {"/repo2/config", http.StatusNotFound}, + } { + listenAddr := app.ListenerAddress() + if listenAddr == nil { + return &url.Error{} // return this type of err, as we know this will retry + } + port := strings.Split(listenAddr.String(), ":")[1] + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:%s%s", port, test.Path), nil) + if err != nil { + return err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + resp.Body.Close() + if resp.StatusCode != test.StatusCode { + return fmt.Errorf("expected %d from server, instead got %d (path %s)", test.StatusCode, resp.StatusCode, test.Path) + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + } +}