diff --git a/transport/wireguard/client_bind.go b/transport/wireguard/client_bind.go index d39d1bece0..39adce251e 100644 --- a/transport/wireguard/client_bind.go +++ b/transport/wireguard/client_bind.go @@ -36,6 +36,7 @@ func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, errorHandler: errorHandler, dialer: dialer, reservedForEndpoint: make(map[netip.AddrPort][3]uint8), + done: make(chan struct{}), isConnect: isConnect, connectAddr: connectAddr, reserved: reserved, @@ -88,8 +89,7 @@ func (c *ClientBind) connect() (*wireConn, error) { func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { select { case <-c.done: - err = net.ErrClosed - return + c.done = make(chan struct{}) default: } return []conn.ReceiveFunc{c.receive}, 0, nil @@ -129,16 +129,8 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) return } -func (c *ClientBind) Reset() { - common.Close(common.PtrOrNil(c.conn)) -} - func (c *ClientBind) Close() error { common.Close(common.PtrOrNil(c.conn)) - if c.done == nil { - c.done = make(chan struct{}) - return nil - } select { case <-c.done: default: @@ -165,7 +157,7 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error { } copy(b[1:4], reserved[:]) } - _, err = udpConn.WriteTo(b, M.SocksaddrFromNetIP(destination)) + _, err = udpConn.WriteToUDPAddrPort(b, destination) if err != nil { udpConn.Close() return err @@ -192,10 +184,18 @@ func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved type wireConn struct { net.PacketConn + conn net.Conn access sync.Mutex done chan struct{} } +func (w *wireConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + if w.conn != nil { + return w.conn.Write(b) + } + return w.PacketConn.WriteTo(b, M.SocksaddrFromNetIP(addr).UDPAddr()) +} + func (w *wireConn) Close() error { w.access.Lock() defer w.access.Unlock()