diff --git a/callbacks.go b/callbacks.go index f9ba131b13..1aeafbfc14 100644 --- a/callbacks.go +++ b/callbacks.go @@ -11,10 +11,14 @@ import ( type Callbacks struct { // Called after a peer connection completes the BitTorrent handshake. The Client lock is not // held. - CompletedHandshake func(*PeerConn, InfoHash) - ReadMessage func(*PeerConn, *pp.Message) + CompletedHandshake func(*PeerConn, InfoHash) + ReadMessage func(*PeerConn, *pp.Message) + // This can be folded into the general case below. ReadExtendedHandshake func(*PeerConn, *pp.ExtendedHandshakeMessage) PeerConnClosed func(*PeerConn) + // BEP 10 message. Not sure if I should call this Ltep universally. Each handler here is called + // in order. + PeerConnReadExtensionMessage []func(PeerConnReadExtensionMessageEvent) // Provides secret keys to be tried against incoming encrypted connections. ReceiveEncryptedHandshakeSkeys mse.SecretKeyIter @@ -38,3 +42,12 @@ type PeerRequestEvent struct { Peer *Peer Request } + +type PeerConnReadExtensionMessageEvent struct { + PeerConn *PeerConn + // Whether the client has builtin support for this extension. + BuiltinHandler bool + // You can look up what protocol this corresponds to using the PeerConn.LocalLtepProtocolMap. + ExtensionNumber pp.ExtensionNumber + Payload []byte +} diff --git a/peer_protocol/extended.go b/peer_protocol/extended.go index 8bc5181633..019590e40a 100644 --- a/peer_protocol/extended.go +++ b/peer_protocol/extended.go @@ -24,7 +24,7 @@ type ( } ExtensionName string - ExtensionNumber int + ExtensionNumber uint8 ) const ( diff --git a/peerconn.go b/peerconn.go index e2d944ff26..03500be1ff 100644 --- a/peerconn.go +++ b/peerconn.go @@ -44,6 +44,8 @@ type PeerConn struct { PeerID PeerID PeerExtensionBytes pp.PeerExtensionBits PeerListenPort int + // 1-based mapping from extension number to extension name. + LocalLtepProtocolMap []pp.ExtensionName // The actual Conn, used for closing, and setting socket options. Do not use methods on this // while holding any mutexes. @@ -55,6 +57,7 @@ type PeerConn struct { messageWriter peerConnMsgWriter + // The peer's extension map, as sent in their extended handshake. PeerExtensionIDs map[pp.ExtensionName]pp.ExtensionNumber PeerClientName atomic.Value uploadTimer *time.Timer @@ -854,8 +857,20 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err }() t := c.t cl := t.cl - switch id { - case pp.HandshakeExtendedID: + { + event := PeerConnReadExtensionMessageEvent{ + PeerConn: c, + // Add one for the handshake ID. This isn't quite right yet, we don't actually have a + // way to differentiate builtin handlers. + BuiltinHandler: int(id) < len(c.LocalLtepProtocolMap)+1, + ExtensionNumber: id, + Payload: payload, + } + for _, cb := range c.callbacks.PeerConnReadExtensionMessage { + cb(event) + } + } + if id == pp.HandshakeExtendedID { var d pp.ExtendedHandshakeMessage if err := bencode.Unmarshal(payload, &d); err != nil { c.logger.Printf("error parsing extended handshake message %q: %s", payload, err) @@ -864,7 +879,6 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err if cb := c.callbacks.ReadExtendedHandshake; cb != nil { cb(c, &d) } - // c.logger.WithDefaultLevel(log.Debug).Printf("received extended handshake message:\n%s", spew.Sdump(d)) if d.Reqq != 0 { c.PeerMaxRequests = d.Reqq } @@ -896,13 +910,19 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err c.pex.Init(c) } return nil - case metadataExtendedId: + } + // Zero was taken care of above. + if int(id) >= len(c.LocalLtepProtocolMap)+1 { + return fmt.Errorf("unexpected extended message ID: %v", id) + } + switch c.LocalLtepProtocolMap[id-1] { + case pp.ExtensionNameMetadata: err := cl.gotMetadataExtensionMsg(payload, t, c) if err != nil { return fmt.Errorf("handling metadata extension message: %w", err) } return nil - case pexExtendedId: + case pp.ExtensionNamePex: if !c.pex.IsEnabled() { return nil // or hang-up maybe? } @@ -911,7 +931,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err err = fmt.Errorf("receiving pex message: %w", err) } return - case utHolepunchExtendedId: + case utHolepunch.ExtensionName: var msg utHolepunch.Msg err = msg.UnmarshalBinary(payload) if err != nil { @@ -921,7 +941,9 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err err = c.t.handleReceivedUtHolepunchMsg(msg, c) return default: - return fmt.Errorf("unexpected extended message ID: %v", id) + // This should have been a user configured protocol, and handled by a callback. How can they + // propagate errors? + return nil } }