From c5c692f9b349111ae86975fa36f485c65b467bce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 6 Dec 2023 20:16:37 +0800 Subject: [PATCH] Implementation read waiter for socks5 UDP and UoT --- common/uot/conn.go | 37 +++++++++++++++++++++++++++ protocol/socks/client.go | 3 ++- protocol/socks/packet.go | 10 ++++++++ protocol/socks/packet_wait.go | 48 +++++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 protocol/socks/packet_wait.go diff --git a/common/uot/conn.go b/common/uot/conn.go index fd7d89865..cf289d59a 100644 --- a/common/uot/conn.go +++ b/common/uot/conn.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" "net" + "os" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -13,11 +14,17 @@ import ( N "github.com/sagernet/sing/common/network" ) +var ( + _ N.NetPacketConn = (*Conn)(nil) + _ N.PacketReadWaiter = (*Conn)(nil) +) + type Conn struct { net.Conn isConnect bool destination M.Socksaddr writer N.VectorisedWriter + newBuffer func() *buf.Buffer } func NewConn(conn net.Conn, request Request) *Conn { @@ -141,6 +148,36 @@ func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.writer.WriteVectorised([]*buf.Buffer{header, buffer}) } +func (c *Conn) InitializeReadWaiter(newBuffer func() *buf.Buffer) { + c.newBuffer = newBuffer +} + +func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + if c.newBuffer == nil { + return nil, M.Socksaddr{}, os.ErrInvalid + } + if c.isConnect { + destination = c.destination + } else { + destination, err = AddrParser.ReadAddrPort(c.Conn) + if err != nil { + return + } + } + var length uint16 + err = binary.Read(c.Conn, binary.BigEndian, &length) + if err != nil { + return + } + buffer = c.newBuffer() + _, err = buffer.ReadFullFrom(c.Conn, int(length)) + if err != nil { + buffer.Release() + return nil, M.Socksaddr{}, E.Cause(err, "UoT read") + } + return +} + func (c *Conn) NeedAdditionalReadDeadline() bool { return true } diff --git a/protocol/socks/client.go b/protocol/socks/client.go index 45c0d9033..fd0a34db8 100644 --- a/protocol/socks/client.go +++ b/protocol/socks/client.go @@ -7,6 +7,7 @@ import ( "os" "strings" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -147,7 +148,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock tcpConn.Close() return nil, err } - return NewAssociateConn(udpConn, address, tcpConn), nil + return NewAssociatePacketConn(bufio.NewUnbindPacketConn(udpConn), address, tcpConn), nil } return nil, os.ErrInvalid } diff --git a/protocol/socks/packet.go b/protocol/socks/packet.go index 4df672c37..21860ea2e 100644 --- a/protocol/socks/packet.go +++ b/protocol/socks/packet.go @@ -7,6 +7,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) @@ -17,6 +18,8 @@ import ( // | 2 | 1 | 1 | Variable | 2 | Variable | // +----+------+------+----------+----------+----------+ +var ErrInvalidPacket = E.New("socks5: invalid packet") + type AssociatePacketConn struct { N.NetPacketConn remoteAddr M.Socksaddr @@ -31,6 +34,7 @@ func NewAssociatePacketConn(conn net.PacketConn, remoteAddr M.Socksaddr, underly } } +// Deprecated: NewAssociatePacketConn(bufio.NewUnbindPacketConn(conn), remoteAddr, underlying) instead. func NewAssociateConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn { return &AssociatePacketConn{ NetPacketConn: bufio.NewUnbindPacketConn(conn), @@ -49,6 +53,9 @@ func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err erro if err != nil { return } + if n < 3 { + return 0, nil, ErrInvalidPacket + } c.remoteAddr = M.SocksaddrFromNet(addr) reader := bytes.NewReader(p[3:n]) destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) @@ -92,6 +99,9 @@ func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Sock if err != nil { return M.Socksaddr{}, err } + if buffer.Len() < 3 { + return M.Socksaddr{}, ErrInvalidPacket + } c.remoteAddr = destination buffer.Advance(3) destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) diff --git a/protocol/socks/packet_wait.go b/protocol/socks/packet_wait.go new file mode 100644 index 000000000..9b9047b6e --- /dev/null +++ b/protocol/socks/packet_wait.go @@ -0,0 +1,48 @@ +package socks + +import ( + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var _ N.PacketReadWaitCreator = (*AssociatePacketConn)(nil) + +func (c *AssociatePacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) { + readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(c.NetPacketConn) + if !isReadWaiter { + return nil, false + } + return &AssociatePacketReadWaiter{c, readWaiter}, true +} + +var _ N.PacketReadWaiter = (*AssociatePacketReadWaiter)(nil) + +type AssociatePacketReadWaiter struct { + conn *AssociatePacketConn + readWaiter N.PacketReadWaiter +} + +func (w *AssociatePacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { + w.readWaiter.InitializeReadWaiter(newBuffer) +} + +func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + buffer, destination, err = w.readWaiter.WaitReadPacket() + if err != nil { + return + } + if buffer.Len() < 3 { + buffer.Release() + return nil, M.Socksaddr{}, ErrInvalidPacket + } + w.conn.remoteAddr = destination + buffer.Advance(3) + destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + buffer.Release() + return nil, M.Socksaddr{}, err + } + return +}