From dd810c53429a4f6693357711831afd9a918c72ba Mon Sep 17 00:00:00 2001 From: Arnaud Vrac Date: Thu, 25 Nov 2021 22:51:36 +0100 Subject: [PATCH 1/2] Allow client to use port 0 when requesting reverse port forwarding Bind the port to forward before calling the ReversePortForwardingCallback callback, with the actual bound port instead of 0. --- tcpip.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tcpip.go b/tcpip.go index 335fda6..75b6336 100644 --- a/tcpip.go +++ b/tcpip.go @@ -108,9 +108,6 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go // TODO: log parse failure return false, []byte{} } - if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { - return false, []byte("port forwarding is disabled") - } addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) ln, err := net.Listen("tcp", addr) if err != nil { @@ -119,6 +116,10 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go } _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) destPort, _ := strconv.Atoi(destPortStr) + if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, uint32(destPort)) { + ln.Close() + return false, []byte("port forwarding is disabled") + } h.Lock() h.forwards[addr] = ln h.Unlock() From 4965f6344a01171a6004d727632d56fce1d3a63c Mon Sep 17 00:00:00 2001 From: Nicolas Escande Date: Fri, 13 Oct 2023 11:45:18 +0200 Subject: [PATCH 2/2] fix multiple ports forward requests to port 0 --- tcpip.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tcpip.go b/tcpip.go index 75b6336..390c985 100644 --- a/tcpip.go +++ b/tcpip.go @@ -113,8 +113,11 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go if err != nil { // TODO: log listen failure return false, []byte{} + } else { + // addr might not be valid anymore if bind port was 0 + addr = ln.Addr().String() } - _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) + _, destPortStr, _ := net.SplitHostPort(addr) destPort, _ := strconv.Atoi(destPortStr) if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, uint32(destPort)) { ln.Close()