diff --git a/dht.go b/dht.go index df5a43148..7c89ab56a 100644 --- a/dht.go +++ b/dht.go @@ -1,9 +1,7 @@ package dht import ( - "bytes" "context" - "errors" "fmt" "math" "math/rand" @@ -17,6 +15,7 @@ import ( "github.com/libp2p/go-libp2p-core/protocol" "github.com/libp2p/go-libp2p-core/routing" + "github.com/libp2p/go-libp2p-kad-dht/internal" "github.com/libp2p/go-libp2p-kad-dht/metrics" pb "github.com/libp2p/go-libp2p-kad-dht/pb" "github.com/libp2p/go-libp2p-kad-dht/providers" @@ -33,7 +32,6 @@ import ( goprocessctx "github.com/jbenet/goprocess/context" "github.com/multiformats/go-base32" ma "github.com/multiformats/go-multiaddr" - "github.com/multiformats/go-multihash" "go.opencensus.io/tag" "go.uber.org/zap" ) @@ -97,8 +95,8 @@ type IpfsDHT struct { ctx context.Context proc goprocess.Process - strmap map[peer.ID]*messageSender - smlk sync.Mutex + protoMessenger *pb.ProtocolMessenger + msgSender *messageSenderImpl plk sync.Mutex @@ -190,6 +188,15 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) dht.disableFixLowPeers = cfg.disableFixLowPeers dht.Validator = cfg.validator + dht.msgSender = &messageSenderImpl{ + host: h, + strmap: make(map[peer.ID]*peerMessageSender), + protocols: dht.protocols, + } + dht.protoMessenger, err = pb.NewProtocolMessenger(dht.msgSender, pb.WithValidator(dht.Validator)) + if err != nil { + return nil, err + } dht.testAddressUpdateProcessing = cfg.testAddressUpdateProcessing @@ -276,7 +283,6 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) (*IpfsDHT, error) { selfKey: kb.ConvertPeerID(h.ID()), peerstore: h.Peerstore(), host: h, - strmap: make(map[peer.ID]*messageSender), birth: time.Now(), protocols: protocols, protocolsStrs: protocol.ConvertToStrings(protocols), @@ -530,80 +536,22 @@ func (dht *IpfsDHT) persistRTPeersInPeerStore() { } } -// putValueToPeer stores the given key/value pair at the peer 'p' -func (dht *IpfsDHT) putValueToPeer(ctx context.Context, p peer.ID, rec *recpb.Record) error { - pmes := pb.NewMessage(pb.Message_PUT_VALUE, rec.Key, 0) - pmes.Record = rec - rpmes, err := dht.sendRequest(ctx, p, pmes) - if err != nil { - logger.Debugw("failed to put value to peer", "to", p, "key", loggableRecordKeyBytes(rec.Key), "error", err) - return err - } - - if !bytes.Equal(rpmes.GetRecord().Value, pmes.GetRecord().Value) { - logger.Infow("value not put correctly", "put-message", pmes, "get-message", rpmes) - return errors.New("value not put correctly") - } - - return nil -} - -var errInvalidRecord = errors.New("received invalid record") - -// getValueOrPeers queries a particular peer p for the value for -// key. It returns either the value or a list of closer peers. -// NOTE: It will update the dht's peerstore with any new addresses -// it finds for the given peer. -func (dht *IpfsDHT) getValueOrPeers(ctx context.Context, p peer.ID, key string) (*recpb.Record, []*peer.AddrInfo, error) { - pmes, err := dht.getValueSingle(ctx, p, key) - if err != nil { - return nil, nil, err - } - - // Perhaps we were given closer peers - peers := pb.PBPeersToPeerInfos(pmes.GetCloserPeers()) - - if rec := pmes.GetRecord(); rec != nil { - // Success! We were given the value - logger.Debug("got value") - - // make sure record is valid. - err = dht.Validator.Validate(string(rec.GetKey()), rec.GetValue()) - if err != nil { - logger.Debug("received invalid record (discarded)") - // return a sentinal to signify an invalid record was received - err = errInvalidRecord - rec = new(recpb.Record) - } - return rec, peers, err - } - - if len(peers) > 0 { - return nil, peers, nil - } - - return nil, nil, routing.ErrNotFound -} - -// getValueSingle simply performs the get value RPC with the given parameters -func (dht *IpfsDHT) getValueSingle(ctx context.Context, p peer.ID, key string) (*pb.Message, error) { - pmes := pb.NewMessage(pb.Message_GET_VALUE, []byte(key), 0) - return dht.sendRequest(ctx, p, pmes) -} - -// getLocal attempts to retrieve the value from the datastore +// getLocal attempts to retrieve the value from the datastore. +// +// returns nil, nil when either nothing is found or the value found doesn't properly validate. +// returns nil, some_error when there's a *datastore* error (i.e., something goes very wrong) func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) { - logger.Debugw("finding value in datastore", "key", loggableRecordKeyString(key)) + logger.Debugw("finding value in datastore", "key", internal.LoggableRecordKeyString(key)) rec, err := dht.getRecordFromDatastore(mkDsKey(key)) if err != nil { - logger.Warnw("get local failed", "key", loggableRecordKeyString(key), "error", err) + logger.Warnw("get local failed", "key", internal.LoggableRecordKeyString(key), "error", err) return nil, err } // Double check the key. Can't hurt. if rec != nil && string(rec.GetKey()) != key { - logger.Errorw("BUG: found a DHT record that didn't match it's key", "expected", loggableRecordKeyString(key), "got", rec.GetKey()) + logger.Errorw("BUG: found a DHT record that didn't match it's key", "expected", internal.LoggableRecordKeyString(key), "got", rec.GetKey()) return nil, nil } @@ -614,7 +562,7 @@ func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) { func (dht *IpfsDHT) putLocal(key string, rec *recpb.Record) error { data, err := proto.Marshal(rec) if err != nil { - logger.Warnw("failed to put marshal record for local put", "error", err, "key", loggableRecordKeyString(key)) + logger.Warnw("failed to put marshal record for local put", "error", err, "key", internal.LoggableRecordKeyString(key)) return err } @@ -719,17 +667,6 @@ func (dht *IpfsDHT) FindLocal(id peer.ID) peer.AddrInfo { } } -// findPeerSingle asks peer 'p' if they know where the peer with id 'id' is -func (dht *IpfsDHT) findPeerSingle(ctx context.Context, p peer.ID, id peer.ID) (*pb.Message, error) { - pmes := pb.NewMessage(pb.Message_FIND_NODE, []byte(id), 0) - return dht.sendRequest(ctx, p, pmes) -} - -func (dht *IpfsDHT) findProvidersSingle(ctx context.Context, p peer.ID, key multihash.Multihash) (*pb.Message, error) { - pmes := pb.NewMessage(pb.Message_GET_PROVIDERS, key, 0) - return dht.sendRequest(ctx, p, pmes) -} - // nearestPeersToQuery returns the routing tables closest peers. func (dht *IpfsDHT) nearestPeersToQuery(pmes *pb.Message, count int) []peer.ID { closer := dht.routingTable.NearestPeers(kb.ConvertKey(string(pmes.GetKey())), count) @@ -870,15 +807,7 @@ func (dht *IpfsDHT) Host() host.Host { // Ping sends a ping message to the passed peer and waits for a response. func (dht *IpfsDHT) Ping(ctx context.Context, p peer.ID) error { - req := pb.NewMessage(pb.Message_PING, nil, 0) - resp, err := dht.sendRequest(ctx, p, req) - if err != nil { - return fmt.Errorf("sending request: %w", err) - } - if resp.Type != pb.Message_PING { - return fmt.Errorf("got unexpected response type: %v", resp.Type) - } - return nil + return dht.protoMessenger.Ping(ctx, p) } // newContextWithLocalTags returns a new context.Context with the InstanceID and diff --git a/dht_net.go b/dht_net.go index 879be778f..278216625 100644 --- a/dht_net.go +++ b/dht_net.go @@ -2,14 +2,12 @@ package dht import ( "bufio" - "context" "fmt" "io" "sync" "time" "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-msgio/protoio" "github.com/libp2p/go-libp2p-kad-dht/metrics" @@ -207,281 +205,3 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool { stats.Record(ctx, metrics.InboundRequestLatency.M(latencyMillis)) } } - -// sendRequest sends out a request, but also makes sure to -// measure the RTT for latency measurements. -func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { - ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) - - ms, err := dht.messageSenderForPeer(ctx, p) - if err != nil { - stats.Record(ctx, - metrics.SentRequests.M(1), - metrics.SentRequestErrors.M(1), - ) - logger.Debugw("request failed to open message sender", "error", err, "to", p) - return nil, err - } - - start := time.Now() - - rpmes, err := ms.SendRequest(ctx, pmes) - if err != nil { - stats.Record(ctx, - metrics.SentRequests.M(1), - metrics.SentRequestErrors.M(1), - ) - logger.Debugw("request failed", "error", err, "to", p) - return nil, err - } - - stats.Record(ctx, - metrics.SentRequests.M(1), - metrics.SentBytes.M(int64(pmes.Size())), - metrics.OutboundRequestLatency.M(float64(time.Since(start))/float64(time.Millisecond)), - ) - dht.peerstore.RecordLatency(p, time.Since(start)) - return rpmes, nil -} - -// sendMessage sends out a message -func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { - ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) - - ms, err := dht.messageSenderForPeer(ctx, p) - if err != nil { - stats.Record(ctx, - metrics.SentMessages.M(1), - metrics.SentMessageErrors.M(1), - ) - logger.Debugw("message failed to open message sender", "error", err, "to", p) - return err - } - - if err := ms.SendMessage(ctx, pmes); err != nil { - stats.Record(ctx, - metrics.SentMessages.M(1), - metrics.SentMessageErrors.M(1), - ) - logger.Debugw("message failed", "error", err, "to", p) - return err - } - - stats.Record(ctx, - metrics.SentMessages.M(1), - metrics.SentBytes.M(int64(pmes.Size())), - ) - return nil -} - -func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) { - dht.smlk.Lock() - ms, ok := dht.strmap[p] - if ok { - dht.smlk.Unlock() - return ms, nil - } - ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()} - dht.strmap[p] = ms - dht.smlk.Unlock() - - if err := ms.prepOrInvalidate(ctx); err != nil { - dht.smlk.Lock() - defer dht.smlk.Unlock() - - if msCur, ok := dht.strmap[p]; ok { - // Changed. Use the new one, old one is invalid and - // not in the map so we can just throw it away. - if ms != msCur { - return msCur, nil - } - // Not changed, remove the now invalid stream from the - // map. - delete(dht.strmap, p) - } - // Invalid but not in map. Must have been removed by a disconnect. - return nil, err - } - // All ready to go. - return ms, nil -} - -type messageSender struct { - s network.Stream - r msgio.ReadCloser - lk ctxMutex - p peer.ID - dht *IpfsDHT - - invalid bool - singleMes int -} - -// invalidate is called before this messageSender is removed from the strmap. -// It prevents the messageSender from being reused/reinitialized and then -// forgotten (leaving the stream open). -func (ms *messageSender) invalidate() { - ms.invalid = true - if ms.s != nil { - _ = ms.s.Reset() - ms.s = nil - } -} - -func (ms *messageSender) prepOrInvalidate(ctx context.Context) error { - if err := ms.lk.Lock(ctx); err != nil { - return err - } - defer ms.lk.Unlock() - - if err := ms.prep(ctx); err != nil { - ms.invalidate() - return err - } - return nil -} - -func (ms *messageSender) prep(ctx context.Context) error { - if ms.invalid { - return fmt.Errorf("message sender has been invalidated") - } - if ms.s != nil { - return nil - } - - // We only want to speak to peers using our primary protocols. We do not want to query any peer that only speaks - // one of the secondary "server" protocols that we happen to support (e.g. older nodes that we can respond to for - // backwards compatibility reasons). - nstr, err := ms.dht.host.NewStream(ctx, ms.p, ms.dht.protocols...) - if err != nil { - return err - } - - ms.r = msgio.NewVarintReaderSize(nstr, network.MessageSizeMax) - ms.s = nstr - - return nil -} - -// streamReuseTries is the number of times we will try to reuse a stream to a -// given peer before giving up and reverting to the old one-message-per-stream -// behaviour. -const streamReuseTries = 3 - -func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error { - if err := ms.lk.Lock(ctx); err != nil { - return err - } - defer ms.lk.Unlock() - - retry := false - for { - if err := ms.prep(ctx); err != nil { - return err - } - - if err := ms.writeMsg(pmes); err != nil { - _ = ms.s.Reset() - ms.s = nil - - if retry { - logger.Debugw("error writing message", "error", err) - return err - } - logger.Debugw("error writing message", "error", err, "retrying", true) - retry = true - continue - } - - var err error - if ms.singleMes > streamReuseTries { - err = ms.s.Close() - ms.s = nil - } else if retry { - ms.singleMes++ - } - - return err - } -} - -func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) { - if err := ms.lk.Lock(ctx); err != nil { - return nil, err - } - defer ms.lk.Unlock() - - retry := false - for { - if err := ms.prep(ctx); err != nil { - return nil, err - } - - if err := ms.writeMsg(pmes); err != nil { - _ = ms.s.Reset() - ms.s = nil - - if retry { - logger.Debugw("error writing message", "error", err) - return nil, err - } - logger.Debugw("error writing message", "error", err, "retrying", true) - retry = true - continue - } - - mes := new(pb.Message) - if err := ms.ctxReadMsg(ctx, mes); err != nil { - _ = ms.s.Reset() - ms.s = nil - - if retry { - logger.Debugw("error reading message", "error", err) - return nil, err - } - logger.Debugw("error reading message", "error", err, "retrying", true) - retry = true - continue - } - - var err error - if ms.singleMes > streamReuseTries { - err = ms.s.Close() - ms.s = nil - } else if retry { - ms.singleMes++ - } - - return mes, err - } -} - -func (ms *messageSender) writeMsg(pmes *pb.Message) error { - return writeMsg(ms.s, pmes) -} - -func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { - errc := make(chan error, 1) - go func(r msgio.ReadCloser) { - defer close(errc) - bytes, err := r.ReadMsg() - defer r.ReleaseMsg(bytes) - if err != nil { - errc <- err - return - } - errc <- mes.Unmarshal(bytes) - }(ms.r) - - t := time.NewTimer(dhtReadMessageTimeout) - defer t.Stop() - - select { - case err := <-errc: - return err - case <-ctx.Done(): - return ctx.Err() - case <-t.C: - return ErrReadTimeout - } -} diff --git a/dht_test.go b/dht_test.go index 85dcc22f6..74835c8e2 100644 --- a/dht_test.go +++ b/dht_test.go @@ -570,14 +570,14 @@ func TestInvalidMessageSenderTracking(t *testing.T) { defer dht.Close() foo := peer.ID("asdasd") - _, err := dht.messageSenderForPeer(ctx, foo) + _, err := dht.msgSender.messageSenderForPeer(ctx, foo) if err == nil { t.Fatal("that shouldnt have succeeded") } - dht.smlk.Lock() - mscnt := len(dht.strmap) - dht.smlk.Unlock() + dht.msgSender.smlk.Lock() + mscnt := len(dht.msgSender.strmap) + dht.msgSender.smlk.Unlock() if mscnt > 0 { t.Fatal("should have no message senders in map") diff --git a/go.sum b/go.sum index da973767b..83fe301f2 100644 --- a/go.sum +++ b/go.sum @@ -62,6 +62,7 @@ github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4er github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= +github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= @@ -73,6 +74,7 @@ github.com/golang/protobuf v1.4.0 h1:oOuy+ugB+P/kBdUnG5QaMXSIyJ1q38wWSojYCb3z5VQ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= @@ -100,6 +102,7 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huin/goupnp v1.0.0 h1:wg75sLpL6DZqwHQN6E1Cfk6mtfzS45z8OV+ic+DtHRo= github.com/huin/goupnp v1.0.0/go.mod h1:n9v9KO1tAxYH82qOn+UTIFQDmx5n1Zxd/ClZDMX7Bnc= @@ -488,6 +491,7 @@ github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.0 h1:Iw5WCbBcaAAd0fpRb1c9r5YCylv4XDoCSigm1zLevwU= github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg= github.com/onsi/ginkgo v1.12.1 h1:mFwc4LvZ0xpSvDZ3E+k8Yte0hLOMxXUlP+yXtJqkYfQ= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= @@ -652,6 +656,7 @@ golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 h1:VvQyQJN0tSuecqgcIxMWnnfG5kSmgy9KZR9sW3W5QeA= golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 h1:/atklqdjdhuosWIl6AIbOeHJjicWYPqR9bpxqxYG2pA= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -682,6 +687,7 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/src-d/go-cli.v0 v0.0.0-20181105080154-d492247bbc0d/go.mod h1:z+K8VcOYVYcSwSjGebuDL6176A1XskgbtNl64NSg+n8= gopkg.in/src-d/go-log.v1 v1.0.1/go.mod h1:GN34hKP0g305ysm2/hctJ0Y8nWP3zxXXJ8GFabTyABE= diff --git a/handlers.go b/handlers.go index 99bea1942..5160232c0 100644 --- a/handlers.go +++ b/handlers.go @@ -14,6 +14,7 @@ import ( "github.com/gogo/protobuf/proto" ds "github.com/ipfs/go-datastore" u "github.com/ipfs/go-ipfs-util" + "github.com/libp2p/go-libp2p-kad-dht/internal" pb "github.com/libp2p/go-libp2p-kad-dht/pb" recpb "github.com/libp2p/go-libp2p-record/pb" "github.com/multiformats/go-base32" @@ -167,7 +168,7 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess // Make sure the record is valid (not expired, valid signature etc) if err = dht.Validator.Validate(string(rec.GetKey()), rec.GetValue()); err != nil { - logger.Infow("bad dht record in PUT", "from", p, "key", loggableRecordKeyBytes(rec.GetKey()), "error", err) + logger.Infow("bad dht record in PUT", "from", p, "key", internal.LoggableRecordKeyBytes(rec.GetKey()), "error", err) return nil, err } @@ -196,11 +197,11 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess recs := [][]byte{rec.GetValue(), existing.GetValue()} i, err := dht.Validator.Select(string(rec.GetKey()), recs) if err != nil { - logger.Warnw("dht record passed validation but failed select", "from", p, "key", loggableRecordKeyBytes(rec.GetKey()), "error", err) + logger.Warnw("dht record passed validation but failed select", "from", p, "key", internal.LoggableRecordKeyBytes(rec.GetKey()), "error", err) return nil, err } if i != 0 { - logger.Infow("DHT record in PUT older than existing record (ignoring)", "peer", p, "key", loggableRecordKeyBytes(rec.GetKey())) + logger.Infow("DHT record in PUT older than existing record (ignoring)", "peer", p, "key", internal.LoggableRecordKeyBytes(rec.GetKey())) return nil, errors.New("old record") } } @@ -344,7 +345,7 @@ func (dht *IpfsDHT) handleAddProvider(ctx context.Context, p peer.ID, pmes *pb.M return nil, fmt.Errorf("handleAddProvider key is empty") } - logger.Debugf("adding provider", "from", p, "key", loggableProviderRecordBytes(key)) + logger.Debugf("adding provider", "from", p, "key", internal.LoggableProviderRecordBytes(key)) // add provider should use the address given in the message pinfos := pb.PBPeersToPeerInfos(pmes.GetProviderPeers()) diff --git a/internal/errors.go b/internal/errors.go new file mode 100644 index 000000000..3c32a83dc --- /dev/null +++ b/internal/errors.go @@ -0,0 +1,5 @@ +package internal + +import "errors" + +var ErrInvalidRecord = errors.New("received invalid record") diff --git a/logging.go b/internal/logging.go similarity index 72% rename from logging.go rename to internal/logging.go index ffc337e3c..981f728cd 100644 --- a/logging.go +++ b/internal/logging.go @@ -1,4 +1,4 @@ -package dht +package internal import ( "fmt" @@ -20,14 +20,14 @@ func multibaseB32Encode(k []byte) string { func tryFormatLoggableRecordKey(k string) (string, error) { if len(k) == 0 { - return "", fmt.Errorf("loggableRecordKey is empty") + return "", fmt.Errorf("LoggableRecordKey is empty") } var proto, cstr string if k[0] == '/' { // it's a path (probably) protoEnd := strings.IndexByte(k[1:], '/') if protoEnd < 0 { - return "", fmt.Errorf("loggableRecordKey starts with '/' but is not a path: %s", multibaseB32Encode([]byte(k))) + return "", fmt.Errorf("LoggableRecordKey starts with '/' but is not a path: %s", multibaseB32Encode([]byte(k))) } proto = k[1 : protoEnd+1] cstr = k[protoEnd+2:] @@ -36,12 +36,12 @@ func tryFormatLoggableRecordKey(k string) (string, error) { return fmt.Sprintf("/%s/%s", proto, encStr), nil } - return "", fmt.Errorf("loggableRecordKey is not a path: %s", multibaseB32Encode([]byte(cstr))) + return "", fmt.Errorf("LoggableRecordKey is not a path: %s", multibaseB32Encode([]byte(cstr))) } -type loggableRecordKeyString string +type LoggableRecordKeyString string -func (lk loggableRecordKeyString) String() string { +func (lk LoggableRecordKeyString) String() string { k := string(lk) newKey, err := tryFormatLoggableRecordKey(k) if err == nil { @@ -50,9 +50,9 @@ func (lk loggableRecordKeyString) String() string { return err.Error() } -type loggableRecordKeyBytes []byte +type LoggableRecordKeyBytes []byte -func (lk loggableRecordKeyBytes) String() string { +func (lk LoggableRecordKeyBytes) String() string { k := string(lk) newKey, err := tryFormatLoggableRecordKey(k) if err == nil { @@ -61,9 +61,9 @@ func (lk loggableRecordKeyBytes) String() string { return err.Error() } -type loggableProviderRecordBytes []byte +type LoggableProviderRecordBytes []byte -func (lk loggableProviderRecordBytes) String() string { +func (lk LoggableProviderRecordBytes) String() string { newKey, err := tryFormatLoggableProviderKey(lk) if err == nil { return newKey @@ -73,7 +73,7 @@ func (lk loggableProviderRecordBytes) String() string { func tryFormatLoggableProviderKey(k []byte) (string, error) { if len(k) == 0 { - return "", fmt.Errorf("loggableProviderKey is empty") + return "", fmt.Errorf("LoggableProviderKey is empty") } encodedKey := multibaseB32Encode(k) @@ -88,5 +88,5 @@ func tryFormatLoggableProviderKey(k []byte) (string, error) { return encodedKey, nil } - return "", fmt.Errorf("loggableProviderKey is not a Multihash or CID: %s", encodedKey) + return "", fmt.Errorf("LoggableProviderKey is not a Multihash or CID: %s", encodedKey) } diff --git a/logging_test.go b/internal/logging_test.go similarity index 99% rename from logging_test.go rename to internal/logging_test.go index 64e4f87ba..25bd033e8 100644 --- a/logging_test.go +++ b/internal/logging_test.go @@ -1,4 +1,4 @@ -package dht +package internal import ( "testing" diff --git a/lookup.go b/lookup.go index 0168601a5..dff8bb244 100644 --- a/lookup.go +++ b/lookup.go @@ -8,7 +8,6 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/routing" - pb "github.com/libp2p/go-libp2p-kad-dht/pb" kb "github.com/libp2p/go-libp2p-kbucket" ) @@ -30,12 +29,11 @@ func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) (<-chan pee ID: p, }) - pmes, err := dht.findPeerSingle(ctx, p, peer.ID(key)) + peers, err := dht.protoMessenger.GetClosestPeers(ctx, p, peer.ID(key)) if err != nil { logger.Debugf("error getting closer peers: %s", err) return nil, err } - peers := pb.PBPeersToPeerInfos(pmes.GetCloserPeers()) // For DHT query command routing.PublishQueryEvent(ctx, &routing.QueryEvent{ diff --git a/message_manager.go b/message_manager.go new file mode 100644 index 000000000..8cc3e22e3 --- /dev/null +++ b/message_manager.go @@ -0,0 +1,327 @@ +package dht + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + + "github.com/libp2p/go-libp2p-kad-dht/metrics" + pb "github.com/libp2p/go-libp2p-kad-dht/pb" + + "github.com/libp2p/go-msgio" + "go.opencensus.io/stats" + "go.opencensus.io/tag" +) + +// messageSenderImpl is responsible for sending requests and messages to peers efficiently, including reuse of streams. +// It also tracks metrics for sent requests and messages. +type messageSenderImpl struct { + host host.Host // the network services we need + smlk sync.Mutex + strmap map[peer.ID]*peerMessageSender + protocols []protocol.ID +} + +func (m *messageSenderImpl) streamDisconnect(ctx context.Context, p peer.ID) { + m.smlk.Lock() + defer m.smlk.Unlock() + ms, ok := m.strmap[p] + if !ok { + return + } + delete(m.strmap, p) + + // Do this asynchronously as ms.lk can block for a while. + go func() { + if err := ms.lk.Lock(ctx); err != nil { + return + } + defer ms.lk.Unlock() + ms.invalidate() + }() +} + +// SendRequest sends out a request, but also makes sure to +// measure the RTT for latency measurements. +func (m *messageSenderImpl) SendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { + ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) + + ms, err := m.messageSenderForPeer(ctx, p) + if err != nil { + stats.Record(ctx, + metrics.SentRequests.M(1), + metrics.SentRequestErrors.M(1), + ) + logger.Debugw("request failed to open message sender", "error", err, "to", p) + return nil, err + } + + start := time.Now() + + rpmes, err := ms.SendRequest(ctx, pmes) + if err != nil { + stats.Record(ctx, + metrics.SentRequests.M(1), + metrics.SentRequestErrors.M(1), + ) + logger.Debugw("request failed", "error", err, "to", p) + return nil, err + } + + stats.Record(ctx, + metrics.SentRequests.M(1), + metrics.SentBytes.M(int64(pmes.Size())), + metrics.OutboundRequestLatency.M(float64(time.Since(start))/float64(time.Millisecond)), + ) + m.host.Peerstore().RecordLatency(p, time.Since(start)) + return rpmes, nil +} + +// SendMessage sends out a message +func (m *messageSenderImpl) SendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { + ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) + + ms, err := m.messageSenderForPeer(ctx, p) + if err != nil { + stats.Record(ctx, + metrics.SentMessages.M(1), + metrics.SentMessageErrors.M(1), + ) + logger.Debugw("message failed to open message sender", "error", err, "to", p) + return err + } + + if err := ms.SendMessage(ctx, pmes); err != nil { + stats.Record(ctx, + metrics.SentMessages.M(1), + metrics.SentMessageErrors.M(1), + ) + logger.Debugw("message failed", "error", err, "to", p) + return err + } + + stats.Record(ctx, + metrics.SentMessages.M(1), + metrics.SentBytes.M(int64(pmes.Size())), + ) + return nil +} + +func (m *messageSenderImpl) messageSenderForPeer(ctx context.Context, p peer.ID) (*peerMessageSender, error) { + m.smlk.Lock() + ms, ok := m.strmap[p] + if ok { + m.smlk.Unlock() + return ms, nil + } + ms = &peerMessageSender{p: p, m: m, lk: newCtxMutex()} + m.strmap[p] = ms + m.smlk.Unlock() + + if err := ms.prepOrInvalidate(ctx); err != nil { + m.smlk.Lock() + defer m.smlk.Unlock() + + if msCur, ok := m.strmap[p]; ok { + // Changed. Use the new one, old one is invalid and + // not in the map so we can just throw it away. + if ms != msCur { + return msCur, nil + } + // Not changed, remove the now invalid stream from the + // map. + delete(m.strmap, p) + } + // Invalid but not in map. Must have been removed by a disconnect. + return nil, err + } + // All ready to go. + return ms, nil +} + +// peerMessageSender is responsible for sending requests and messages to a particular peer +type peerMessageSender struct { + s network.Stream + r msgio.ReadCloser + lk ctxMutex + p peer.ID + m *messageSenderImpl + + invalid bool + singleMes int +} + +// invalidate is called before this peerMessageSender is removed from the strmap. +// It prevents the peerMessageSender from being reused/reinitialized and then +// forgotten (leaving the stream open). +func (ms *peerMessageSender) invalidate() { + ms.invalid = true + if ms.s != nil { + _ = ms.s.Reset() + ms.s = nil + } +} + +func (ms *peerMessageSender) prepOrInvalidate(ctx context.Context) error { + if err := ms.lk.Lock(ctx); err != nil { + return err + } + defer ms.lk.Unlock() + + if err := ms.prep(ctx); err != nil { + ms.invalidate() + return err + } + return nil +} + +func (ms *peerMessageSender) prep(ctx context.Context) error { + if ms.invalid { + return fmt.Errorf("message sender has been invalidated") + } + if ms.s != nil { + return nil + } + + // We only want to speak to peers using our primary protocols. We do not want to query any peer that only speaks + // one of the secondary "server" protocols that we happen to support (e.g. older nodes that we can respond to for + // backwards compatibility reasons). + nstr, err := ms.m.host.NewStream(ctx, ms.p, ms.m.protocols...) + if err != nil { + return err + } + + ms.r = msgio.NewVarintReaderSize(nstr, network.MessageSizeMax) + ms.s = nstr + + return nil +} + +// streamReuseTries is the number of times we will try to reuse a stream to a +// given peer before giving up and reverting to the old one-message-per-stream +// behaviour. +const streamReuseTries = 3 + +func (ms *peerMessageSender) SendMessage(ctx context.Context, pmes *pb.Message) error { + if err := ms.lk.Lock(ctx); err != nil { + return err + } + defer ms.lk.Unlock() + + retry := false + for { + if err := ms.prep(ctx); err != nil { + return err + } + + if err := ms.writeMsg(pmes); err != nil { + _ = ms.s.Reset() + ms.s = nil + + if retry { + logger.Debugw("error writing message", "error", err) + return err + } + logger.Debugw("error writing message", "error", err, "retrying", true) + retry = true + continue + } + + var err error + if ms.singleMes > streamReuseTries { + err = ms.s.Close() + ms.s = nil + } else if retry { + ms.singleMes++ + } + + return err + } +} + +func (ms *peerMessageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) { + if err := ms.lk.Lock(ctx); err != nil { + return nil, err + } + defer ms.lk.Unlock() + + retry := false + for { + if err := ms.prep(ctx); err != nil { + return nil, err + } + + if err := ms.writeMsg(pmes); err != nil { + _ = ms.s.Reset() + ms.s = nil + + if retry { + logger.Debugw("error writing message", "error", err) + return nil, err + } + logger.Debugw("error writing message", "error", err, "retrying", true) + retry = true + continue + } + + mes := new(pb.Message) + if err := ms.ctxReadMsg(ctx, mes); err != nil { + _ = ms.s.Reset() + ms.s = nil + + if retry { + logger.Debugw("error reading message", "error", err) + return nil, err + } + logger.Debugw("error reading message", "error", err, "retrying", true) + retry = true + continue + } + + var err error + if ms.singleMes > streamReuseTries { + err = ms.s.Close() + ms.s = nil + } else if retry { + ms.singleMes++ + } + + return mes, err + } +} + +func (ms *peerMessageSender) writeMsg(pmes *pb.Message) error { + return writeMsg(ms.s, pmes) +} + +func (ms *peerMessageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { + errc := make(chan error, 1) + go func(r msgio.ReadCloser) { + defer close(errc) + bytes, err := r.ReadMsg() + defer r.ReleaseMsg(bytes) + if err != nil { + errc <- err + return + } + errc <- mes.Unmarshal(bytes) + }(ms.r) + + t := time.NewTimer(dhtReadMessageTimeout) + defer t.Stop() + + select { + case err := <-errc: + return err + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + return ErrReadTimeout + } +} diff --git a/pb/protocol_messenger.go b/pb/protocol_messenger.go new file mode 100644 index 000000000..7524f59b9 --- /dev/null +++ b/pb/protocol_messenger.go @@ -0,0 +1,175 @@ +package dht_pb + +import ( + "bytes" + "context" + "errors" + "fmt" + + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/routing" + + logging "github.com/ipfs/go-log" + record "github.com/libp2p/go-libp2p-record" + recpb "github.com/libp2p/go-libp2p-record/pb" + "github.com/multiformats/go-multihash" + + "github.com/libp2p/go-libp2p-kad-dht/internal" +) + +var logger = logging.Logger("dht") + +// ProtocolMessenger can be used for sending DHT messages to peers and processing their responses. +// This decouples the wire protocol format from both the DHT protocol implementation and from the implementation of the +// routing.Routing interface. +// +// Note: the ProtocolMessenger's MessageSender still needs to deal with some wire protocol details such as using +// varint-delineated protobufs +type ProtocolMessenger struct { + m MessageSender + validator record.Validator +} + +type ProtocolMessengerOption func(*ProtocolMessenger) error + +func WithValidator(validator record.Validator) ProtocolMessengerOption { + return func(messenger *ProtocolMessenger) error { + messenger.validator = validator + return nil + } +} + +// NewProtocolMessenger creates a new ProtocolMessenger that is used for sending DHT messages to peers and processing +// their responses. +func NewProtocolMessenger(msgSender MessageSender, opts ...ProtocolMessengerOption) (*ProtocolMessenger, error) { + pm := &ProtocolMessenger{ + m: msgSender, + } + + for _, o := range opts { + if err := o(pm); err != nil { + return nil, err + } + } + + return pm, nil +} + +// MessageSender handles sending wire protocol messages to a given peer +type MessageSender interface { + // SendRequest sends a peer a message and waits for its response + SendRequest(ctx context.Context, p peer.ID, pmes *Message) (*Message, error) + // SendMessage sends a peer a message without waiting on a response + SendMessage(ctx context.Context, p peer.ID, pmes *Message) error +} + +// PutValue asks a peer to store the given key/value pair. +func (pm *ProtocolMessenger) PutValue(ctx context.Context, p peer.ID, rec *recpb.Record) error { + pmes := NewMessage(Message_PUT_VALUE, rec.Key, 0) + pmes.Record = rec + rpmes, err := pm.m.SendRequest(ctx, p, pmes) + if err != nil { + logger.Debugw("failed to put value to peer", "to", p, "key", internal.LoggableRecordKeyBytes(rec.Key), "error", err) + return err + } + + if !bytes.Equal(rpmes.GetRecord().Value, pmes.GetRecord().Value) { + const errStr = "value not put correctly" + logger.Infow(errStr, "put-message", pmes, "get-message", rpmes) + return errors.New(errStr) + } + + return nil +} + +// GetValue asks a peer for the value corresponding to the given key. Also returns the K closest peers to the key +// as described in GetClosestPeers. +func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string) (*recpb.Record, []*peer.AddrInfo, error) { + pmes := NewMessage(Message_GET_VALUE, []byte(key), 0) + respMsg, err := pm.m.SendRequest(ctx, p, pmes) + if err != nil { + return nil, nil, err + } + + // Perhaps we were given closer peers + peers := PBPeersToPeerInfos(respMsg.GetCloserPeers()) + + if rec := respMsg.GetRecord(); rec != nil { + // Success! We were given the value + logger.Debug("got value") + + // make sure record is valid. + err = pm.validator.Validate(string(rec.GetKey()), rec.GetValue()) + if err != nil { + logger.Debug("received invalid record (discarded)") + // return a sentinel to signify an invalid record was received + return nil, peers, internal.ErrInvalidRecord + } + return rec, peers, err + } + + if len(peers) > 0 { + return nil, peers, nil + } + + return nil, nil, routing.ErrNotFound +} + +// GetClosestPeers asks a peer to return the K (a DHT-wide parameter) DHT server peers closest in XOR space to the id +// Note: If the peer happens to know another peer whose peerID exactly matches the given id it will return that peer +// even if that peer is not a DHT server node. +func (pm *ProtocolMessenger) GetClosestPeers(ctx context.Context, p peer.ID, id peer.ID) ([]*peer.AddrInfo, error) { + pmes := NewMessage(Message_FIND_NODE, []byte(id), 0) + respMsg, err := pm.m.SendRequest(ctx, p, pmes) + if err != nil { + return nil, err + } + peers := PBPeersToPeerInfos(respMsg.GetCloserPeers()) + return peers, nil +} + +// PutProvider asks a peer to store that we are a provider for the given key. +func (pm *ProtocolMessenger) PutProvider(ctx context.Context, p peer.ID, key multihash.Multihash, host host.Host) error { + pi := peer.AddrInfo{ + ID: host.ID(), + Addrs: host.Addrs(), + } + + // TODO: We may want to limit the type of addresses in our provider records + // For example, in a WAN-only DHT prohibit sharing non-WAN addresses (e.g. 192.168.0.100) + if len(pi.Addrs) < 1 { + return fmt.Errorf("no known addresses for self, cannot put provider") + } + + pmes := NewMessage(Message_ADD_PROVIDER, key, 0) + pmes.ProviderPeers = RawPeerInfosToPBPeers([]peer.AddrInfo{pi}) + + return pm.m.SendMessage(ctx, p, pmes) +} + +// GetProviders asks a peer for the providers it knows of for a given key. Also returns the K closest peers to the key +// as described in GetClosestPeers. +func (pm *ProtocolMessenger) GetProviders(ctx context.Context, p peer.ID, key multihash.Multihash) ([]*peer.AddrInfo, []*peer.AddrInfo, error) { + pmes := NewMessage(Message_GET_PROVIDERS, key, 0) + respMsg, err := pm.m.SendRequest(ctx, p, pmes) + if err != nil { + return nil, nil, err + } + provs := PBPeersToPeerInfos(respMsg.GetProviderPeers()) + closerPeers := PBPeersToPeerInfos(respMsg.GetCloserPeers()) + return provs, closerPeers, nil +} + +// Ping sends a ping message to the passed peer and waits for a response. +func (pm *ProtocolMessenger) Ping(ctx context.Context, p peer.ID) error { + req := NewMessage(Message_PING, nil, 0) + resp, err := pm.m.SendRequest(ctx, p, req) + if err != nil { + return fmt.Errorf("sending request: %w", err) + } + if resp.Type != Message_PING { + return fmt.Errorf("got unexpected response type: %v", resp.Type) + } + return nil +} diff --git a/records.go b/records.go index adb28ce7d..bba505080 100644 --- a/records.go +++ b/records.go @@ -98,13 +98,12 @@ func (dht *IpfsDHT) getPublicKeyFromNode(ctx context.Context, p peer.ID) (ci.Pub // Get the key from the node itself pkkey := routing.KeyForPublicKey(p) - pmes, err := dht.getValueSingle(ctx, p, pkkey) + record, _, err := dht.protoMessenger.GetValue(ctx, p, pkkey) if err != nil { return nil, err } // node doesn't have key :( - record := pmes.GetRecord() if record == nil { return nil, fmt.Errorf("node %v not responding with its public key", p) } diff --git a/routing.go b/routing.go index 4d6ae990c..d14e3845f 100644 --- a/routing.go +++ b/routing.go @@ -14,7 +14,7 @@ import ( "github.com/ipfs/go-cid" u "github.com/ipfs/go-ipfs-util" - pb "github.com/libp2p/go-libp2p-kad-dht/pb" + "github.com/libp2p/go-libp2p-kad-dht/internal" "github.com/libp2p/go-libp2p-kad-dht/qpeerset" kb "github.com/libp2p/go-libp2p-kbucket" record "github.com/libp2p/go-libp2p-record" @@ -32,7 +32,7 @@ func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte, opts return routing.ErrNotSupported } - logger.Debugw("putting value", "key", loggableRecordKeyString(key)) + logger.Debugw("putting value", "key", internal.LoggableRecordKeyString(key)) // don't even allow local users to put bad values. if err := dht.Validator.Validate(key, value); err != nil { @@ -81,7 +81,7 @@ func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte, opts ID: p, }) - err := dht.putValueToPeer(ctx, p, rec) + err := dht.protoMessenger.PutValue(ctx, p, rec) if err != nil { logger.Debugf("failed putting value to peer: %s", err) } @@ -128,7 +128,7 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Op if best == nil { return nil, routing.ErrNotFound } - logger.Debugf("GetValue %v %x", loggableRecordKeyString(key), best) + logger.Debugf("GetValue %v %x", internal.LoggableRecordKeyString(key), best) return best, nil } @@ -247,7 +247,7 @@ loop: } sel, err := dht.Validator.Select(key, [][]byte{best, v.Val}) if err != nil { - logger.Warnw("failed to select best value", "key", loggableRecordKeyString(key), "error", err) + logger.Warnw("failed to select best value", "key", internal.LoggableRecordKeyString(key), "error", err) continue } if sel != 1 { @@ -281,7 +281,7 @@ func (dht *IpfsDHT) updatePeerValues(ctx context.Context, key string, val []byte } ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() - err := dht.putValueToPeer(ctx, p, fixupRec) + err := dht.protoMessenger.PutValue(ctx, p, fixupRec) if err != nil { logger.Debug("Error correcting DHT entry: ", err) } @@ -293,7 +293,7 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, stopQuery chan st valCh := make(chan RecvdVal, 1) lookupResCh := make(chan *lookupWithFollowupResult, 1) - logger.Debugw("finding value", "key", loggableRecordKeyString(key)) + logger.Debugw("finding value", "key", internal.LoggableRecordKeyString(key)) if rec, err := dht.getLocal(key); rec != nil && err == nil { select { @@ -316,7 +316,7 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, stopQuery chan st ID: p, }) - rec, peers, err := dht.getValueOrPeers(ctx, p, key) + rec, peers, err := dht.protoMessenger.GetValue(ctx, p, key) switch err { case routing.ErrNotFound: // in this case, they responded with nothing, @@ -329,7 +329,7 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, stopQuery chan st return nil, err default: return nil, err - case nil, errInvalidRecord: + case nil, internal.ErrInvalidRecord: // in either of these cases, we want to keep going } @@ -399,7 +399,7 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err return fmt.Errorf("invalid cid: undefined") } keyMH := key.Hash() - logger.Debugw("providing", "cid", key, "mh", loggableProviderRecordBytes(keyMH)) + logger.Debugw("providing", "cid", key, "mh", internal.LoggableProviderRecordBytes(keyMH)) // add self locally dht.ProviderManager.AddProvider(ctx, keyMH, dht.self) @@ -444,18 +444,13 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err return err } - mes, err := dht.makeProvRecord(keyMH) - if err != nil { - return err - } - wg := sync.WaitGroup{} for p := range peers { wg.Add(1) go func(p peer.ID) { defer wg.Done() - logger.Debugf("putProvider(%s, %s)", loggableProviderRecordBytes(keyMH), p) - err := dht.sendMessage(ctx, p, mes) + logger.Debugf("putProvider(%s, %s)", internal.LoggableProviderRecordBytes(keyMH), p) + err := dht.protoMessenger.PutProvider(ctx, p, keyMH, dht.host) if err != nil { logger.Debug(err) } @@ -467,22 +462,6 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err } return ctx.Err() } -func (dht *IpfsDHT) makeProvRecord(key []byte) (*pb.Message, error) { - pi := peer.AddrInfo{ - ID: dht.self, - Addrs: dht.host.Addrs(), - } - - // // only share WAN-friendly addresses ?? - // pi.Addrs = addrutil.WANShareableAddrs(pi.Addrs) - if len(pi.Addrs) < 1 { - return nil, fmt.Errorf("no known addresses for self, cannot put provider") - } - - pmes := pb.NewMessage(pb.Message_ADD_PROVIDER, key, 0) - pmes.ProviderPeers = pb.RawPeerInfosToPBPeers([]peer.AddrInfo{pi}) - return pmes, nil -} // FindProviders searches until the context expires. func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) { @@ -519,7 +498,7 @@ func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count i keyMH := key.Hash() - logger.Debugw("finding providers", "cid", key, "mh", loggableProviderRecordBytes(keyMH)) + logger.Debugw("finding providers", "cid", key, "mh", internal.LoggableProviderRecordBytes(keyMH)) go dht.findProvidersAsyncRoutine(ctx, keyMH, count, peerOut) return peerOut } @@ -562,14 +541,12 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash ID: p, }) - pmes, err := dht.findProvidersSingle(ctx, p, key) + provs, closest, err := dht.protoMessenger.GetProviders(ctx, p, key) if err != nil { return nil, err } - logger.Debugf("%d provider entries", len(pmes.GetProviderPeers())) - provs := pb.PBPeersToPeerInfos(pmes.GetProviderPeers()) - logger.Debugf("%d provider entries decoded", len(provs)) + logger.Debugf("%d provider entries", len(provs)) // Add unique providers from request, up to 'count' for _, prov := range provs { @@ -591,17 +568,15 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash } // Give closer peers back to the query to be queried - closer := pmes.GetCloserPeers() - peers := pb.PBPeersToPeerInfos(closer) - logger.Debugf("got closer peers: %d %s", len(peers), peers) + logger.Debugf("got closer peers: %d %s", len(closest), closest) routing.PublishQueryEvent(ctx, &routing.QueryEvent{ Type: routing.PeerResponse, ID: p, - Responses: peers, + Responses: closest, }) - return peers, nil + return closest, nil }, func() bool { return !findAll && ps.Size() >= count @@ -634,12 +609,11 @@ func (dht *IpfsDHT) FindPeer(ctx context.Context, id peer.ID) (_ peer.AddrInfo, ID: p, }) - pmes, err := dht.findPeerSingle(ctx, p, id) + peers, err := dht.protoMessenger.GetClosestPeers(ctx, p, id) if err != nil { logger.Debugf("error getting closer peers: %s", err) return nil, err } - peers := pb.PBPeersToPeerInfos(pmes.GetCloserPeers()) // For DHT query command routing.PublishQueryEvent(ctx, &routing.QueryEvent{ diff --git a/subscriber_notifee.go b/subscriber_notifee.go index 8211d25de..7cc9018f7 100644 --- a/subscriber_notifee.go +++ b/subscriber_notifee.go @@ -173,22 +173,7 @@ func (nn *subscriberNotifee) Disconnected(n network.Network, v network.Conn) { return } - dht.smlk.Lock() - defer dht.smlk.Unlock() - ms, ok := dht.strmap[p] - if !ok { - return - } - delete(dht.strmap, p) - - // Do this asynchronously as ms.lk can block for a while. - go func() { - if err := ms.lk.Lock(dht.Context()); err != nil { - return - } - defer ms.lk.Unlock() - ms.invalidate() - }() + dht.msgSender.streamDisconnect(dht.Context(), p) } func (nn *subscriberNotifee) Connected(network.Network, network.Conn) {}