From 06e36384832721ef4fe406ff47cf5ee308180e57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 7 Dec 2023 10:22:29 +0800 Subject: [PATCH] Add reserve support for buffer --- common/buf/buffer.go | 80 +++++++++++++++++-------------- common/bufio/copy.go | 33 ++++++------- common/bufio/copy_direct_posix.go | 14 +++--- common/network/direct.go | 25 ++++++---- common/pipe/pipe_wait.go | 7 ++- common/udpnat/conn_wait.go | 7 ++- common/uot/conn_wait.go | 7 ++- 7 files changed, 90 insertions(+), 83 deletions(-) diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 026b8a5e4..d22994642 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -4,7 +4,6 @@ import ( "crypto/rand" "io" "net" - "strconv" "sync/atomic" "github.com/sagernet/sing/common" @@ -17,14 +16,15 @@ type Buffer struct { data []byte start int end int + length int refs atomic.Int32 managed bool - closed bool } func New() *Buffer { return &Buffer{ data: Get(BufferSize), + length: BufferSize, managed: true, } } @@ -32,6 +32,7 @@ func New() *Buffer { func NewPacket() *Buffer { return &Buffer{ data: Get(UDPBufferSize), + length: UDPBufferSize, managed: true, } } @@ -41,40 +42,29 @@ func NewSize(size int) *Buffer { return &Buffer{} } else if size > 65535 { return &Buffer{ - data: make([]byte, size), + data: make([]byte, size), + length: size, } } return &Buffer{ data: Get(size), + length: size, managed: true, } } -// Deprecated: use New instead. -func StackNew() *Buffer { - return New() -} - -// Deprecated: use NewPacket instead. -func StackNewPacket() *Buffer { - return NewPacket() -} - -// Deprecated: use NewSize instead. -func StackNewSize(size int) *Buffer { - return NewSize(size) -} - func As(data []byte) *Buffer { return &Buffer{ - data: data, - end: len(data), + data: data, + end: len(data), + length: len(data), } } func With(data []byte) *Buffer { return &Buffer{ - data: data, + data: data, + length: len(data), } } @@ -88,8 +78,8 @@ func (b *Buffer) SetByte(index int, value byte) { func (b *Buffer) Extend(n int) []byte { end := b.end + n - if end > cap(b.data) { - panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n)) + if end > b.length { + panic(F.ToString("buffer overflow: length ", b.length, ",end ", b.end, ", need ", n)) } ext := b.data[b.end:end] b.end = end @@ -111,14 +101,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) { if b.IsFull() { return 0, io.ErrShortBuffer } - n = copy(b.data[b.end:], data) + n = copy(b.data[b.end:b.length], data) b.end += n return } func (b *Buffer) ExtendHeader(n int) []byte { if b.start < n { - panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n)) + panic(F.ToString("buffer overflow: length ", b.length, ",start ", b.start, ", need ", n)) } b.start -= n return b.data[b.start : b.start+n] @@ -171,7 +161,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) { } func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) { - if b.end+size > b.Cap() { + if b.end+size > b.length { return 0, io.ErrShortBuffer } n, err = io.ReadFull(r, b.data[b.end:b.end+size]) @@ -208,7 +198,7 @@ func (b *Buffer) WriteString(s string) (n int, err error) { if b.IsFull() { return 0, io.ErrShortBuffer } - n = copy(b.data[b.end:], s) + n = copy(b.data[b.end:b.length], s) b.end += n return } @@ -223,7 +213,7 @@ func (b *Buffer) WriteZero() error { } func (b *Buffer) WriteZeroN(n int) error { - if b.end+n > b.Cap() { + if b.end+n > b.length { return io.ErrShortBuffer } for i := b.end; i <= b.end+n; i++ { @@ -272,9 +262,24 @@ func (b *Buffer) Resize(start, end int) { b.end = b.start + end } +func (b *Buffer) Reserve(n int) { + if n > b.length { + panic(F.ToString("buffer overflow: length ", b.length, ", need ", n)) + } + b.length -= n +} + +func (b *Buffer) OverLength(n int) { + if b.length+n > len(b.data) { + panic(F.ToString("buffer overflow: length ", len(b.data), ", need ", b.length+n)) + } + b.length += n +} + func (b *Buffer) Reset() { b.start = 0 b.end = 0 + b.length = len(b.data) } // Deprecated: use Reset instead. @@ -291,19 +296,19 @@ func (b *Buffer) DecRef() { } func (b *Buffer) Release() { - if b == nil || b.closed || !b.managed { + if b == nil || !b.managed { return } if b.refs.Load() > 0 { return } common.Must(Put(b.data)) - *b = Buffer{closed: true} + *b = Buffer{} } func (b *Buffer) Leak() { if debug.Enabled { - if b == nil || b.closed || !b.managed { + if b == nil || !b.managed { return } refs := b.refs.Load() @@ -319,7 +324,7 @@ func (b *Buffer) Leak() { func (b *Buffer) Cut(start int, end int) *Buffer { b.start += start - b.end = len(b.data) - end + b.end = b.length - end return &Buffer{ data: b.data[b.start:b.end], } @@ -334,7 +339,7 @@ func (b *Buffer) Len() int { } func (b *Buffer) Cap() int { - return len(b.data) + return b.length } func (b *Buffer) Bytes() []byte { @@ -342,7 +347,7 @@ func (b *Buffer) Bytes() []byte { } func (b *Buffer) Slice() []byte { - return b.data + return b.data[:b.length] } func (b *Buffer) From(n int) []byte { @@ -362,11 +367,11 @@ func (b *Buffer) Index(start int) []byte { } func (b *Buffer) FreeLen() int { - return b.Cap() - b.end + return b.length - b.end } func (b *Buffer) FreeBytes() []byte { - return b.data[b.end:b.Cap()] + return b.data[b.end:b.length] } func (b *Buffer) IsEmpty() bool { @@ -374,7 +379,7 @@ func (b *Buffer) IsEmpty() bool { } func (b *Buffer) IsFull() bool { - return b.end == b.Cap() + return b.end == b.length } func (b *Buffer) ToOwned() *Buffer { @@ -382,5 +387,6 @@ func (b *Buffer) ToOwned() *Buffer { copy(n.data[b.start:b.end], b.data[b.start:b.end]) n.start = b.start n.end = b.end + n.length = b.length return n } diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 53c6e2ec3..d42231356 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -81,12 +81,11 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so defer buffer.DecRef() frontHeadroom := N.CalculateFrontHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination) - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) + buffer.Resize(frontHeadroom, 0) + buffer.Reserve(rearHeadroom) var notFirstTime bool for { - readBuffer.Resize(frontHeadroom, 0) - err = source.ReadBuffer(readBuffer) + err = source.ReadBuffer(buffer) if err != nil { if errors.Is(err, io.EOF) { err = nil @@ -94,8 +93,8 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so } return } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) + dataLen := buffer.Len() + buffer.OverLength(rearHeadroom) err = destination.WriteBuffer(buffer) if err != nil { if !notFirstTime { @@ -126,10 +125,9 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, var notFirstTime bool for { buffer := buf.NewSize(bufferSize) - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) - readBuffer.Resize(frontHeadroom, 0) - err = source.ReadBuffer(readBuffer) + buffer.Resize(frontHeadroom, 0) + buffer.Reserve(rearHeadroom) + err = source.ReadBuffer(buffer) if err != nil { buffer.Release() if errors.Is(err, io.EOF) { @@ -138,8 +136,8 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, } return } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) + dataLen := buffer.Len() + buffer.OverLength(rearHeadroom) err = destination.WriteBuffer(buffer) if err != nil { buffer.Leak() @@ -263,16 +261,15 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri var destination M.Socksaddr for { buffer := buf.NewSize(bufferSize) - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) - readBuffer.Resize(frontHeadroom, 0) - destination, err = source.ReadPacket(readBuffer) + buffer.Resize(frontHeadroom, 0) + buffer.Reserve(rearHeadroom) + destination, err = source.ReadPacket(buffer) if err != nil { buffer.Release() return } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) + dataLen := buffer.Len() + buffer.OverLength(rearHeadroom) err = destinationConn.WritePacket(buffer, destination) if err != nil { buffer.Leak() diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index 07506789c..5ef83ec0e 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -104,11 +104,11 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { w.options = options w.readFunc = func(fd uintptr) (done bool) { - buffer, readBuffer := w.options.NewBuffer() + buffer := w.options.NewBuffer() var readN int - readN, w.readErr = syscall.Read(int(fd), readBuffer.FreeBytes()) + readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes()) if readN > 0 { - buffer.Resize(readBuffer.Start(), readN) + buffer.Truncate(readN) } else { buffer.Release() buffer = nil @@ -119,6 +119,7 @@ func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (nee if readN == 0 { w.readErr = io.EOF } + w.options.PostReturn(buffer) w.buffer = buffer return true } @@ -168,12 +169,12 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { w.options = options w.readFunc = func(fd uintptr) (done bool) { - buffer, readBuffer := w.options.NewPacketBuffer() + buffer := w.options.NewPacketBuffer() var readN int var from syscall.Sockaddr - readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), readBuffer.FreeBytes(), nil, 0) + readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0) if readN > 0 { - buffer.Resize(readBuffer.Start(), readN) + buffer.Truncate(readN) } else { buffer.Release() buffer = nil @@ -189,6 +190,7 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap() } } + w.options.PostReturn(buffer) w.buffer = buffer return true } diff --git a/common/network/direct.go b/common/network/direct.go index 7a28eca6a..4c0a629d5 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -19,30 +19,35 @@ func (o ReadWaitOptions) NeedHeadroom() bool { return o.FrontHeadroom > 0 || o.RearHeadroom > 0 } -func (o ReadWaitOptions) NewBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) { +func (o ReadWaitOptions) NewBuffer() *buf.Buffer { return o.newBuffer(buf.BufferSize) } -func (o ReadWaitOptions) NewPacketBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) { +func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer { return o.newBuffer(buf.UDPBufferSize) } -func (o ReadWaitOptions) newBuffer(defaultBufferSize int) (buffer *buf.Buffer, readBuffer *buf.Buffer) { +func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer { var bufferSize int if o.MTU > 0 { bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom } else { bufferSize = defaultBufferSize } - buffer = buf.NewSize(bufferSize) + buffer := buf.NewSize(bufferSize) + if o.FrontHeadroom > 0 { + buffer.Resize(o.FrontHeadroom, 0) + } if o.RearHeadroom > 0 { - readBufferRaw := buffer.Slice() - readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-o.RearHeadroom]) - } else { - readBuffer = buffer + buffer.Reserve(o.RearHeadroom) + } + return buffer +} + +func (o ReadWaitOptions) PostReturn(buffer *buf.Buffer) { + if o.RearHeadroom > 0 { + buffer.OverLength(o.RearHeadroom) } - readBuffer.Resize(o.FrontHeadroom, 0) - return } type ReadWaiter interface { diff --git a/common/pipe/pipe_wait.go b/common/pipe/pipe_wait.go index e1a241351..409984060 100644 --- a/common/pipe/pipe_wait.go +++ b/common/pipe/pipe_wait.go @@ -33,17 +33,16 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) { case isClosedChan(p.readDeadline.wait()): return nil, os.ErrDeadlineExceeded } - var readBuffer *buf.Buffer select { case bw := <-p.rdRx: - buffer, readBuffer = p.readWaitOptions.NewBuffer() + buffer = p.readWaitOptions.NewBuffer() var nr int - nr, err = readBuffer.Write(bw) + nr, err = buffer.Write(bw) if err != nil { buffer.Release() return } - buffer.Resize(readBuffer.Start(), readBuffer.Len()) + p.readWaitOptions.PostReturn(buffer) p.rdTx <- nr return case <-p.localDone: diff --git a/common/udpnat/conn_wait.go b/common/udpnat/conn_wait.go index 5dde5315b..2e6d741dc 100644 --- a/common/udpnat/conn_wait.go +++ b/common/udpnat/conn_wait.go @@ -19,15 +19,14 @@ func (c *conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er select { case p := <-c.data: if c.readWaitOptions.NeedHeadroom() { - var readBuffer *buf.Buffer - buffer, readBuffer = c.readWaitOptions.NewPacketBuffer() - _, err = readBuffer.Write(p.data.Bytes()) + buffer = c.readWaitOptions.NewPacketBuffer() + _, err = buffer.Write(p.data.Bytes()) if err != nil { buffer.Release() return } + c.readWaitOptions.PostReturn(buffer) p.data.Release() - buffer.Resize(readBuffer.Start(), readBuffer.Len()) } else { buffer = p.data } diff --git a/common/uot/conn_wait.go b/common/uot/conn_wait.go index eecb3601c..7341c882a 100644 --- a/common/uot/conn_wait.go +++ b/common/uot/conn_wait.go @@ -28,13 +28,12 @@ func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er if err != nil { return } - var readBuffer *buf.Buffer - buffer, readBuffer = c.readWaitOptions.NewPacketBuffer() - _, err = readBuffer.ReadFullFrom(c.Conn, int(length)) + buffer = c.readWaitOptions.NewPacketBuffer() + _, err = buffer.ReadFullFrom(c.Conn, int(length)) if err != nil { buffer.Release() return nil, M.Socksaddr{}, E.Cause(err, "UoT read") } - buffer.Resize(readBuffer.Start(), readBuffer.Len()) + c.readWaitOptions.PostReturn(buffer) return }