diff --git a/peerconn.go b/peerconn.go index a8912f4a09..f12477fd6f 100644 --- a/peerconn.go +++ b/peerconn.go @@ -895,7 +895,9 @@ func (c *PeerConn) mainReadLoop() (err error) { err = c.onReadExtendedMsg(msg.ExtendedID, msg.ExtendedPayload) case pp.Hashes: err = c.onReadHashes(&msg) - case pp.HashRequest, pp.HashReject: + case pp.HashRequest: + err = c.onHashRequest(&msg) + case pp.HashReject: c.protocolLogger.Levelf(log.Info, "received unimplemented BitTorrent v2 message: %v", msg.Type) default: err = fmt.Errorf("received unknown message type: %#v", msg.Type) @@ -1356,6 +1358,69 @@ func (pc *PeerConn) onReadHashes(msg *pp.Message) (err error) { return nil } +func (pc *PeerConn) getHashes(msg *pp.Message) ([][32]byte, error) { + if msg.ProofLayers != 0 { + return nil, errors.New("proof layers not supported") + } + if msg.Length > 8192 { + return nil, fmt.Errorf("requested too many hashes: %d", msg.Length) + } + file := pc.t.getFileByPiecesRoot(msg.PiecesRoot) + if file == nil { + return nil, fmt.Errorf("no file for pieces root %x", msg.PiecesRoot) + } + beginPieceIndex := file.BeginPieceIndex() + endPieceIndex := file.EndPieceIndex() + length := merkle.RoundUpToPowerOfTwo(uint(endPieceIndex - beginPieceIndex)) + if uint(msg.Index+msg.Length) > length { + return nil, errors.New("invalid hash range") + } + + hashes := make([][32]byte, msg.Length) + padHash := metainfo.HashForPiecePad(int64(pc.t.usualPieceSize())) + for i := range hashes { + torrentPieceIndex := beginPieceIndex + int(msg.Index) + i + if torrentPieceIndex >= endPieceIndex { + hashes[i] = padHash + continue + } + piece := pc.t.piece(torrentPieceIndex) + hash, err := piece.obtainHashV2() + if err != nil { + return nil, fmt.Errorf("can't get hash for piece %d: %w", torrentPieceIndex, err) + } + hashes[i] = hash + } + return hashes, nil +} + +func (pc *PeerConn) onHashRequest(msg *pp.Message) error { + if !pc.t.info.HasV2() { + return errors.New("torrent has no v2 metadata") + } + + resp := pp.Message{ + PiecesRoot: msg.PiecesRoot, + BaseLayer: msg.BaseLayer, + Index: msg.Index, + Length: msg.Length, + ProofLayers: msg.ProofLayers, + } + + hashes, err := pc.getHashes(msg) + if err != nil { + pc.protocolLogger.WithNames(v2HashesLogName).Levelf(log.Debug, "error getting hashes: %v", err) + resp.Type = pp.HashReject + pc.write(resp) + return nil + } + + resp.Type = pp.Hashes + resp.Hashes = hashes + pc.write(resp) + return nil +} + type hashRequest struct { piecesRoot [32]byte baseLayer, index, length, proofLayers pp.Integer diff --git a/piece.go b/piece.go index 5b2a1f3b60..7dd3f61e30 100644 --- a/piece.go +++ b/piece.go @@ -1,6 +1,7 @@ package torrent import ( + "errors" "fmt" "sync" @@ -8,6 +9,7 @@ import ( g "github.com/anacrolix/generics" "github.com/anacrolix/missinggo/v2/bitmap" + "github.com/anacrolix/torrent/merkle" "github.com/anacrolix/torrent/metainfo" pp "github.com/anacrolix/torrent/peer_protocol" "github.com/anacrolix/torrent/storage" @@ -307,5 +309,28 @@ func (p *Piece) haveHash() bool { } func (p *Piece) hasPieceLayer() bool { - return int64(p.length()) > p.t.info.PieceLength + return len(p.files) == 1 && p.files[0].length > p.t.info.PieceLength +} + +func (p *Piece) obtainHashV2() (hash [32]byte, err error) { + if p.hashV2.Ok { + hash = p.hashV2.Value + return + } + if !p.hasPieceLayer() { + hash = p.mustGetOnlyFile().piecesRoot.Unwrap() + return + } + storage := p.Storage() + if !storage.Completion().Complete { + err = errors.New("piece incomplete") + return + } + + h := merkle.NewHash() + if _, err = storage.WriteTo(h); err != nil { + return + } + h.SumMinLength(hash[:0], int(p.t.info.PieceLength)) + return }