From 3f87c8370d73ee18868bad5e0bc2ccc550b81f81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 14 Nov 2024 19:45:41 +0800 Subject: [PATCH] Fix HandshakeFailure usages --- common/network/dialer.go | 3 ++- common/network/handshake.go | 8 +++++++- protocol/socks/lazy.go | 5 +++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/common/network/dialer.go b/common/network/dialer.go index 7c3f3e26..6b66d691 100644 --- a/common/network/dialer.go +++ b/common/network/dialer.go @@ -2,6 +2,7 @@ package network import ( "context" + "github.com/sagernet/sing/common/buf" "net" "net/netip" @@ -14,7 +15,7 @@ type Dialer interface { } type PayloadDialer interface { - DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr, payload [][]byte) (net.Conn, error) + DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr, payloads []*buf.Buffer) (net.Conn, error) } type ParallelDialer interface { diff --git a/common/network/handshake.go b/common/network/handshake.go index 5f13492b..d2203e05 100644 --- a/common/network/handshake.go +++ b/common/network/handshake.go @@ -36,7 +36,13 @@ func ReportHandshakeFailure(reporter any, err error) error { func CloseOnHandshakeFailure(reporter any, onClose CloseHandlerFunc, err error) error { if err != nil { if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn { - err = E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error { + hErr := handshakeConn.HandshakeFailure(err) + err = E.Append(err, hErr, func(err error) error { + if closer, isCloser := reporter.(io.Closer); isCloser { + err = E.Append(err, closer.Close(), func(err error) error { + return E.Cause(err, "close") + }) + } return E.Cause(err, "write handshake failure") }) } else { diff --git a/protocol/socks/lazy.go b/protocol/socks/lazy.go index 34689814..f98ac3d5 100644 --- a/protocol/socks/lazy.go +++ b/protocol/socks/lazy.go @@ -2,6 +2,7 @@ package socks import ( "net" + "os" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -48,7 +49,7 @@ func (c *LazyConn) ConnHandshakeSuccess(conn net.Conn) error { func (c *LazyConn) HandshakeFailure(err error) error { if c.responseWritten { - return nil + return os.ErrInvalid } defer func() { c.responseWritten = true @@ -130,7 +131,7 @@ func (c *LazyAssociatePacketConn) HandshakeSuccess() error { func (c *LazyAssociatePacketConn) HandshakeFailure(err error) error { if c.responseWritten { - return nil + return os.ErrInvalid } defer func() { c.responseWritten = true