Skip to content

Commit

Permalink
feat: add Unix forwarding server implementations
Browse files Browse the repository at this point in the history
Adds optional (disabled by default) implementations of local->remote and
remote->local Unix forwarding through OpenSSH's protocol extensions:

- [email protected]
    - [email protected]
    - [email protected]
- [email protected]

Adds tests for Unix forwarding, reverse Unix forwarding and reverse TCP
forwarding.

Co-authored-by: Samuel Corsi-House <[email protected]>
  • Loading branch information
deansheather and samchouse committed Oct 23, 2024
1 parent 5a52e32 commit f9faa02
Show file tree
Hide file tree
Showing 9 changed files with 642 additions and 30 deletions.
2 changes: 1 addition & 1 deletion options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestPasswordAuth(t *testing.T) {

func TestPasswordAuthBadPass(t *testing.T) {
t.Parallel()
l := newLocalListener()
l := newLocalTCPListener()
srv := &Server{Handler: func(s Session) {}}
srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
return false
Expand Down
2 changes: 2 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ type Server struct {
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
LocalUnixForwardingCallback LocalUnixForwardingCallback // callback for allowing local unix forwarding ([email protected]), denies all if nil
ReverseUnixForwardingCallback ReverseUnixForwardingCallback // callback for allowing reverse unix forwarding ([email protected]), denies all if nil
ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options
SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions

Expand Down
4 changes: 2 additions & 2 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestAddHostKey(t *testing.T) {
}

func TestServerShutdown(t *testing.T) {
l := newLocalListener()
l := newLocalTCPListener()
testBytes := []byte("Hello world\n")
s := &Server{
Handler: func(s Session) {
Expand Down Expand Up @@ -80,7 +80,7 @@ func TestServerShutdown(t *testing.T) {
}

func TestServerClose(t *testing.T) {
l := newLocalListener()
l := newLocalTCPListener()
s := &Server{
Handler: func(s Session) {
time.Sleep(5 * time.Second)
Expand Down
19 changes: 15 additions & 4 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,25 @@ func (srv *Server) serveOnce(l net.Listener) error {
return e
}
srv.ChannelHandlers = map[string]ChannelHandler{
"session": DefaultSessionHandler,
"direct-tcpip": DirectTCPIPHandler,
"session": DefaultSessionHandler,
"direct-tcpip": DirectTCPIPHandler,
"[email protected]": DirectStreamLocalHandler,
}

forwardedTCPHandler := &ForwardedTCPHandler{}
forwardedUnixHandler := &ForwardedUnixHandler{}
srv.RequestHandlers = map[string]RequestHandler{
"tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
"[email protected]": forwardedUnixHandler.HandleSSHRequest,
"[email protected]": forwardedUnixHandler.HandleSSHRequest,
}

srv.HandleConn(conn)
return nil
}

func newLocalListener() net.Listener {
func newLocalTCPListener() net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
Expand Down Expand Up @@ -64,7 +75,7 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g
}

func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
l := newLocalListener()
l := newLocalTCPListener()
go srv.serveOnce(l)
return newClientSession(t, l.Addr().String(), cfg)
}
Expand Down
20 changes: 20 additions & 0 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ssh

import (
"crypto/subtle"
"errors"
"net"

gossh "golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -29,6 +30,9 @@ const (
// DefaultHandler is the default Handler used by Serve.
var DefaultHandler Handler

// ErrReject is returned by some callbacks to reject a request.
var ErrRejected = errors.New("ssh: rejected")

// Option is a functional option handler for Server.
type Option func(*Server) error

Expand Down Expand Up @@ -64,6 +68,22 @@ type LocalPortForwardingCallback func(ctx Context, destinationHost string, desti
// ReversePortForwardingCallback is a hook for allowing reverse port forwarding
type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool

// LocalUnixForwardingCallback is a hook for allowing unix forwarding
// ([email protected]). Returning ErrRejected will reject the
// request. The returned net.Conn will be closed by the server when no longer
// needed.
//
// Use SimpleUnixLocalForwardingCallback for a basic implementation.
type LocalUnixForwardingCallback func(ctx Context, socketPath string) (net.Conn, error)

// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding
// ([email protected]). Returning ErrRejected will reject the
// request. The returned net.Listener will be closed by the server when no
// longer needed.
//
// Use SimpleUnixReverseForwardingCallback for a basic implementation.
type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error)

// ServerConfigCallback is a hook for creating custom default server configs
type ServerConfigCallback func(ctx Context) *gossh.ServerConfig

Expand Down
252 changes: 252 additions & 0 deletions streamlocal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
package ssh

import (
"context"
"errors"
"fmt"
"io/fs"
"net"
"os"
"path/filepath"
"sync"
"syscall"

gossh "golang.org/x/crypto/ssh"
)

const (
forwardedUnixChannelType = "[email protected]"
)

// directStreamLocalChannelData data struct as specified in OpenSSH's protocol
// extensions document, Section 2.4.
// https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL?annotate=HEAD
type directStreamLocalChannelData struct {
SocketPath string

Reserved1 string
Reserved2 uint32
}

// DirectStreamLocalHandler provides Unix forwarding from client -> server. It
// can be enabled by adding it to the server's ChannelHandlers under
// `[email protected]`.
//
// Unix socket support on Windows is not widely available, so this handler may
// not work on all Windows installations and is not tested on Windows.
func DirectStreamLocalHandler(srv *Server, _ *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
var d directStreamLocalChannelData
err := gossh.Unmarshal(newChan.ExtraData(), &d)
if err != nil {
_ = newChan.Reject(gossh.ConnectionFailed, "error parsing direct-streamlocal data: "+err.Error())
return
}

if srv.LocalUnixForwardingCallback == nil {
_ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled")
return
}
dconn, err := srv.LocalUnixForwardingCallback(ctx, d.SocketPath)
if err != nil {
if errors.Is(err, ErrRejected) {
_ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled")
return
}
_ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %+v", d.SocketPath, err.Error()))
return
}

ch, reqs, err := newChan.Accept()
if err != nil {
_ = dconn.Close()
return
}
go gossh.DiscardRequests(reqs)

bicopy(ctx, ch, dconn)
}

// remoteUnixForwardRequest describes the extra data sent in a
// [email protected] containing the socket path to bind to.
type remoteUnixForwardRequest struct {
SocketPath string
}

// remoteUnixForwardChannelData describes the data sent as the payload in the new
// channel request when a Unix connection is accepted by the listener.
type remoteUnixForwardChannelData struct {
SocketPath string
Reserved uint32
}

// ForwardedUnixHandler can be enabled by creating a ForwardedUnixHandler and
// adding the HandleSSHRequest callback to the server's RequestHandlers under
// `[email protected]` and
// `[email protected]`
//
// Unix socket support on Windows is not widely available, so this handler may
// not work on all Windows installations and is not tested on Windows.
type ForwardedUnixHandler struct {
sync.Mutex
forwards map[string]net.Listener
}

func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
h.Lock()
if h.forwards == nil {
h.forwards = make(map[string]net.Listener)
}
h.Unlock()
conn, ok := ctx.Value(ContextKeyConn).(*gossh.ServerConn)
if !ok {
// TODO: log cast failure
return false, nil
}

switch req.Type {
case "[email protected]":
var reqPayload remoteUnixForwardRequest
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
// TODO: log parse failure
return false, nil
}

if srv.ReverseUnixForwardingCallback == nil {
return false, []byte("unix forwarding is disabled")
}

addr := reqPayload.SocketPath
h.Lock()
_, ok := h.forwards[addr]
h.Unlock()
if ok {
// TODO: log failure
return false, nil
}

ln, err := srv.ReverseUnixForwardingCallback(ctx, addr)
if err != nil {
if errors.Is(err, ErrRejected) {
return false, []byte("unix forwarding is disabled")
}
// TODO: log unix listen failure
return false, nil
}

// The listener needs to successfully start before it can be added to
// the map, so we don't have to worry about checking for an existing
// listener as you can't listen on the same socket twice.
//
// This is also what the TCP version of this code does.
h.Lock()
h.forwards[addr] = ln
h.Unlock()

ctx, cancel := context.WithCancel(ctx)
go func() {
<-ctx.Done()
_ = ln.Close()
}()
go func() {
defer cancel()

for {
c, err := ln.Accept()
if err != nil {
// closed below
break
}
payload := gossh.Marshal(&remoteUnixForwardChannelData{
SocketPath: addr,
})

go func() {
ch, reqs, err := conn.OpenChannel(forwardedUnixChannelType, payload)
if err != nil {
_ = c.Close()
return
}
go gossh.DiscardRequests(reqs)
bicopy(ctx, ch, c)
}()
}

h.Lock()
ln2, ok := h.forwards[addr]
if ok && ln2 == ln {
delete(h.forwards, addr)
}
h.Unlock()
_ = ln.Close()
}()

return true, nil

case "[email protected]":
var reqPayload remoteUnixForwardRequest
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
// TODO: log parse failure
return false, nil
}
h.Lock()
ln, ok := h.forwards[reqPayload.SocketPath]
h.Unlock()
if ok {
_ = ln.Close()
}
return true, nil

default:
return false, nil
}
}

// unlink removes files and unlike os.Remove, directories are kept.
func unlink(path string) error {
// Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go
// for more details.
for {
err := syscall.Unlink(path)
if !errors.Is(err, syscall.EINTR) {
return err
}
}
}

// SimpleUnixLocalForwardingCallback provides a basic implementation for
// LocalUnixForwardingCallback. It will simply dial the requested socket using
// a context-aware dialer.
func SimpleUnixLocalForwardingCallback(ctx Context, socketPath string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", socketPath)
}

// SimpleUnixReverseForwardingCallback provides a basic implementation for
// ReverseUnixForwardingCallback. The parent directory will be created (with
// os.MkdirAll), and existing files with the same name will be removed.
func SimpleUnixReverseForwardingCallback(_ Context, socketPath string) (net.Listener, error) {
// Create socket parent dir if not exists.
parentDir := filepath.Dir(socketPath)
err := os.MkdirAll(parentDir, 0700)
if err != nil {
return nil, fmt.Errorf("failed to create parent directory %q for socket %q: %w", parentDir, socketPath, err)
}

// Remove existing socket if it exists. We do not use os.Remove() here
// so that directories are kept. Note that it's possible that we will
// overwrite a regular file here. Both of these behaviors match OpenSSH,
// however, which is why we unlink.
err = unlink(socketPath)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("failed to remove existing file in socket path %q: %w", socketPath, err)
}

ln, err := net.Listen("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("failed to listen on unix socket %q: %w", socketPath, err)
}

return ln, err
}
Loading

0 comments on commit f9faa02

Please sign in to comment.