diff --git a/changelog/unreleased/pull-272 b/changelog/unreleased/pull-272 new file mode 100644 index 0000000..0d8cccc --- /dev/null +++ b/changelog/unreleased/pull-272 @@ -0,0 +1,12 @@ +Enhancement: Support listening on a unix socket + +To let rest-server listen on a unix socket, prefix the socket filename with `unix:` and pass it to the `--listen` option, for example `--listen unix:/tmp/foo`. + +This is useful in combination with remote port forwarding to enable remote server to backup locally, e.g. + +``` +rest-server --listen unix:/tmp/foo & +ssh -R /tmp/foo:/tmp/foo user@host restic -r rest:http+unix:///tmp/foo:/repo backup +``` + +https://github.com/restic/rest-server/pull/272 diff --git a/cmd/rest-server/listener_unix.go b/cmd/rest-server/listener_unix.go index f3baf00..8350d41 100644 --- a/cmd/rest-server/listener_unix.go +++ b/cmd/rest-server/listener_unix.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net" + "strings" "github.com/coreos/go-systemd/v22/activation" ) @@ -23,9 +24,20 @@ func findListener(addr string) (listener net.Listener, err error) { switch len(listeners) { case 0: // no listeners found, listen manually - listener, err = net.Listen("tcp", addr) - if err != nil { - return nil, fmt.Errorf("listen on %v failed: %w", addr, err) + if strings.HasPrefix(addr, "unix:") { // if we want to listen on a unix socket + unixAddr, err := net.ResolveUnixAddr("unix", strings.TrimPrefix(addr, "unix:")) + if err != nil { + return nil, fmt.Errorf("unable to understand unix address %s: %w", addr, err) + } + listener, err = net.ListenUnix("unix", unixAddr) + if err != nil { + return nil, fmt.Errorf("listen on %v failed: %w", addr, err) + } + } else { // assume tcp + listener, err = net.Listen("tcp", addr) + if err != nil { + return nil, fmt.Errorf("listen on %v failed: %w", addr, err) + } } log.Printf("start server on %v", listener.Addr()) diff --git a/cmd/rest-server/listener_unix_test.go b/cmd/rest-server/listener_unix_test.go new file mode 100644 index 0000000..a4f32f4 --- /dev/null +++ b/cmd/rest-server/listener_unix_test.go @@ -0,0 +1,75 @@ +//go:build !windows +// +build !windows + +package main + +import ( + "context" + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" +) + +func TestUnixSocket(t *testing.T) { + td := t.TempDir() + + // this is the socket we'll listen on and connect to + tempSocket := filepath.Join(td, "sock") + + // create some content and parent dirs + if err := os.MkdirAll(filepath.Join(td, "data", "repo1"), 0700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(td, "data", "repo1", "config"), []byte("foo"), 0700); err != nil { + t.Fatal(err) + } + + // run the following twice, to test that the server will + // cleanup its socket file when quitting, which won't happen + // if it doesn't exit gracefully + for i := 0; i < 2; i++ { + err := testServerWithArgs([]string{ + "--no-auth", + "--path", filepath.Join(td, "data"), + "--listen", fmt.Sprintf("unix:%s", tempSocket), + }, time.Second, func(ctx context.Context, _ *restServerApp) error { + // custom client that will talk HTTP to unix socket + client := http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", tempSocket) + }, + }, + } + for _, test := range []struct { + Path string + StatusCode int + }{ + {"/repo1/", http.StatusMethodNotAllowed}, + {"/repo1/config", http.StatusOK}, + {"/repo2/config", http.StatusNotFound}, + } { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://ignored"+test.Path, nil) + if err != nil { + return err + } + resp, err := client.Do(req) + if err != nil { + return err + } + resp.Body.Close() + if resp.StatusCode != test.StatusCode { + return fmt.Errorf("expected %d from server, instead got %d (path %s)", test.StatusCode, resp.StatusCode, test.Path) + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + } +}