Skip to content

Commit

Permalink
Support UDP in wsl-proxy
Browse files Browse the repository at this point in the history
The wsl-proxy previously only supported TCP connections.
With this change, it now enables the listening and handling
of UDP packets in the proxy.

Signed-off-by: Nino Kodabande <[email protected]>
  • Loading branch information
Nino-K committed Oct 10, 2024
1 parent 9c4941f commit a9fca11
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 42 deletions.
10 changes: 9 additions & 1 deletion src/go/networking/cmd/proxy/wsl_integration_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,23 @@ var (
logFile string
socketFile string
upstreamAddr string
udpBuffer int
)

const (
defaultLogPath = "/var/log/wsl-proxy.log"
defaultSocket = "/run/wsl-proxy.sock"
bridgeIPAddr = "192.168.143.1"
// Set UDP buffer size to 8 MB
defaultUDPBufferSize = 8388608 // 8 MB in bytes
)

func main() {
flag.BoolVar(&debug, "debug", false, "enable additional debugging.")
flag.StringVar(&logFile, "logfile", defaultLogPath, "path to the logfile for wsl-proxy process")
flag.StringVar(&socketFile, "socketFile", defaultSocket, "path to the .sock file for UNIX socket")
flag.StringVar(&upstreamAddr, "upstreamAddress", bridgeIPAddr, "IP address of the upstream server to forward to")
flag.IntVar(&udpBuffer, "udpBuffer", defaultUDPBufferSize, "max buffer size for the UDP socket io")
flag.Parse()

setupLogging(logFile)
Expand All @@ -54,7 +58,11 @@ func main() {
logrus.Fatalf("failed to create listener for published ports: %s", err)
return
}
proxy := portproxy.NewPortProxy(socket, bridgeIPAddr)
proxyConfig := &portproxy.ProxyConfig{
UpstreamAddress: upstreamAddr,
UDPBufferSize: udpBuffer,
}
proxy := portproxy.NewPortProxy(socket, proxyConfig)

// Handle graceful shutdown
sigCh := make(chan os.Signal, 1)
Expand Down
194 changes: 153 additions & 41 deletions src/go/networking/pkg/portproxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,44 @@ import (
"net"
"sync"

gvisorTypes "github.com/containers/gvisor-tap-vsock/pkg/types"
"github.com/docker/go-connections/nat"
"github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/types"
"github.com/rancher-sandbox/rancher-desktop/src/go/networking/pkg/utils"
"github.com/sirupsen/logrus"
)

type ProxyConfig struct {
UpstreamAddress string
UDPBufferSize int
}

type PortProxy struct {
upstreamAddress string
listener net.Listener
quit chan struct{}
// map of port number as a key to associated listener
config *ProxyConfig
listener net.Listener
quit chan struct{}
// map of TCP port number as a key to associated listener
activeListeners map[int]net.Listener
mutex sync.Mutex
wg sync.WaitGroup
listenerMutex sync.Mutex
// map of UDP port number as a key to associated UDPConn
activeUDPConns map[int]*net.UDPConn
udpConnMutex sync.Mutex
wg sync.WaitGroup
}

func NewPortProxy(listener net.Listener, upstreamAddr string) *PortProxy {
func NewPortProxy(listener net.Listener, cfg *ProxyConfig) *PortProxy {
portProxy := &PortProxy{
upstreamAddress: upstreamAddr,
config: cfg,
listener: listener,
quit: make(chan struct{}),
activeListeners: make(map[int]net.Listener),
activeUDPConns: make(map[int]*net.UDPConn),
}
return portProxy
}

func (p *PortProxy) Start() error {
logrus.Infof("Proxy server started accepting on %s, forwarding to %s", p.listener.Addr(), p.upstreamAddress)
logrus.Infof("Proxy server started accepting on %s, forwarding to %s", p.listener.Addr(), p.config.UpstreamAddress)
for {
conn, err := p.listener.Accept()
if err != nil {
Expand All @@ -73,47 +83,140 @@ func (p *PortProxy) handleEvent(conn net.Conn) {
logrus.Errorf("port server decoding received payload error: %s", err)
return
}
p.execListener(pm)
p.exec(pm)
}

func (p *PortProxy) execListener(pm types.PortMapping) {
for _, portBindings := range pm.Ports {
for _, portBinding := range portBindings {
logrus.Debugf("received the following port: [%s] from portMapping: %+v", portBinding.HostPort, pm)
port, err := nat.ParsePort(portBinding.HostPort)
if err != nil {
logrus.Errorf("parsing port error: %s", err)
continue
}
if pm.Remove {
p.mutex.Lock()
if listener, exist := p.activeListeners[port]; exist {
logrus.Debugf("closing listener for port: %d", port)
if err := listener.Close(); err != nil {
logrus.Errorf("error closing listener for port [%s]: %s", portBinding.HostPort, err)
}
func (p *PortProxy) exec(pm types.PortMapping) {
for portProto, portBindings := range pm.Ports {
logrus.Debugf("received the following port: [%s] and protocol: [%s] from portMapping: %+v", portProto.Port(), portProto.Proto(), pm)
if gvisorTypes.TransportProtocol(portProto.Proto()) == gvisorTypes.TCP {
p.handleTCP(portBindings, pm.Remove)
}
if gvisorTypes.TransportProtocol(portProto.Proto()) == gvisorTypes.UDP {
p.handleUDP(portBindings, pm.Remove)
}
}
}

func (p *PortProxy) handleUDP(portBindings []nat.PortBinding, remove bool) {
for _, portBinding := range portBindings {
port, err := nat.ParsePort(portBinding.HostPort)
if err != nil {
logrus.Errorf("parsing port error: %s", err)
continue
}
if remove {
p.udpConnMutex.Lock()
if udpConn, exist := p.activeUDPConns[port]; exist {
logrus.Debugf("closing UDPConn for port: %d", port)
if err := udpConn.Close(); err != nil {
logrus.Errorf("error closing UDPConn for port [%s]: %s", portBinding.HostPort, err)
}
delete(p.activeListeners, port)
p.mutex.Unlock()
continue
}
addr := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort)
l, err := net.Listen("tcp", addr)
if err != nil {
logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err)
continue
delete(p.activeUDPConns, port)
p.udpConnMutex.Unlock()
continue
}

// the localAddress IP section can either be 0.0.0.0 or 127.0.0.1
localAddress := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort)
sourceAddr, err := net.ResolveUDPAddr("udp", localAddress)
if err != nil {
logrus.Errorf("failed to resolve UDP source address [%s]: %s", sourceAddr, err)
continue
}

c, err := net.ListenUDP("udp", sourceAddr)
if err != nil {
logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err)
continue
}

p.udpConnMutex.Lock()
defer p.udpConnMutex.Unlock()
p.activeUDPConns[port] = c
logrus.Debugf("created UDPConn for: %v", sourceAddr)

forwardAddr := net.JoinHostPort(p.config.UpstreamAddress, portBinding.HostPort)
targetAddr, err := net.ResolveUDPAddr("udp", forwardAddr)
if err != nil {
logrus.Errorf("failed to resolve UDP target address [%s]: %s", targetAddr, err)
continue
}
go p.acceptUDPConn(c, targetAddr)
}

}

Check failure on line 149 in src/go/networking/pkg/portproxy/server.go

View workflow job for this annotation

GitHub Actions / lint (windows-latest)

unnecessary trailing newline (whitespace)

func (p *PortProxy) acceptUDPConn(sourceConn *net.UDPConn, targetAddr *net.UDPAddr) {
targetConn, err := net.DialUDP("udp", nil, targetAddr)
if err != nil {
logrus.Errorf("failed to connect to target address: %s : %s", targetAddr, err)
return
}
defer targetConn.Close()
p.wg.Add(1)
for {
b := make([]byte, p.config.UDPBufferSize)
n, addr, err := sourceConn.ReadFromUDP(b)
if err != nil {
logrus.Errorf("error reading UDP packet from source: %s : %s", addr, err)
if errors.Is(err, net.ErrClosed) {
p.wg.Done()
break
}
continue
}
logrus.Debugf("received %d data from %s", n, addr)

n, err = targetConn.Write(b[:n])
if err != nil {
logrus.Errorf("error forwarding UDP packet to target: %s : %s", targetAddr, err)
if errors.Is(err, net.ErrClosed) {
p.wg.Done()
break
}
p.mutex.Lock()
p.activeListeners[port] = l
p.mutex.Unlock()
logrus.Debugf("created listener for: %s", addr)
go p.acceptTraffic(l, portBinding.HostPort)
continue
}
logrus.Debugf("sent %d data to %s", n, targetAddr)
}
}

func (p *PortProxy) handleTCP(portBindings []nat.PortBinding, remove bool) {
for _, portBinding := range portBindings {
port, err := nat.ParsePort(portBinding.HostPort)
if err != nil {
logrus.Errorf("parsing port error: %s", err)
continue
}
if remove {
p.listenerMutex.Lock()
if listener, exist := p.activeListeners[port]; exist {
logrus.Debugf("closing listener for port: %d", port)
if err := listener.Close(); err != nil {
logrus.Errorf("error closing listener for port [%s]: %s", portBinding.HostPort, err)
}
}
delete(p.activeListeners, port)
p.listenerMutex.Unlock()
continue
}
addr := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort)
l, err := net.Listen("tcp", addr)
if err != nil {
logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err)
continue
}
p.listenerMutex.Lock()
p.activeListeners[port] = l
p.listenerMutex.Unlock()
logrus.Debugf("created listener for: %s", addr)
go p.acceptTraffic(l, portBinding.HostPort)
}
}

func (p *PortProxy) acceptTraffic(listener net.Listener, port string) {
forwardAddr := net.JoinHostPort(p.upstreamAddress, port)
forwardAddr := net.JoinHostPort(p.config.UpstreamAddress, port)
for {
conn, err := listener.Accept()
if err != nil {
Expand All @@ -124,7 +227,7 @@ func (p *PortProxy) acceptTraffic(listener net.Listener, port string) {
logrus.Errorf("port proxy listener failed to accept: %s", err)
continue
}
logrus.Debugf("port proxy accepted connection from %s", conn.RemoteAddr())
logrus.Debugf("port proxy accepted TCP connection from %s", conn.RemoteAddr())
p.wg.Add(1)

go func(conn net.Conn) {
Expand All @@ -139,6 +242,9 @@ func (p *PortProxy) Close() error {
// Close all the active listeners
p.cleanupListeners()

// Close all active UDP connections
p.cleanupUDPConns()

// Close the listener first to prevent new connections.
err := p.listener.Close()
if err != nil {
Expand All @@ -159,3 +265,9 @@ func (p *PortProxy) cleanupListeners() {
_ = l.Close()
}
}

func (p *PortProxy) cleanupUDPConns() {
for _, c := range p.activeUDPConns {
_ = c.Close()
}
}

0 comments on commit a9fca11

Please sign in to comment.