Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: shutdown gracefully on TERM or INT signals #273

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions changelog/unreleased/pull-273
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Change: Server is now shutdown cleanly on TERM or INT signals

Server now listens for TERM and INT signals and cleanly closes down the http.Server and listener.

This is particularly useful when listening on a unix socket, as the server will remove the socket file from it shuts down.

https://github.com/restic/rest-server/pull/273
199 changes: 118 additions & 81 deletions cmd/rest-server/main.go
Original file line number Diff line number Diff line change
@@ -1,159 +1,196 @@
package main

import (
"context"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/pprof"
"sync"
"syscall"

restserver "github.com/restic/rest-server"
"github.com/spf13/cobra"
)

// cmdRoot is the base command when no other command has been specified.
var cmdRoot = &cobra.Command{
Use: "rest-server",
Short: "Run a REST server for use with restic",
SilenceErrors: true,
SilenceUsage: true,
RunE: runRoot,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) != 0 {
return fmt.Errorf("rest-server expects no arguments - unknown argument: %s", args[0])
}
return nil
},
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
}
type restServerApp struct {
CmdRoot *cobra.Command
Server restserver.Server
CpuProfile string

var server = restserver.Server{
Path: filepath.Join(os.TempDir(), "restic"),
Listen: ":8000",
listenerAddressMu sync.Mutex
listenerAddress net.Addr // set after startup
}

var (
cpuProfile string
)

func init() {
flags := cmdRoot.Flags()
flags.StringVar(&cpuProfile, "cpu-profile", cpuProfile, "write CPU profile to file")
flags.BoolVar(&server.Debug, "debug", server.Debug, "output debug messages")
flags.StringVar(&server.Listen, "listen", server.Listen, "listen address")
flags.StringVar(&server.Log, "log", server.Log, "write HTTP requests in the combined log format to the specified `filename` (use \"-\" for logging to stdout)")
flags.Int64Var(&server.MaxRepoSize, "max-size", server.MaxRepoSize, "the maximum size of the repository in bytes")
flags.StringVar(&server.Path, "path", server.Path, "data directory")
flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support")
flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path")
flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path")
flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication")
flags.StringVar(&server.HtpasswdPath, "htpasswd-file", server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload,
// cmdRoot is the base command when no other command has been specified.
func newRestServerApp() *restServerApp {
rv := &restServerApp{
CmdRoot: &cobra.Command{
Use: "rest-server",
Short: "Run a REST server for use with restic",
SilenceErrors: true,
SilenceUsage: true,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) != 0 {
return fmt.Errorf("rest-server expects no arguments - unknown argument: %s", args[0])
}
return nil
},
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
},
Server: restserver.Server{
Path: filepath.Join(os.TempDir(), "restic"),
Listen: ":8000",
},
}
rv.CmdRoot.RunE = rv.runRoot
flags := rv.CmdRoot.Flags()

flags.StringVar(&rv.CpuProfile, "cpu-profile", rv.CpuProfile, "write CPU profile to file")
flags.BoolVar(&rv.Server.Debug, "debug", rv.Server.Debug, "output debug messages")
flags.StringVar(&rv.Server.Listen, "listen", rv.Server.Listen, "listen address")
flags.StringVar(&rv.Server.Log, "log", rv.Server.Log, "write HTTP requests in the combined log format to the specified `filename` (use \"-\" for logging to stdout)")
flags.Int64Var(&rv.Server.MaxRepoSize, "max-size", rv.Server.MaxRepoSize, "the maximum size of the repository in bytes")
flags.StringVar(&rv.Server.Path, "path", rv.Server.Path, "data directory")
flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support")
flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path")
flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path")
flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable .htpasswd authentication")
flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
flags.BoolVar(&rv.Server.NoVerifyUpload, "no-verify-upload", rv.Server.NoVerifyUpload,
"do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device")
flags.BoolVar(&server.AppendOnly, "append-only", server.AppendOnly, "enable append only mode")
flags.BoolVar(&server.PrivateRepos, "private-repos", server.PrivateRepos, "users can only access their private repo")
flags.BoolVar(&server.Prometheus, "prometheus", server.Prometheus, "enable Prometheus metrics")
flags.BoolVar(&server.PrometheusNoAuth, "prometheus-no-auth", server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint")
flags.BoolVar(&rv.Server.AppendOnly, "append-only", rv.Server.AppendOnly, "enable append only mode")
flags.BoolVar(&rv.Server.PrivateRepos, "private-repos", rv.Server.PrivateRepos, "users can only access their private repo")
flags.BoolVar(&rv.Server.Prometheus, "prometheus", rv.Server.Prometheus, "enable Prometheus metrics")
flags.BoolVar(&rv.Server.PrometheusNoAuth, "prometheus-no-auth", rv.Server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint")

return rv
}

var version = "0.12.1-dev"

func tlsSettings() (bool, string, string, error) {
func (app *restServerApp) tlsSettings() (bool, string, string, error) {
var key, cert string
if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") {
if !app.Server.TLS && (app.Server.TLSKey != "" || app.Server.TLSCert != "") {
return false, "", "", errors.New("requires enabled TLS")
} else if !server.TLS {
} else if !app.Server.TLS {
return false, "", "", nil
}
if server.TLSKey != "" {
key = server.TLSKey
if app.Server.TLSKey != "" {
key = app.Server.TLSKey
} else {
key = filepath.Join(server.Path, "private_key")
key = filepath.Join(app.Server.Path, "private_key")
}
if server.TLSCert != "" {
cert = server.TLSCert
if app.Server.TLSCert != "" {
cert = app.Server.TLSCert
} else {
cert = filepath.Join(server.Path, "public_key")
cert = filepath.Join(app.Server.Path, "public_key")
}
return server.TLS, key, cert, nil
return app.Server.TLS, key, cert, nil
}

func runRoot(cmd *cobra.Command, args []string) error {
// returns the address that the app is listening on.
// returns nil if the application hasn't finished starting yet
func (app *restServerApp) ListenerAddress() net.Addr {
app.listenerAddressMu.Lock()
defer app.listenerAddressMu.Unlock()
return app.listenerAddress
}

func (app *restServerApp) runRoot(cmd *cobra.Command, args []string) error {
log.SetFlags(0)

log.Printf("Data directory: %s", server.Path)
log.Printf("Data directory: %s", app.Server.Path)

if cpuProfile != "" {
f, err := os.Create(cpuProfile)
if app.CpuProfile != "" {
f, err := os.Create(app.CpuProfile)
if err != nil {
return err
}
defer f.Close()

if err := pprof.StartCPUProfile(f); err != nil {
return err
}
log.Println("CPU profiling enabled")
defer pprof.StopCPUProfile()

// clean profiling shutdown on sigint
sigintCh := make(chan os.Signal, 1)
go func() {
for range sigintCh {
pprof.StopCPUProfile()
log.Println("Stopped CPU profiling")
err := f.Close()
if err != nil {
log.Printf("error closing CPU profile file: %v", err)
}
os.Exit(130)
}
}()
signal.Notify(sigintCh, syscall.SIGINT)
log.Println("CPU profiling enabled")
defer log.Println("Stopped CPU profiling")
}

if server.NoAuth {
if app.Server.NoAuth {
log.Println("Authentication disabled")
} else {
log.Println("Authentication enabled")
}

handler, err := restserver.NewHandler(&server)
handler, err := restserver.NewHandler(&app.Server)
if err != nil {
log.Fatalf("error: %v", err)
}

if server.PrivateRepos {
if app.Server.PrivateRepos {
log.Println("Private repositories enabled")
} else {
log.Println("Private repositories disabled")
}

enabledTLS, privateKey, publicKey, err := tlsSettings()
enabledTLS, privateKey, publicKey, err := app.tlsSettings()
if err != nil {
return err
}

listener, err := findListener(server.Listen)
listener, err := findListener(app.Server.Listen)
if err != nil {
return fmt.Errorf("unable to listen: %w", err)
}

if !enabledTLS {
err = http.Serve(listener, handler)
} else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = http.ServeTLS(listener, handler, publicKey, privateKey)
// set listener address, this is useful for tests
app.listenerAddressMu.Lock()
app.listenerAddress = listener.Addr()
app.listenerAddressMu.Unlock()

srv := &http.Server{
Handler: handler,
}

// run server in background
go func() {
if !enabledTLS {
err = srv.Serve(listener)
} else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = srv.ServeTLS(listener, publicKey, privateKey)
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("listen and serve returned err: %v", err)
}
}()

// wait until done
<-app.CmdRoot.Context().Done()

// gracefully shutdown server
if err := srv.Shutdown(context.Background()); err != nil {
return fmt.Errorf("server shutdown returned an err: %w", err)
}

return err
log.Println("shutdown cleanly")
return nil
}

func main() {
if err := cmdRoot.Execute(); err != nil {
// create context to be notified on interrupt or term signal so that we can shutdown cleanly
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()

if err := newRestServerApp().CmdRoot.ExecuteContext(ctx); err != nil {
log.Fatalf("error: %v", err)
}
}
Loading
Loading