diff --git a/client.go b/client.go index a65a6613e9..414a3dfd26 100644 --- a/client.go +++ b/client.go @@ -937,7 +937,7 @@ func (cl *Client) initiateHandshakes(ctx context.Context, c *PeerConn, t *Torren // If we're sending the v1 infohash, and we know the v2 infohash, set the v2 upgrade bit. This // means the peer can send the v2 infohash in the handshake to upgrade the connection. localReservedBits.SetBit(pp.ExtensionBitV2Upgrade, g.Some(handshakeIh) == t.infoHash && t.infoHashV2.Ok) - ih, err := cl.connBtHandshake(c, &handshakeIh, localReservedBits) + ih, err := cl.connBtHandshake(context.TODO(), c, &handshakeIh, localReservedBits) if err != nil { return fmt.Errorf("bittorrent protocol handshake: %w", err) } @@ -1015,7 +1015,7 @@ func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) { err = errors.New("connection does not have required header obfuscation") return } - ih, err := cl.connBtHandshake(c, nil, cl.config.Extensions) + ih, err := cl.connBtHandshake(context.TODO(), c, nil, cl.config.Extensions) if err != nil { return nil, fmt.Errorf("during bt handshake: %w", err) } @@ -1039,8 +1039,8 @@ func init() { &successfulPeerWireProtocolHandshakePeerReservedBytes) } -func (cl *Client) connBtHandshake(c *PeerConn, ih *metainfo.Hash, reservedBits PeerExtensionBits) (ret metainfo.Hash, err error) { - res, err := pp.Handshake(c.rw(), ih, cl.peerID, reservedBits) +func (cl *Client) connBtHandshake(ctx context.Context, c *PeerConn, ih *metainfo.Hash, reservedBits PeerExtensionBits) (ret metainfo.Hash, err error) { + res, err := pp.Handshake(ctx, c.rw(), ih, cl.peerID, reservedBits) if err != nil { return } diff --git a/go.mod b/go.mod index 479c24105e..05ca5af7d5 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/RoaringBitmap/roaring v1.2.3 github.com/ajwerner/btree v0.0.0-20211221152037-f427b3e689c0 github.com/alexflint/go-arg v1.4.3 - github.com/anacrolix/bargle v0.0.0-20220630015206-d7a4d433886a + github.com/anacrolix/bargle v0.0.0-20221014000746-4f2739072e9d github.com/anacrolix/chansync v0.4.1-0.20240627045151-1aa1ac392fe8 github.com/anacrolix/dht/v2 v2.19.2-0.20221121215055-066ad8494444 github.com/anacrolix/envpprof v1.3.0 diff --git a/go.sum b/go.sum index 7e27e621b5..14210a02e5 100644 --- a/go.sum +++ b/go.sum @@ -68,6 +68,8 @@ github.com/anacrolix/backtrace v0.0.0-20221205112523-22a61db8f82e h1:A0Ty9UeyBDI github.com/anacrolix/backtrace v0.0.0-20221205112523-22a61db8f82e/go.mod h1:4YFqy+788tLJWtin2jNliYVJi+8aDejG9zcu/2/pONw= github.com/anacrolix/bargle v0.0.0-20220630015206-d7a4d433886a h1:KCP9QvHlLoUQBOaTf/YCuOzG91Ym1cPB6S68O4Q3puo= github.com/anacrolix/bargle v0.0.0-20220630015206-d7a4d433886a/go.mod h1:9xUiZbkh+94FbiIAL1HXpAIBa832f3Mp07rRPl5c5RQ= +github.com/anacrolix/bargle v0.0.0-20221014000746-4f2739072e9d h1:ypNOsIwvdumNRlqWj/hsnLs5TyQWQOylwi+T9Qs454A= +github.com/anacrolix/bargle v0.0.0-20221014000746-4f2739072e9d/go.mod h1:9xUiZbkh+94FbiIAL1HXpAIBa832f3Mp07rRPl5c5RQ= github.com/anacrolix/chansync v0.4.1-0.20240627045151-1aa1ac392fe8 h1:eyb0bBaQKMOh5Se/Qg54shijc8K4zpQiOjEhKFADkQM= github.com/anacrolix/chansync v0.4.1-0.20240627045151-1aa1ac392fe8/go.mod h1:DZsatdsdXxD0WiwcGl0nJVwyjCKMDv+knl1q2iBjA2k= github.com/anacrolix/dht/v2 v2.19.2-0.20221121215055-066ad8494444 h1:8V0K09lrGoeT2KRJNOtspA7q+OMxGwQqK/Ug0IiaaRE= diff --git a/mse/ctxrw.go b/internal/ctxrw/ctxrw.go similarity index 91% rename from mse/ctxrw.go rename to internal/ctxrw/ctxrw.go index 933c1871d6..8ca19310b9 100644 --- a/mse/ctxrw.go +++ b/internal/ctxrw/ctxrw.go @@ -1,4 +1,4 @@ -package mse +package ctxrw import ( "context" @@ -41,7 +41,7 @@ func (me contextedWriter) Write(p []byte) (n int, err error) { return contextedReadOrWrite(me.ctx, me.w.Write, p) } -func contextedReadWriter(ctx context.Context, rw io.ReadWriter) io.ReadWriter { +func WrapReadWriter(ctx context.Context, rw io.ReadWriter) io.ReadWriter { return struct { io.Reader io.Writer diff --git a/mse/mse.go b/mse/mse.go index 6ab6f2236f..f1bd63b8a3 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -12,6 +12,7 @@ import ( "errors" "expvar" "fmt" + "github.com/anacrolix/torrent/internal/ctxrw" "io" "math" "math/big" @@ -554,7 +555,7 @@ func InitiateHandshakeContext( ) { h := handshake{ conn: rw, - ctxConn: contextedReadWriter(ctx, rw), + ctxConn: ctxrw.WrapReadWriter(ctx, rw), initer: true, skey: skey, ia: initialPayload, @@ -589,7 +590,7 @@ func ReceiveHandshakeEx( ) (ret HandshakeResult) { h := handshake{ conn: rw, - ctxConn: contextedReadWriter(ctx, rw), + ctxConn: ctxrw.WrapReadWriter(ctx, rw), initer: false, skeys: skeys, chooseMethod: selectCrypto, diff --git a/peer_protocol/handshake.go b/peer_protocol/handshake.go index a6512c796e..a87b7c9215 100644 --- a/peer_protocol/handshake.go +++ b/peer_protocol/handshake.go @@ -1,9 +1,11 @@ package peer_protocol import ( + "context" "encoding/hex" "errors" "fmt" + "github.com/anacrolix/torrent/internal/ctxrw" "io" "math/bits" "strconv" @@ -122,10 +124,15 @@ type HandshakeResult struct { // connection. Returns ok if the Handshake was successful, and err if there was an unexpected // condition other than the peer simply abandoning the Handshake. func Handshake( - sock io.ReadWriter, ih *metainfo.Hash, peerID [20]byte, extensions PeerExtensionBits, + ctx context.Context, + sock io.ReadWriter, + ih *metainfo.Hash, + peerID [20]byte, + extensions PeerExtensionBits, ) ( res HandshakeResult, err error, ) { + sock = ctxrw.WrapReadWriter(ctx, sock) // Bytes to be sent to the peer. Should never block the sender. postCh := make(chan []byte, 4) // A single error value sent when the writer completes. diff --git a/torrent_test.go b/torrent_test.go index ab3f436103..bd409b6834 100644 --- a/torrent_test.go +++ b/torrent_test.go @@ -1,6 +1,7 @@ package torrent import ( + "context" "fmt" "io" "net" @@ -187,7 +188,7 @@ func TestTorrentMetainfoIncompleteMetadata(t *testing.T) { var pex PeerExtensionBits pex.SetBit(pp.ExtensionBitLtep, true) - hr, err := pp.Handshake(nc, &ih, [20]byte{}, pex) + hr, err := pp.Handshake(context.Background(), nc, &ih, [20]byte{}, pex) require.NoError(t, err) assert.True(t, hr.PeerExtensionBits.GetBit(pp.ExtensionBitLtep)) assert.EqualValues(t, cl.PeerID(), hr.PeerID)