Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support UDP in wsl-proxy #7618

Merged
merged 4 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/go/networking/cmd/proxy/wsl_integration_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,23 @@ var (
logFile string
socketFile string
upstreamAddr string
udpBuffer int
)

const (
defaultLogPath = "/var/log/wsl-proxy.log"
defaultSocket = "/run/wsl-proxy.sock"
bridgeIPAddr = "192.168.143.1"
// Set UDP buffer size to 8 MB
defaultUDPBufferSize = 8 * 1024 * 1024 // 8 MB in bytes
)

func main() {
flag.BoolVar(&debug, "debug", false, "enable additional debugging.")
flag.StringVar(&logFile, "logfile", defaultLogPath, "path to the logfile for wsl-proxy process")
flag.StringVar(&socketFile, "socketFile", defaultSocket, "path to the .sock file for UNIX socket")
flag.StringVar(&upstreamAddr, "upstreamAddress", bridgeIPAddr, "IP address of the upstream server to forward to")
flag.IntVar(&udpBuffer, "udpBuffer", defaultUDPBufferSize, "max buffer size in bytes for UDP socket I/O")
flag.Parse()

setupLogging(logFile)
Expand All @@ -54,7 +58,11 @@ func main() {
logrus.Fatalf("failed to create listener for published ports: %s", err)
return
}
proxy := portproxy.NewPortProxy(socket, bridgeIPAddr)
proxyConfig := &portproxy.ProxyConfig{
UpstreamAddress: upstreamAddr,
UDPBufferSize: udpBuffer,
}
proxy := portproxy.NewPortProxy(socket, proxyConfig)

// Handle graceful shutdown
sigCh := make(chan os.Signal, 1)
Expand Down
208 changes: 167 additions & 41 deletions src/go/networking/pkg/portproxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,44 @@ import (
"net"
"sync"

gvisorTypes "github.com/containers/gvisor-tap-vsock/pkg/types"
"github.com/docker/go-connections/nat"
"github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/types"
"github.com/rancher-sandbox/rancher-desktop/src/go/networking/pkg/utils"
"github.com/sirupsen/logrus"
)

type ProxyConfig struct {
UpstreamAddress string
UDPBufferSize int
}

type PortProxy struct {
upstreamAddress string
listener net.Listener
quit chan struct{}
// map of port number as a key to associated listener
config *ProxyConfig
listener net.Listener
quit chan struct{}
// map of TCP port number as a key to associated listener
activeListeners map[int]net.Listener
mutex sync.Mutex
wg sync.WaitGroup
listenerMutex sync.Mutex
// map of UDP port number as a key to associated UDPConn
activeUDPConns map[int]*net.UDPConn
udpConnMutex sync.Mutex
wg sync.WaitGroup
}

func NewPortProxy(listener net.Listener, upstreamAddr string) *PortProxy {
func NewPortProxy(listener net.Listener, cfg *ProxyConfig) *PortProxy {
portProxy := &PortProxy{
upstreamAddress: upstreamAddr,
config: cfg,
listener: listener,
quit: make(chan struct{}),
activeListeners: make(map[int]net.Listener),
activeUDPConns: make(map[int]*net.UDPConn),
}
return portProxy
}

func (p *PortProxy) Start() error {
logrus.Infof("Proxy server started accepting on %s, forwarding to %s", p.listener.Addr(), p.upstreamAddress)
logrus.Infof("Proxy server started accepting on %s, forwarding to %s", p.listener.Addr(), p.config.UpstreamAddress)
for {
conn, err := p.listener.Accept()
if err != nil {
Expand All @@ -65,6 +75,12 @@ func (p *PortProxy) Start() error {
}
}

func (p *PortProxy) UDPPortMappings() map[int]*net.UDPConn {
p.udpConnMutex.Lock()
defer p.udpConnMutex.Unlock()
return p.activeUDPConns
}

func (p *PortProxy) handleEvent(conn net.Conn) {
defer conn.Close()

Expand All @@ -73,47 +89,144 @@ func (p *PortProxy) handleEvent(conn net.Conn) {
logrus.Errorf("port server decoding received payload error: %s", err)
return
}
p.execListener(pm)
p.exec(pm)
}

func (p *PortProxy) execListener(pm types.PortMapping) {
for _, portBindings := range pm.Ports {
for _, portBinding := range portBindings {
logrus.Debugf("received the following port: [%s] from portMapping: %+v", portBinding.HostPort, pm)
port, err := nat.ParsePort(portBinding.HostPort)
if err != nil {
logrus.Errorf("parsing port error: %s", err)
continue
}
if pm.Remove {
p.mutex.Lock()
if listener, exist := p.activeListeners[port]; exist {
logrus.Debugf("closing listener for port: %d", port)
if err := listener.Close(); err != nil {
logrus.Errorf("error closing listener for port [%s]: %s", portBinding.HostPort, err)
}
func (p *PortProxy) exec(pm types.PortMapping) {
for portProto, portBindings := range pm.Ports {
logrus.Debugf("received the following port: [%s] and protocol: [%s] from portMapping: %+v", portProto.Port(), portProto.Proto(), pm)

switch gvisorTypes.TransportProtocol(portProto.Proto()) {
case gvisorTypes.TCP:
p.handleTCP(portBindings, pm.Remove)
case gvisorTypes.UDP:
p.handleUDP(portBindings, pm.Remove)
default:
logrus.Warnf("unsupported protocol: [%s]", portProto.Proto())
}
}
}

func (p *PortProxy) handleUDP(portBindings []nat.PortBinding, remove bool) {
for _, portBinding := range portBindings {
port, err := nat.ParsePort(portBinding.HostPort)
if err != nil {
logrus.Errorf("parsing port error: %s", err)
continue
}
if remove {
p.udpConnMutex.Lock()
if udpConn, exist := p.activeUDPConns[port]; exist {
if err := udpConn.Close(); err != nil {
logrus.Errorf("error closing UDPConn for port [%s]: %s", portBinding.HostPort, err)
}
delete(p.activeListeners, port)
p.mutex.Unlock()
continue
}
addr := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort)
l, err := net.Listen("tcp", addr)
if err != nil {
logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err)
continue
delete(p.activeUDPConns, port)
p.udpConnMutex.Unlock()
logrus.Debugf("closing UDPConn for port: %d", port)
continue
}

// the localAddress IP section can either be 0.0.0.0 or 127.0.0.1
localAddress := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort)
sourceAddr, err := net.ResolveUDPAddr("udp", localAddress)
if err != nil {
logrus.Errorf("failed to resolve UDP source address [%s]: %s", sourceAddr, err)
continue
}

c, err := net.ListenUDP("udp", sourceAddr)
if err != nil {
logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err)
continue
}

forwardAddr := net.JoinHostPort(p.config.UpstreamAddress, portBinding.HostPort)
targetAddr, err := net.ResolveUDPAddr("udp", forwardAddr)
if err != nil {
c.Close()
logrus.Errorf("failed to resolve UDP target address [%s]: %s", targetAddr, err)
mook-as marked this conversation as resolved.
Show resolved Hide resolved
continue
}

p.udpConnMutex.Lock()
p.activeUDPConns[port] = c
p.udpConnMutex.Unlock()
logrus.Debugf("created UDPConn for: %v", sourceAddr)

go p.acceptUDPConn(c, targetAddr)
}
}

func (p *PortProxy) acceptUDPConn(sourceConn *net.UDPConn, targetAddr *net.UDPAddr) {
targetConn, err := net.DialUDP("udp", nil, targetAddr)
if err != nil {
logrus.Errorf("failed to connect to target address: %s : %s", targetAddr, err)
return
}
defer targetConn.Close()
p.wg.Add(1)
for {
b := make([]byte, p.config.UDPBufferSize)
n, addr, err := sourceConn.ReadFromUDP(b)
if err != nil && n == 0 {
logrus.Errorf("error reading UDP packet from source: %s : %s", addr, err)
if errors.Is(err, net.ErrClosed) {
p.wg.Done()
break
}
continue
}
logrus.Debugf("received %d data from %s", n, addr)

n, err = targetConn.Write(b[:n])
if err != nil {
logrus.Errorf("error forwarding UDP packet to target: %s : %s", targetAddr, err)
if errors.Is(err, net.ErrClosed) {
p.wg.Done()
break
}
p.mutex.Lock()
p.activeListeners[port] = l
p.mutex.Unlock()
logrus.Debugf("created listener for: %s", addr)
go p.acceptTraffic(l, portBinding.HostPort)
continue
}
logrus.Debugf("sent %d data to %s", n, targetAddr)
}
}

func (p *PortProxy) handleTCP(portBindings []nat.PortBinding, remove bool) {
for _, portBinding := range portBindings {
port, err := nat.ParsePort(portBinding.HostPort)
if err != nil {
logrus.Errorf("parsing port error: %s", err)
continue
}
if remove {
p.listenerMutex.Lock()
if listener, exist := p.activeListeners[port]; exist {
logrus.Debugf("closing listener for port: %d", port)
if err := listener.Close(); err != nil {
logrus.Errorf("error closing listener for port [%s]: %s", portBinding.HostPort, err)
}
}
delete(p.activeListeners, port)
p.listenerMutex.Unlock()
continue
}
addr := net.JoinHostPort(portBinding.HostIP, portBinding.HostPort)
l, err := net.Listen("tcp", addr)
if err != nil {
logrus.Errorf("failed creating listener for published port [%s]: %s", portBinding.HostPort, err)
continue
}
p.listenerMutex.Lock()
p.activeListeners[port] = l
p.listenerMutex.Unlock()
logrus.Debugf("created listener for: %s", addr)
go p.acceptTraffic(l, portBinding.HostPort)
}
}

func (p *PortProxy) acceptTraffic(listener net.Listener, port string) {
forwardAddr := net.JoinHostPort(p.upstreamAddress, port)
forwardAddr := net.JoinHostPort(p.config.UpstreamAddress, port)
for {
conn, err := listener.Accept()
if err != nil {
Expand All @@ -124,7 +237,7 @@ func (p *PortProxy) acceptTraffic(listener net.Listener, port string) {
logrus.Errorf("port proxy listener failed to accept: %s", err)
continue
}
logrus.Debugf("port proxy accepted connection from %s", conn.RemoteAddr())
logrus.Debugf("port proxy accepted TCP connection from %s", conn.RemoteAddr())
p.wg.Add(1)

go func(conn net.Conn) {
Expand All @@ -139,6 +252,9 @@ func (p *PortProxy) Close() error {
// Close all the active listeners
p.cleanupListeners()

// Close all active UDP connections
p.cleanupUDPConns()

// Close the listener first to prevent new connections.
err := p.listener.Close()
if err != nil {
Expand All @@ -155,7 +271,17 @@ func (p *PortProxy) Close() error {
}

func (p *PortProxy) cleanupListeners() {
p.listenerMutex.Lock()
defer p.listenerMutex.Unlock()
for _, l := range p.activeListeners {
_ = l.Close()
}
}

func (p *PortProxy) cleanupUDPConns() {
p.udpConnMutex.Lock()
defer p.udpConnMutex.Unlock()
for _, c := range p.activeUDPConns {
mook-as marked this conversation as resolved.
Show resolved Hide resolved
_ = c.Close()
}
}
Loading
Loading