diff --git a/channels.go b/channels.go index f7d3931..a846d0f 100644 --- a/channels.go +++ b/channels.go @@ -122,7 +122,7 @@ func handleAlias(newChannel ssh.NewChannel, sshConn *SSHConnection, state *State sshConn.Listeners.Store(conn.RemoteAddr(), nil) - copyBoth(conn, connection, false) + copyBoth(conn, connection, true) sshConn.CleanUp(state) } diff --git a/requests.go b/requests.go index c58f664..b586827 100644 --- a/requests.go +++ b/requests.go @@ -227,7 +227,7 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *SSHConnection, state } } - go copyBoth(cl, newChan, false) + go copyBoth(cl, newChan, true) go ssh.DiscardRequests(newReqs) } } @@ -245,30 +245,29 @@ func copyBoth(writer net.Conn, reader ssh.Channel, wait bool) { if wait { wg.Add(1) defer wg.Done() - } else { - defer closeBoth() } _, err := io.Copy(writer, reader) if err != nil && *debug { - log.Println("Error writing to reader:", err) + log.Println("Error writing to writer:", err) } }() - if wait { - wg.Add(1) - } else { - defer closeBoth() - } + go func() { + if wait { + wg.Add(1) + defer wg.Done() + } + + _, err := io.Copy(reader, writer) + if err != nil && *debug { + log.Println("Error writing to reader:", err) + } + }() - _, err := io.Copy(reader, writer) - if err != nil && *debug { - log.Println("Error writing to writer:", err) - } if wait { - wg.Done() + wg.Wait() } - wg.Wait() closeBoth() }