From 25b4d58d791c953d2fa927d682832de3814af030 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 6 Dec 2023 21:25:48 +0800 Subject: [PATCH] Merge ThreadSafeReader into ReadWaiter interface --- common/bufio/bind_wait.go | 8 ++-- common/bufio/copy.go | 14 +++++- common/bufio/copy_direct_posix.go | 78 ++++++++++++++++++++++--------- common/network/direct.go | 8 +++- common/network/thread.go | 6 +++ common/pipe/pipe_wait.go | 3 +- common/uot/conn.go | 3 +- protocol/socks/packet_wait.go | 5 +- 8 files changed, 90 insertions(+), 35 deletions(-) diff --git a/common/bufio/bind_wait.go b/common/bufio/bind_wait.go index 724a76e11..4834c4548 100644 --- a/common/bufio/bind_wait.go +++ b/common/bufio/bind_wait.go @@ -12,8 +12,8 @@ type BindPacketReadWaiter struct { readWaiter N.PacketReadWaiter } -func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - w.readWaiter.InitializeReadWaiter(newBuffer) +func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer, needHeadroom bool) bool { + return w.readWaiter.InitializeReadWaiter(newBuffer, needHeadroom) } func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { @@ -28,8 +28,8 @@ type UnbindPacketReadWaiter struct { addr M.Socksaddr } -func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - w.readWaiter.InitializeReadWaiter(newBuffer) +func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer, needHeadroom bool) bool { + return w.readWaiter.InitializeReadWaiter(newBuffer, needHeadroom) } func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 3bdb164ba..2d1946b5d 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -57,15 +57,19 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { } func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { + //nolint:staticcheck + //goland:noinspection GoDeprecation safeSrc := N.IsSafeReader(source) headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination) if safeSrc != nil { if headroom == 0 { + //nolint:staticcheck + //goland:noinspection GoDeprecation return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters) } } readWaiter, isReadWaiter := CreateReadWaiter(source) - if isReadWaiter { + if isReadWaiter && (readWaiter.InitializeReadWaiter(nil, headroom > 0) || headroom == 0 || common.LowMemory) { var handled bool handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters) if handled { @@ -113,6 +117,7 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so } } +// Deprecated: Use ReadWaiter interface instead. func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { var notFirstTime bool for { @@ -256,6 +261,8 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, return } } + //nolint:staticcheck + //goland:noinspection GoDeprecation safeSrc := N.IsSafePacketReader(source) frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) @@ -263,6 +270,8 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, if safeSrc != nil { if headroom == 0 { var copyN int64 + //nolint:staticcheck + //goland:noinspection GoDeprecation copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0) n += copyN return @@ -273,7 +282,7 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, copeN int64 ) readWaiter, isReadWaiter := CreatePacketReadWaiter(source) - if isReadWaiter { + if isReadWaiter && (readWaiter.InitializeReadWaiter(nil, headroom > 0) || headroom == 0 || common.LowMemory) { handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) if handled { n += copeN @@ -285,6 +294,7 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, return } +// Deprecated: Use PacketReadWaiter interface instead. func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { var buffer *buf.Buffer var destination M.Socksaddr diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index 06da27d0c..25bec174d 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -19,6 +19,7 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour handled = true frontHeadroom := N.CalculateFrontHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination) + needHeadroom := frontHeadroom > 0 || rearHeadroom > 0 bufferSize := N.CalculateMTU(source, destination) if bufferSize > 0 { bufferSize += frontHeadroom + rearHeadroom @@ -27,31 +28,45 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour } var ( buffer *buf.Buffer - readBuffer *buf.Buffer + resultBuffer *buf.Buffer notFirstTime bool ) - source.InitializeReadWaiter(func() *buf.Buffer { + externalBuffer := source.InitializeReadWaiter(func() *buf.Buffer { + if buffer != nil { + buffer.Release() + } buffer = buf.NewSize(bufferSize) readBufferRaw := buffer.Slice() - readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) + readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) readBuffer.Resize(frontHeadroom, 0) return readBuffer - }) - defer source.InitializeReadWaiter(nil) + }, needHeadroom) + defer source.InitializeReadWaiter(nil, false) for { - _, err = source.WaitReadBuffer() + resultBuffer, err = source.WaitReadBuffer() if err != nil { + if buffer != nil { + buffer.Release() + } if errors.Is(err, io.EOF) { err = nil return } return } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) - err = destination.WriteBuffer(buffer) + dataLen := resultBuffer.Len() + if externalBuffer { + err = destination.WriteBuffer(resultBuffer) + } else { + buffer.Resize(resultBuffer.Start(), dataLen) + err = destination.WriteBuffer(buffer) + } if err != nil { - buffer.Release() + if externalBuffer { + resultBuffer.Release() + } else { + buffer.Release() + } if !notFirstTime { err = N.ReportHandshakeFailure(originSource, err) } @@ -72,6 +87,7 @@ func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.Packe handled = true frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) + needHeadroom := frontHeadroom > 0 || rearHeadroom > 0 bufferSize := N.CalculateMTU(source, destinationConn) if bufferSize > 0 { bufferSize += frontHeadroom + rearHeadroom @@ -79,28 +95,42 @@ func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.Packe bufferSize = buf.UDPBufferSize } var ( - buffer *buf.Buffer - readBuffer *buf.Buffer - destination M.Socksaddr + buffer *buf.Buffer + resultBuffer *buf.Buffer + destination M.Socksaddr ) - source.InitializeReadWaiter(func() *buf.Buffer { + externalBuffer := source.InitializeReadWaiter(func() *buf.Buffer { + if buffer != nil { + buffer.Release() + } buffer = buf.NewSize(bufferSize) readBufferRaw := buffer.Slice() - readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) + readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) readBuffer.Resize(frontHeadroom, 0) return readBuffer - }) - defer source.InitializeReadWaiter(nil) + }, needHeadroom) + defer source.InitializeReadWaiter(nil, false) for { _, destination, err = source.WaitReadPacket() if err != nil { + if buffer != nil { + buffer.Release() + } return } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) - err = destinationConn.WritePacket(buffer, destination) + dataLen := resultBuffer.Len() + if externalBuffer { + err = destinationConn.WritePacket(resultBuffer, destination) + } else { + buffer.Resize(resultBuffer.Start(), dataLen) + err = destinationConn.WritePacket(buffer, destination) + } if err != nil { - buffer.Release() + if externalBuffer { + resultBuffer.Release() + } else { + buffer.Release() + } if !notFirstTime { err = N.ReportHandshakeFailure(originSource, err) } @@ -136,7 +166,7 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { return nil, false } -func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { +func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer, _ bool) bool { w.readErr = nil if newBuffer == nil { w.readFunc = nil @@ -161,6 +191,7 @@ func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { return true } } + return true } func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { @@ -202,7 +233,7 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) return nil, false } -func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { +func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer, _ bool) bool { w.readErr = nil w.readFrom = M.Socksaddr{} if newBuffer == nil { @@ -234,6 +265,7 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buf return true } } + return true } func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { diff --git a/common/network/direct.go b/common/network/direct.go index b64567645..a709b99bc 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -5,8 +5,12 @@ import ( M "github.com/sagernet/sing/common/metadata" ) +type ReadWaitable interface { + InitializeReadWaiter(newBuffer func() *buf.Buffer, needHeadroom bool) (externalBuffer bool) +} + type ReadWaiter interface { - InitializeReadWaiter(newBuffer func() *buf.Buffer) + ReadWaitable WaitReadBuffer() (buffer *buf.Buffer, err error) } @@ -15,7 +19,7 @@ type ReadWaitCreator interface { } type PacketReadWaiter interface { - InitializeReadWaiter(newBuffer func() *buf.Buffer) + ReadWaitable WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) } diff --git a/common/network/thread.go b/common/network/thread.go index a492fdd06..58ccebbb6 100644 --- a/common/network/thread.go +++ b/common/network/thread.go @@ -10,11 +10,15 @@ type ThreadUnsafeWriter interface { WriteIsThreadUnsafe() } +// Deprecated: Use ReadWaiter interface instead. type ThreadSafeReader interface { + // Deprecated: Use ReadWaiter interface instead. ReadBufferThreadSafe() (buffer *buf.Buffer, err error) } +// Deprecated: Use ReadWaiter interface instead. type ThreadSafePacketReader interface { + // Deprecated: Use ReadWaiter interface instead. ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error) } @@ -23,6 +27,7 @@ func IsUnsafeWriter(writer any) bool { return isUnsafe } +// Deprecated: Use ReadWaiter interface instead. func IsSafeReader(reader any) ThreadSafeReader { if safeReader, isSafe := reader.(ThreadSafeReader); isSafe { return safeReader @@ -39,6 +44,7 @@ func IsSafeReader(reader any) ThreadSafeReader { return nil } +// Deprecated: Use ReadWaiter interface instead. func IsSafePacketReader(reader any) ThreadSafePacketReader { if safeReader, isSafe := reader.(ThreadSafePacketReader); isSafe { return safeReader diff --git a/common/pipe/pipe_wait.go b/common/pipe/pipe_wait.go index 27ea27b02..7ad87b6fa 100644 --- a/common/pipe/pipe_wait.go +++ b/common/pipe/pipe_wait.go @@ -11,8 +11,9 @@ import ( var _ N.ReadWaiter = (*pipe)(nil) -func (p *pipe) InitializeReadWaiter(newBuffer func() *buf.Buffer) { +func (p *pipe) InitializeReadWaiter(newBuffer func() *buf.Buffer, _ bool) bool { p.newBuffer = newBuffer + return true } func (p *pipe) WaitReadBuffer() (buffer *buf.Buffer, err error) { diff --git a/common/uot/conn.go b/common/uot/conn.go index cf289d59a..9d1beb0f1 100644 --- a/common/uot/conn.go +++ b/common/uot/conn.go @@ -148,8 +148,9 @@ func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.writer.WriteVectorised([]*buf.Buffer{header, buffer}) } -func (c *Conn) InitializeReadWaiter(newBuffer func() *buf.Buffer) { +func (c *Conn) InitializeReadWaiter(newBuffer func() *buf.Buffer, _ bool) bool { c.newBuffer = newBuffer + return true } func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { diff --git a/protocol/socks/packet_wait.go b/protocol/socks/packet_wait.go index 9b9047b6e..413a29586 100644 --- a/protocol/socks/packet_wait.go +++ b/protocol/socks/packet_wait.go @@ -24,8 +24,9 @@ type AssociatePacketReadWaiter struct { readWaiter N.PacketReadWaiter } -func (w *AssociatePacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - w.readWaiter.InitializeReadWaiter(newBuffer) +func (w *AssociatePacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer, needHeadroom bool) bool { + w.readWaiter.InitializeReadWaiter(newBuffer, needHeadroom) + return true } func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {