diff --git a/src/go/networking/cmd/proxy/wsl_integration_linux.go b/src/go/networking/cmd/proxy/wsl_integration_linux.go index 97bb02c87d3..4676c1c877f 100644 --- a/src/go/networking/cmd/proxy/wsl_integration_linux.go +++ b/src/go/networking/cmd/proxy/wsl_integration_linux.go @@ -32,12 +32,15 @@ var ( logFile string socketFile string upstreamAddr string + udpBuffer int ) const ( defaultLogPath = "/var/log/wsl-proxy.log" defaultSocket = "/run/wsl-proxy.sock" bridgeIPAddr = "192.168.143.1" + // Set UDP buffer size to 8 MB + defaultUDPBufferSize = 8 * 1024 * 1024 // 8 MB in bytes ) func main() { @@ -45,6 +48,7 @@ func main() { flag.StringVar(&logFile, "logfile", defaultLogPath, "path to the logfile for wsl-proxy process") flag.StringVar(&socketFile, "socketFile", defaultSocket, "path to the .sock file for UNIX socket") flag.StringVar(&upstreamAddr, "upstreamAddress", bridgeIPAddr, "IP address of the upstream server to forward to") + flag.IntVar(&udpBuffer, "udpBuffer", defaultUDPBufferSize, "max buffer size in bytes for UDP socket I/O") flag.Parse() setupLogging(logFile) @@ -54,7 +58,11 @@ func main() { logrus.Fatalf("failed to create listener for published ports: %s", err) return } - proxy := portproxy.NewPortProxy(socket, bridgeIPAddr) + proxyConfig := &portproxy.ProxyConfig{ + UpstreamAddress: upstreamAddr, + UDPBufferSize: udpBuffer, + } + proxy := portproxy.NewPortProxy(socket, proxyConfig) // Handle graceful shutdown sigCh := make(chan os.Signal, 1) diff --git a/src/go/networking/pkg/portproxy/server.go b/src/go/networking/pkg/portproxy/server.go index 8d59e423a6c..7516246864e 100644 --- a/src/go/networking/pkg/portproxy/server.go +++ b/src/go/networking/pkg/portproxy/server.go @@ -21,34 +21,44 @@ import ( "net" "sync" + gvisorTypes "github.com/containers/gvisor-tap-vsock/pkg/types" "github.com/docker/go-connections/nat" "github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/types" "github.com/rancher-sandbox/rancher-desktop/src/go/networking/pkg/utils" "github.com/sirupsen/logrus" ) +type ProxyConfig struct { + UpstreamAddress string + UDPBufferSize int +} + type PortProxy struct { - upstreamAddress string - listener net.Listener - quit chan struct{} - // map of port number as a key to associated listener + config *ProxyConfig + listener net.Listener + quit chan struct{} + // map of TCP port number as a key to associated listener activeListeners map[int]net.Listener - mutex sync.Mutex - wg sync.WaitGroup + listenerMutex sync.Mutex + // map of UDP port number as a key to associated UDPConn + activeUDPConns map[int]*net.UDPConn + udpConnMutex sync.Mutex + wg sync.WaitGroup } -func NewPortProxy(listener net.Listener, upstreamAddr string) *PortProxy { +func NewPortProxy(listener net.Listener, cfg *ProxyConfig) *PortProxy { portProxy := &PortProxy{ - upstreamAddress: upstreamAddr, + config: cfg, listener: listener, quit: make(chan struct{}), activeListeners: make(map[int]net.Listener), + activeUDPConns: make(map[int]*net.UDPConn), } return portProxy } func (p *PortProxy) Start() error { - logrus.Infof("Proxy server started accepting on %s, forwarding to %s", p.listener.Addr(), p.upstreamAddress) + logrus.Infof("Proxy server started accepting on %s, forwarding to %s", p.listener.Addr(), p.config.UpstreamAddress) for { conn, err := p.listener.Accept() if err != nil { @@ -65,6 +75,12 @@ func (p *PortProxy) Start() error { } } +func (p *PortProxy) UDPPortMappings() map[int]*net.UDPConn { + p.udpConnMutex.Lock() + defer p.udpConnMutex.Unlock() + return p.activeUDPConns +} + func (p *PortProxy) handleEvent(conn net.Conn) { defer conn.Close() @@ -73,47 +89,144 @@ func (p *PortProxy) handleEvent(conn net.Conn) { logrus.Errorf("port server decoding received payload error: %s", err) return } - p.execListener(pm) + p.exec(pm) } -func (p *PortProxy) execListener(pm types.PortMapping) { - for _, portBindings := range pm.Ports { - for _, portBinding := range portBindings { - logrus.Debugf("received the following port: [%s] from portMapping: %+v", portBinding.HostPort, pm) - port, err := nat.ParsePort(portBinding.HostPort) - if err != nil { - logrus.Errorf("parsing port error: %s", err) - continue - } - if pm.Remove { - p.mutex.Lock() - if listener, exist := p.activeListeners[port]; exist { - logrus.Debugf("closing listener for port: %d", port) - if err := listener.Close(); err != nil { - logrus.Errorf("error closing listener for port [%s]: %s", portBinding.HostPort, err) - } +func (p *PortProxy) exec(pm types.PortMapping) { + for portProto, portBindings := range pm.Ports { + logrus.Debugf("received the following port: [%s] and protocol: [%s] from portMapping: %+v", portProto.Port(), portProto.Proto(), pm) + + switch gvisorTypes.TransportProtocol(portProto.Proto()) { + case gvisorTypes.TCP: + p.handleTCP(portBindings, pm.Remove) + case gvisorTypes.UDP: + p.handleUDP(portBindings, pm.Remove) + default: + logrus.Warnf("unsupported protocol: [%s]", portProto.Proto()) + } + } +} + +func (p *PortProxy) handleUDP(portBindings []nat.PortBinding, remove bool) { + for _, portBinding := range portBindings { + port, err := nat.ParsePort(portBinding.HostPort) + if err != nil { + logrus.Errorf("parsing port error: %s", err) + continue + } + if remove { + p.udpConnMutex.Lock() + if udpConn, exist := p.activeUDPConns[port]; exist { + if err := udpConn.Close(); err != nil { + logrus.Errorf("error closing UDPConn for port [%s]: %s", portBinding.HostPort, err) } - delete(p.activeListeners, port) - p.mutex.Unlock() - continue } - addr := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort) - l, err := net.Listen("tcp", addr) - if err != nil { - logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err) - continue + delete(p.activeUDPConns, port) + p.udpConnMutex.Unlock() + logrus.Debugf("closing UDPConn for port: %d", port) + continue + } + + // the localAddress IP section can either be 0.0.0.0 or 127.0.0.1 + localAddress := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort) + sourceAddr, err := net.ResolveUDPAddr("udp", localAddress) + if err != nil { + logrus.Errorf("failed to resolve UDP source address [%s]: %s", sourceAddr, err) + continue + } + + c, err := net.ListenUDP("udp", sourceAddr) + if err != nil { + logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err) + continue + } + + forwardAddr := net.JoinHostPort(p.config.UpstreamAddress, portBinding.HostPort) + targetAddr, err := net.ResolveUDPAddr("udp", forwardAddr) + if err != nil { + c.Close() + logrus.Errorf("failed to resolve UDP target address [%s]: %s", targetAddr, err) + continue + } + + p.udpConnMutex.Lock() + p.activeUDPConns[port] = c + p.udpConnMutex.Unlock() + logrus.Debugf("created UDPConn for: %v", sourceAddr) + + go p.acceptUDPConn(c, targetAddr) + } +} + +func (p *PortProxy) acceptUDPConn(sourceConn *net.UDPConn, targetAddr *net.UDPAddr) { + targetConn, err := net.DialUDP("udp", nil, targetAddr) + if err != nil { + logrus.Errorf("failed to connect to target address: %s : %s", targetAddr, err) + return + } + defer targetConn.Close() + p.wg.Add(1) + for { + b := make([]byte, p.config.UDPBufferSize) + n, addr, err := sourceConn.ReadFromUDP(b) + if err != nil && n == 0 { + logrus.Errorf("error reading UDP packet from source: %s : %s", addr, err) + if errors.Is(err, net.ErrClosed) { + p.wg.Done() + break + } + continue + } + logrus.Debugf("received %d data from %s", n, addr) + + n, err = targetConn.Write(b[:n]) + if err != nil { + logrus.Errorf("error forwarding UDP packet to target: %s : %s", targetAddr, err) + if errors.Is(err, net.ErrClosed) { + p.wg.Done() + break } - p.mutex.Lock() - p.activeListeners[port] = l - p.mutex.Unlock() - logrus.Debugf("created listener for: %s", addr) - go p.acceptTraffic(l, portBinding.HostPort) + continue } + logrus.Debugf("sent %d data to %s", n, targetAddr) + } +} + +func (p *PortProxy) handleTCP(portBindings []nat.PortBinding, remove bool) { + for _, portBinding := range portBindings { + port, err := nat.ParsePort(portBinding.HostPort) + if err != nil { + logrus.Errorf("parsing port error: %s", err) + continue + } + if remove { + p.listenerMutex.Lock() + if listener, exist := p.activeListeners[port]; exist { + logrus.Debugf("closing listener for port: %d", port) + if err := listener.Close(); err != nil { + logrus.Errorf("error closing listener for port [%s]: %s", portBinding.HostPort, err) + } + } + delete(p.activeListeners, port) + p.listenerMutex.Unlock() + continue + } + addr := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort) + l, err := net.Listen("tcp", addr) + if err != nil { + logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err) + continue + } + p.listenerMutex.Lock() + p.activeListeners[port] = l + p.listenerMutex.Unlock() + logrus.Debugf("created listener for: %s", addr) + go p.acceptTraffic(l, portBinding.HostPort) } } func (p *PortProxy) acceptTraffic(listener net.Listener, port string) { - forwardAddr := net.JoinHostPort(p.upstreamAddress, port) + forwardAddr := net.JoinHostPort(p.config.UpstreamAddress, port) for { conn, err := listener.Accept() if err != nil { @@ -124,7 +237,7 @@ func (p *PortProxy) acceptTraffic(listener net.Listener, port string) { logrus.Errorf("port proxy listener failed to accept: %s", err) continue } - logrus.Debugf("port proxy accepted connection from %s", conn.RemoteAddr()) + logrus.Debugf("port proxy accepted TCP connection from %s", conn.RemoteAddr()) p.wg.Add(1) go func(conn net.Conn) { @@ -139,6 +252,9 @@ func (p *PortProxy) Close() error { // Close all the active listeners p.cleanupListeners() + // Close all active UDP connections + p.cleanupUDPConns() + // Close the listener first to prevent new connections. err := p.listener.Close() if err != nil { @@ -155,7 +271,17 @@ func (p *PortProxy) Close() error { } func (p *PortProxy) cleanupListeners() { + p.listenerMutex.Lock() + defer p.listenerMutex.Unlock() for _, l := range p.activeListeners { _ = l.Close() } } + +func (p *PortProxy) cleanupUDPConns() { + p.udpConnMutex.Lock() + defer p.udpConnMutex.Unlock() + for _, c := range p.activeUDPConns { + _ = c.Close() + } +} diff --git a/src/go/networking/pkg/portproxy/server_test.go b/src/go/networking/pkg/portproxy/server_test.go index d0b2304998b..473a96e8431 100644 --- a/src/go/networking/pkg/portproxy/server_test.go +++ b/src/go/networking/pkg/portproxy/server_test.go @@ -24,18 +24,91 @@ import ( "net/http" "syscall" "testing" + "time" "github.com/docker/go-connections/nat" "github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/types" "github.com/rancher-sandbox/rancher-desktop/src/go/networking/pkg/portproxy" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "golang.org/x/net/nettest" ) -func TestNewPortProxy(t *testing.T) { - logrus.SetLevel(logrus.DebugLevel) +func TestNewPortProxyUDP(t *testing.T) { + testServerIP, err := availableIP() + require.NoError(t, err, "cannot continue with the test since there are no available IP addresses") + + remoteAddr := net.JoinHostPort(testServerIP, "0") + targetAddr, err := net.ResolveUDPAddr("udp", remoteAddr) + require.NoError(t, err) + targetConn, err := net.ListenUDP("udp", targetAddr) + require.NoError(t, err) + + t.Logf("created the following UDP target listener: %s", targetConn.LocalAddr().String()) + localListener, err := nettest.NewLocalListener("unix") + require.NoError(t, err) + defer localListener.Close() + + proxyConfig := &portproxy.ProxyConfig{ + UpstreamAddress: testServerIP, + UDPBufferSize: 1024, + } + portProxy := portproxy.NewPortProxy(localListener, proxyConfig) + go portProxy.Start() + + _, testPort, err := net.SplitHostPort(targetConn.LocalAddr().String()) + require.NoError(t, err) + + port, err := nat.NewPort("udp", testPort) + require.NoError(t, err) + + portMapping := types.PortMapping{ + Remove: false, + Ports: nat.PortMap{ + port: []nat.PortBinding{ + { + HostIP: "127.0.0.1", + HostPort: testPort, + }, + }, + }, + } + t.Logf("sending the following portMapping to portProxy: %+v", portMapping) + err = marshalAndSend(localListener, portMapping) + require.NoError(t, err) + + // indicate when UDP mappings are ready + for len(portProxy.UDPPortMappings()) == 0 { + time.Sleep(100 * time.Millisecond) + } + + t.Log("UDP port mappings are set up") + + localAddr := net.JoinHostPort("127.0.0.1", testPort) + sourceAddr, err := net.ResolveUDPAddr("udp", localAddr) + require.NoError(t, err) + sourceConn, err := net.DialUDP("udp", nil, sourceAddr) + require.NoError(t, err) + t.Logf("dialing in to the following UDP connection: %s", localAddr) + + expectedString := "this is what we expect" + _, err = sourceConn.Write([]byte(expectedString)) + require.NoError(t, err) + + targetConn.SetDeadline(time.Now().Add(time.Second * 5)) + + b := make([]byte, len(expectedString)) + n, _, err := targetConn.ReadFromUDP(b) + require.NoError(t, err) + require.Equal(t, n, len(expectedString)) + require.Equal(t, string(b), expectedString) + + targetConn.Close() + sourceConn.Close() + portProxy.Close() +} + +func TestNewPortProxyTCP(t *testing.T) { expectedResponse := "called the upstream server" testServerIP, err := availableIP() @@ -61,7 +134,10 @@ func TestNewPortProxy(t *testing.T) { require.NoError(t, err) defer localListener.Close() - portProxy := portproxy.NewPortProxy(localListener, testServerIP) + proxyConfig := &portproxy.ProxyConfig{ + UpstreamAddress: testServerIP, + } + portProxy := portproxy.NewPortProxy(localListener, proxyConfig) go portProxy.Start() getURL := fmt.Sprintf("http://localhost:%s", testPort) @@ -79,12 +155,13 @@ func TestNewPortProxy(t *testing.T) { Ports: nat.PortMap{ port: []nat.PortBinding{ { - HostIP: testServerIP, + HostIP: "127.0.0.1", HostPort: testPort, }, }, }, } + t.Logf("sending the following portMapping to portProxy: %+v", portMapping) err = marshalAndSend(localListener, portMapping) require.NoError(t, err) @@ -101,7 +178,7 @@ func TestNewPortProxy(t *testing.T) { Ports: nat.PortMap{ port: []nat.PortBinding{ { - HostIP: testServerIP, + HostIP: "127.0.0.1", HostPort: testPort, }, },