From aa2d683e627395b1552071c831a4a96aabd1ad11 Mon Sep 17 00:00:00 2001 From: Nino Kodabande Date: Thu, 10 Oct 2024 10:33:15 -0700 Subject: [PATCH 1/4] Support UDP in wsl-proxy The wsl-proxy previously only supported TCP connections. With this change, it now enables the listening and handling of UDP packets in the proxy. Signed-off-by: Nino Kodabande --- .../cmd/proxy/wsl_integration_linux.go | 10 +- src/go/networking/pkg/portproxy/server.go | 198 ++++++++++++++---- .../networking/pkg/portproxy/server_test.go | 5 +- 3 files changed, 170 insertions(+), 43 deletions(-) diff --git a/src/go/networking/cmd/proxy/wsl_integration_linux.go b/src/go/networking/cmd/proxy/wsl_integration_linux.go index 97bb02c87d3..4676c1c877f 100644 --- a/src/go/networking/cmd/proxy/wsl_integration_linux.go +++ b/src/go/networking/cmd/proxy/wsl_integration_linux.go @@ -32,12 +32,15 @@ var ( logFile string socketFile string upstreamAddr string + udpBuffer int ) const ( defaultLogPath = "/var/log/wsl-proxy.log" defaultSocket = "/run/wsl-proxy.sock" bridgeIPAddr = "192.168.143.1" + // Set UDP buffer size to 8 MB + defaultUDPBufferSize = 8 * 1024 * 1024 // 8 MB in bytes ) func main() { @@ -45,6 +48,7 @@ func main() { flag.StringVar(&logFile, "logfile", defaultLogPath, "path to the logfile for wsl-proxy process") flag.StringVar(&socketFile, "socketFile", defaultSocket, "path to the .sock file for UNIX socket") flag.StringVar(&upstreamAddr, "upstreamAddress", bridgeIPAddr, "IP address of the upstream server to forward to") + flag.IntVar(&udpBuffer, "udpBuffer", defaultUDPBufferSize, "max buffer size in bytes for UDP socket I/O") flag.Parse() setupLogging(logFile) @@ -54,7 +58,11 @@ func main() { logrus.Fatalf("failed to create listener for published ports: %s", err) return } - proxy := portproxy.NewPortProxy(socket, bridgeIPAddr) + proxyConfig := &portproxy.ProxyConfig{ + UpstreamAddress: upstreamAddr, + UDPBufferSize: udpBuffer, + } + proxy := portproxy.NewPortProxy(socket, proxyConfig) // Handle graceful shutdown sigCh := make(chan os.Signal, 1) diff --git a/src/go/networking/pkg/portproxy/server.go b/src/go/networking/pkg/portproxy/server.go index 8d59e423a6c..1201ee0ebf4 100644 --- a/src/go/networking/pkg/portproxy/server.go +++ b/src/go/networking/pkg/portproxy/server.go @@ -21,34 +21,44 @@ import ( "net" "sync" + gvisorTypes "github.com/containers/gvisor-tap-vsock/pkg/types" "github.com/docker/go-connections/nat" "github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/types" "github.com/rancher-sandbox/rancher-desktop/src/go/networking/pkg/utils" "github.com/sirupsen/logrus" ) +type ProxyConfig struct { + UpstreamAddress string + UDPBufferSize int +} + type PortProxy struct { - upstreamAddress string - listener net.Listener - quit chan struct{} - // map of port number as a key to associated listener + config *ProxyConfig + listener net.Listener + quit chan struct{} + // map of TCP port number as a key to associated listener activeListeners map[int]net.Listener - mutex sync.Mutex - wg sync.WaitGroup + listenerMutex sync.Mutex + // map of UDP port number as a key to associated UDPConn + activeUDPConns map[int]*net.UDPConn + udpConnMutex sync.Mutex + wg sync.WaitGroup } -func NewPortProxy(listener net.Listener, upstreamAddr string) *PortProxy { +func NewPortProxy(listener net.Listener, cfg *ProxyConfig) *PortProxy { portProxy := &PortProxy{ - upstreamAddress: upstreamAddr, + config: cfg, listener: listener, quit: make(chan struct{}), activeListeners: make(map[int]net.Listener), + activeUDPConns: make(map[int]*net.UDPConn), } return portProxy } func (p *PortProxy) Start() error { - logrus.Infof("Proxy server started accepting on %s, forwarding to %s", p.listener.Addr(), p.upstreamAddress) + logrus.Infof("Proxy server started accepting on %s, forwarding to %s", p.listener.Addr(), p.config.UpstreamAddress) for { conn, err := p.listener.Accept() if err != nil { @@ -73,47 +83,144 @@ func (p *PortProxy) handleEvent(conn net.Conn) { logrus.Errorf("port server decoding received payload error: %s", err) return } - p.execListener(pm) + p.exec(pm) } -func (p *PortProxy) execListener(pm types.PortMapping) { - for _, portBindings := range pm.Ports { - for _, portBinding := range portBindings { - logrus.Debugf("received the following port: [%s] from portMapping: %+v", portBinding.HostPort, pm) - port, err := nat.ParsePort(portBinding.HostPort) - if err != nil { - logrus.Errorf("parsing port error: %s", err) - continue - } - if pm.Remove { - p.mutex.Lock() - if listener, exist := p.activeListeners[port]; exist { - logrus.Debugf("closing listener for port: %d", port) - if err := listener.Close(); err != nil { - logrus.Errorf("error closing listener for port [%s]: %s", portBinding.HostPort, err) - } +func (p *PortProxy) exec(pm types.PortMapping) { + for portProto, portBindings := range pm.Ports { + logrus.Debugf("received the following port: [%s] and protocol: [%s] from portMapping: %+v", portProto.Port(), portProto.Proto(), pm) + + switch gvisorTypes.TransportProtocol(portProto.Proto()) { + case gvisorTypes.TCP: + p.handleTCP(portBindings, pm.Remove) + case gvisorTypes.UDP: + p.handleUDP(portBindings, pm.Remove) + default: + logrus.Warnf("unsupported protocol: [%s]", portProto.Proto()) + } + } +} + +func (p *PortProxy) handleUDP(portBindings []nat.PortBinding, remove bool) { + for _, portBinding := range portBindings { + port, err := nat.ParsePort(portBinding.HostPort) + if err != nil { + logrus.Errorf("parsing port error: %s", err) + continue + } + if remove { + p.udpConnMutex.Lock() + if udpConn, exist := p.activeUDPConns[port]; exist { + if err := udpConn.Close(); err != nil { + logrus.Errorf("error closing UDPConn for port [%s]: %s", portBinding.HostPort, err) } - delete(p.activeListeners, port) - p.mutex.Unlock() - continue } - addr := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort) - l, err := net.Listen("tcp", addr) - if err != nil { - logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err) - continue + delete(p.activeUDPConns, port) + p.udpConnMutex.Unlock() + logrus.Debugf("closing UDPConn for port: %d", port) + continue + } + + // the localAddress IP section can either be 0.0.0.0 or 127.0.0.1 + localAddress := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort) + sourceAddr, err := net.ResolveUDPAddr("udp", localAddress) + if err != nil { + logrus.Errorf("failed to resolve UDP source address [%s]: %s", sourceAddr, err) + continue + } + + c, err := net.ListenUDP("udp", sourceAddr) + if err != nil { + logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err) + continue + } + + forwardAddr := net.JoinHostPort(p.config.UpstreamAddress, portBinding.HostPort) + targetAddr, err := net.ResolveUDPAddr("udp", forwardAddr) + if err != nil { + c.Close() + logrus.Errorf("failed to resolve UDP target address [%s]: %s", targetAddr, err) + continue + } + + p.udpConnMutex.Lock() + p.activeUDPConns[port] = c + p.udpConnMutex.Unlock() + logrus.Debugf("created UDPConn for: %v", sourceAddr) + + go p.acceptUDPConn(c, targetAddr) + } +} + +func (p *PortProxy) acceptUDPConn(sourceConn *net.UDPConn, targetAddr *net.UDPAddr) { + targetConn, err := net.DialUDP("udp", nil, targetAddr) + if err != nil { + logrus.Errorf("failed to connect to target address: %s : %s", targetAddr, err) + return + } + defer targetConn.Close() + p.wg.Add(1) + for { + b := make([]byte, p.config.UDPBufferSize) + n, addr, err := sourceConn.ReadFromUDP(b) + if err != nil && n == 0 { + logrus.Errorf("error reading UDP packet from source: %s : %s", addr, err) + if errors.Is(err, net.ErrClosed) { + p.wg.Done() + break + } + continue + } + logrus.Debugf("received %d data from %s", n, addr) + + n, err = targetConn.Write(b[:n]) + if err != nil { + logrus.Errorf("error forwarding UDP packet to target: %s : %s", targetAddr, err) + if errors.Is(err, net.ErrClosed) { + p.wg.Done() + break + } + continue + } + logrus.Debugf("sent %d data to %s", n, targetAddr) + } +} + +func (p *PortProxy) handleTCP(portBindings []nat.PortBinding, remove bool) { + for _, portBinding := range portBindings { + port, err := nat.ParsePort(portBinding.HostPort) + if err != nil { + logrus.Errorf("parsing port error: %s", err) + continue + } + if remove { + p.listenerMutex.Lock() + if listener, exist := p.activeListeners[port]; exist { + logrus.Debugf("closing listener for port: %d", port) + if err := listener.Close(); err != nil { + logrus.Errorf("error closing listener for port [%s]: %s", portBinding.HostPort, err) + } } - p.mutex.Lock() - p.activeListeners[port] = l - p.mutex.Unlock() - logrus.Debugf("created listener for: %s", addr) - go p.acceptTraffic(l, portBinding.HostPort) + delete(p.activeListeners, port) + p.listenerMutex.Unlock() + continue } + addr := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort) + l, err := net.Listen("tcp", addr) + if err != nil { + logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err) + continue + } + p.listenerMutex.Lock() + p.activeListeners[port] = l + p.listenerMutex.Unlock() + logrus.Debugf("created listener for: %s", addr) + go p.acceptTraffic(l, portBinding.HostPort) } } func (p *PortProxy) acceptTraffic(listener net.Listener, port string) { - forwardAddr := net.JoinHostPort(p.upstreamAddress, port) + forwardAddr := net.JoinHostPort(p.config.UpstreamAddress, port) for { conn, err := listener.Accept() if err != nil { @@ -124,7 +231,7 @@ func (p *PortProxy) acceptTraffic(listener net.Listener, port string) { logrus.Errorf("port proxy listener failed to accept: %s", err) continue } - logrus.Debugf("port proxy accepted connection from %s", conn.RemoteAddr()) + logrus.Debugf("port proxy accepted TCP connection from %s", conn.RemoteAddr()) p.wg.Add(1) go func(conn net.Conn) { @@ -139,6 +246,9 @@ func (p *PortProxy) Close() error { // Close all the active listeners p.cleanupListeners() + // Close all active UDP connections + p.cleanupUDPConns() + // Close the listener first to prevent new connections. err := p.listener.Close() if err != nil { @@ -159,3 +269,9 @@ func (p *PortProxy) cleanupListeners() { _ = l.Close() } } + +func (p *PortProxy) cleanupUDPConns() { + for _, c := range p.activeUDPConns { + _ = c.Close() + } +} diff --git a/src/go/networking/pkg/portproxy/server_test.go b/src/go/networking/pkg/portproxy/server_test.go index d0b2304998b..10850c892d4 100644 --- a/src/go/networking/pkg/portproxy/server_test.go +++ b/src/go/networking/pkg/portproxy/server_test.go @@ -61,7 +61,10 @@ func TestNewPortProxy(t *testing.T) { require.NoError(t, err) defer localListener.Close() - portProxy := portproxy.NewPortProxy(localListener, testServerIP) + proxyConfig := &portproxy.ProxyConfig{ + UpstreamAddress: testServerIP, + } + portProxy := portproxy.NewPortProxy(localListener, proxyConfig) go portProxy.Start() getURL := fmt.Sprintf("http://localhost:%s", testPort) From 2beb262c8a1b9dc335fda2ff6918e285f7ca4607 Mon Sep 17 00:00:00 2001 From: Nino Kodabande Date: Fri, 11 Oct 2024 00:08:10 -0700 Subject: [PATCH 2/4] Fixes the PortProxy TCP test Signed-off-by: Nino Kodabande --- src/go/networking/pkg/portproxy/server_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/go/networking/pkg/portproxy/server_test.go b/src/go/networking/pkg/portproxy/server_test.go index 10850c892d4..74cfc6b8a04 100644 --- a/src/go/networking/pkg/portproxy/server_test.go +++ b/src/go/networking/pkg/portproxy/server_test.go @@ -82,7 +82,7 @@ func TestNewPortProxy(t *testing.T) { Ports: nat.PortMap{ port: []nat.PortBinding{ { - HostIP: testServerIP, + HostIP: "127.0.0.1", HostPort: testPort, }, }, @@ -104,7 +104,7 @@ func TestNewPortProxy(t *testing.T) { Ports: nat.PortMap{ port: []nat.PortBinding{ { - HostIP: testServerIP, + HostIP: "127.0.0.1", HostPort: testPort, }, }, From c34149e814c29cc04727b74ab5c64f69f674c487 Mon Sep 17 00:00:00 2001 From: Nino Kodabande Date: Fri, 11 Oct 2024 10:59:58 -0700 Subject: [PATCH 3/4] Add a test for UDP connectivity Signed-off-by: Nino Kodabande --- src/go/networking/pkg/portproxy/server.go | 6 ++ .../networking/pkg/portproxy/server_test.go | 80 ++++++++++++++++++- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/src/go/networking/pkg/portproxy/server.go b/src/go/networking/pkg/portproxy/server.go index 1201ee0ebf4..da82798b25a 100644 --- a/src/go/networking/pkg/portproxy/server.go +++ b/src/go/networking/pkg/portproxy/server.go @@ -75,6 +75,12 @@ func (p *PortProxy) Start() error { } } +func (p *PortProxy) UDPPortMappings() map[int]*net.UDPConn { + p.udpConnMutex.Lock() + defer p.udpConnMutex.Unlock() + return p.activeUDPConns +} + func (p *PortProxy) handleEvent(conn net.Conn) { defer conn.Close() diff --git a/src/go/networking/pkg/portproxy/server_test.go b/src/go/networking/pkg/portproxy/server_test.go index 74cfc6b8a04..473a96e8431 100644 --- a/src/go/networking/pkg/portproxy/server_test.go +++ b/src/go/networking/pkg/portproxy/server_test.go @@ -24,18 +24,91 @@ import ( "net/http" "syscall" "testing" + "time" "github.com/docker/go-connections/nat" "github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/types" "github.com/rancher-sandbox/rancher-desktop/src/go/networking/pkg/portproxy" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "golang.org/x/net/nettest" ) -func TestNewPortProxy(t *testing.T) { - logrus.SetLevel(logrus.DebugLevel) +func TestNewPortProxyUDP(t *testing.T) { + testServerIP, err := availableIP() + require.NoError(t, err, "cannot continue with the test since there are no available IP addresses") + + remoteAddr := net.JoinHostPort(testServerIP, "0") + targetAddr, err := net.ResolveUDPAddr("udp", remoteAddr) + require.NoError(t, err) + targetConn, err := net.ListenUDP("udp", targetAddr) + require.NoError(t, err) + + t.Logf("created the following UDP target listener: %s", targetConn.LocalAddr().String()) + + localListener, err := nettest.NewLocalListener("unix") + require.NoError(t, err) + defer localListener.Close() + + proxyConfig := &portproxy.ProxyConfig{ + UpstreamAddress: testServerIP, + UDPBufferSize: 1024, + } + portProxy := portproxy.NewPortProxy(localListener, proxyConfig) + go portProxy.Start() + + _, testPort, err := net.SplitHostPort(targetConn.LocalAddr().String()) + require.NoError(t, err) + + port, err := nat.NewPort("udp", testPort) + require.NoError(t, err) + + portMapping := types.PortMapping{ + Remove: false, + Ports: nat.PortMap{ + port: []nat.PortBinding{ + { + HostIP: "127.0.0.1", + HostPort: testPort, + }, + }, + }, + } + t.Logf("sending the following portMapping to portProxy: %+v", portMapping) + err = marshalAndSend(localListener, portMapping) + require.NoError(t, err) + + // indicate when UDP mappings are ready + for len(portProxy.UDPPortMappings()) == 0 { + time.Sleep(100 * time.Millisecond) + } + + t.Log("UDP port mappings are set up") + + localAddr := net.JoinHostPort("127.0.0.1", testPort) + sourceAddr, err := net.ResolveUDPAddr("udp", localAddr) + require.NoError(t, err) + sourceConn, err := net.DialUDP("udp", nil, sourceAddr) + require.NoError(t, err) + t.Logf("dialing in to the following UDP connection: %s", localAddr) + + expectedString := "this is what we expect" + _, err = sourceConn.Write([]byte(expectedString)) + require.NoError(t, err) + + targetConn.SetDeadline(time.Now().Add(time.Second * 5)) + + b := make([]byte, len(expectedString)) + n, _, err := targetConn.ReadFromUDP(b) + require.NoError(t, err) + require.Equal(t, n, len(expectedString)) + require.Equal(t, string(b), expectedString) + + targetConn.Close() + sourceConn.Close() + portProxy.Close() +} +func TestNewPortProxyTCP(t *testing.T) { expectedResponse := "called the upstream server" testServerIP, err := availableIP() @@ -88,6 +161,7 @@ func TestNewPortProxy(t *testing.T) { }, }, } + t.Logf("sending the following portMapping to portProxy: %+v", portMapping) err = marshalAndSend(localListener, portMapping) require.NoError(t, err) From 6d9ab24180daaf23d515b58cdd81f6568e7bb949 Mon Sep 17 00:00:00 2001 From: Nino Kodabande Date: Fri, 11 Oct 2024 11:39:34 -0700 Subject: [PATCH 4/4] Add the corresponding mutexes when cleaning up connections Signed-off-by: Nino Kodabande --- src/go/networking/pkg/portproxy/server.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/go/networking/pkg/portproxy/server.go b/src/go/networking/pkg/portproxy/server.go index da82798b25a..7516246864e 100644 --- a/src/go/networking/pkg/portproxy/server.go +++ b/src/go/networking/pkg/portproxy/server.go @@ -271,12 +271,16 @@ func (p *PortProxy) Close() error { } func (p *PortProxy) cleanupListeners() { + p.listenerMutex.Lock() + defer p.listenerMutex.Unlock() for _, l := range p.activeListeners { _ = l.Close() } } func (p *PortProxy) cleanupUDPConns() { + p.udpConnMutex.Lock() + defer p.udpConnMutex.Unlock() for _, c := range p.activeUDPConns { _ = c.Close() }