diff --git a/core/peerstore/peerstore.go b/core/peerstore/peerstore.go index 10469e72cb..6366026c9d 100644 --- a/core/peerstore/peerstore.go +++ b/core/peerstore/peerstore.go @@ -31,6 +31,7 @@ var ( RecentlyConnectedAddrTTL = time.Minute * 15 // OwnObservedAddrTTL is used for our own external addresses observed by peers. + // // Deprecated: observed addresses are maintained till we disconnect from the peer which provided it OwnObservedAddrTTL = time.Minute * 30 ) @@ -65,6 +66,10 @@ type Peerstore interface { // Peers returns all the peer IDs stored across all inner stores. Peers() peer.IDSlice + + // RemovePeer removes all the peer related information except its addresses. To remove the + // addresses use `AddrBook.ClearAddrs` or set the address ttls to 0. + RemovePeer(peer.ID) } // PeerMetadata can handle values of any type. Serializing values is @@ -134,12 +139,13 @@ type AddrBook interface { // } type CertifiedAddrBook interface { // ConsumePeerRecord stores a signed peer record and the contained addresses for - // for ttl duration. + // ttl duration. // The addresses contained in the signed peer record will expire after ttl. If any // address is already present in the peer store, it'll expire at max of existing ttl and // provided ttl. // The signed peer record itself will be expired when all the addresses associated with the peer, // self-certified or not, are removed from the AddrBook. + // // To delete the signed peer record, use `AddrBook.UpdateAddrs`,`AddrBook.SetAddrs`, or // `AddrBook.ClearAddrs` with ttl 0. // Note: Future calls to ConsumePeerRecord will not expire self-certified addresses from the diff --git a/p2p/host/peerstore/pstoremem/addr_book.go b/p2p/host/peerstore/pstoremem/addr_book.go index 89b87bdb47..b528eb6989 100644 --- a/p2p/host/peerstore/pstoremem/addr_book.go +++ b/p2p/host/peerstore/pstoremem/addr_book.go @@ -3,13 +3,14 @@ package pstoremem import ( "container/heap" "context" + "errors" "fmt" "sort" "sync" "time" "github.com/libp2p/go-libp2p/core/peer" - pstore "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/record" logging "github.com/ipfs/go-log/v2" @@ -23,7 +24,7 @@ type expiringAddr struct { TTL time.Duration Expiry time.Time Peer peer.ID - // to sort by expiry time + // to sort by expiry time, -1 means it's not in the heap heapIndex int } @@ -31,6 +32,32 @@ func (e *expiringAddr) ExpiredBy(t time.Time) bool { return !t.Before(e.Expiry) } +func (e *expiringAddr) IsConnected() bool { + return ttlIsConnected(e.TTL) +} + +// ttlIsConnected returns true if the TTL is at least as long as the connected +// TTL. +func ttlIsConnected(ttl time.Duration) bool { + return ttl >= peerstore.ConnectedAddrTTL +} + +var expiringAddrPool = sync.Pool{New: func() any { return &expiringAddr{} }} + +func getExpiringAddrs() *expiringAddr { + a := expiringAddrPool.Get().(*expiringAddr) + a.heapIndex = -1 + return a +} + +func putExpiringAddrs(ea *expiringAddr) { + if ea == nil { + return + } + *ea = expiringAddr{} + expiringAddrPool.Put(ea) +} + type peerRecordState struct { Envelope *record.Envelope Seq uint64 @@ -40,7 +67,9 @@ type peerRecordState struct { var _ heap.Interface = &peerAddrs{} type peerAddrs struct { - Addrs map[peer.ID]map[string]*expiringAddr // peer.ID -> addr.Bytes() -> *expiringAddr + Addrs map[peer.ID]map[string]*expiringAddr // peer.ID -> addr.Bytes() -> *expiringAddr + // expiringHeap only stores non-connected addresses. Since connected address + // basically have an infinite TTL expiringHeap []*expiringAddr } @@ -61,10 +90,6 @@ func (pa *peerAddrs) Swap(i, j int) { } func (pa *peerAddrs) Push(x any) { a := x.(*expiringAddr) - if _, ok := pa.Addrs[a.Peer]; !ok { - pa.Addrs[a.Peer] = make(map[string]*expiringAddr) - } - pa.Addrs[a.Peer][string(a.Addr.Bytes())] = a a.heapIndex = len(pa.expiringHeap) pa.expiringHeap = append(pa.expiringHeap, a) } @@ -72,34 +97,24 @@ func (pa *peerAddrs) Pop() any { a := pa.expiringHeap[len(pa.expiringHeap)-1] a.heapIndex = -1 pa.expiringHeap = pa.expiringHeap[0 : len(pa.expiringHeap)-1] - - if m, ok := pa.Addrs[a.Peer]; ok { - delete(m, string(a.Addr.Bytes())) - if len(m) == 0 { - delete(pa.Addrs, a.Peer) - } - } return a } -func (pa *peerAddrs) Fix(a *expiringAddr) { - heap.Fix(pa, a.heapIndex) -} - func (pa *peerAddrs) Delete(a *expiringAddr) { - heap.Remove(pa, a.heapIndex) - a.heapIndex = -1 - if m, ok := pa.Addrs[a.Peer]; ok { - delete(m, string(a.Addr.Bytes())) - if len(m) == 0 { + if ea, ok := pa.Addrs[a.Peer][string(a.Addr.Bytes())]; ok { + if ea.heapIndex != -1 { + heap.Remove(pa, a.heapIndex) + } + delete(pa.Addrs[a.Peer], string(a.Addr.Bytes())) + if len(pa.Addrs[a.Peer]) == 0 { delete(pa.Addrs, a.Peer) } } } -func (pa *peerAddrs) FindAddr(p peer.ID, addrBytes ma.Multiaddr) (*expiringAddr, bool) { +func (pa *peerAddrs) FindAddr(p peer.ID, addr ma.Multiaddr) (*expiringAddr, bool) { if m, ok := pa.Addrs[p]; ok { - v, ok := m[string(addrBytes.Bytes())] + v, ok := m[string(addr.Bytes())] return v, ok } return nil, false @@ -115,12 +130,44 @@ func (pa *peerAddrs) NextExpiry() time.Time { func (pa *peerAddrs) PopIfExpired(now time.Time) (*expiringAddr, bool) { // Use `!Before` instead of `After` to ensure that we expire *at* now, and not *just after now*. if len(pa.expiringHeap) > 0 && !now.Before(pa.NextExpiry()) { - a := heap.Pop(pa) - return a.(*expiringAddr), true + ea := heap.Pop(pa).(*expiringAddr) + delete(pa.Addrs[ea.Peer], string(ea.Addr.Bytes())) + if len(pa.Addrs[ea.Peer]) == 0 { + delete(pa.Addrs, ea.Peer) + } + return ea, true } return nil, false } +func (pa *peerAddrs) Update(a *expiringAddr) { + if a.heapIndex == -1 { + return + } + if a.IsConnected() { + heap.Remove(pa, a.heapIndex) + } else { + heap.Fix(pa, a.heapIndex) + } +} + +func (pa *peerAddrs) Insert(a *expiringAddr) { + a.heapIndex = -1 + if _, ok := pa.Addrs[a.Peer]; !ok { + pa.Addrs[a.Peer] = make(map[string]*expiringAddr) + } + pa.Addrs[a.Peer][string(a.Addr.Bytes())] = a + // don't add connected addr to heap. + if a.IsConnected() { + return + } + heap.Push(pa, a) +} + +func (pa *peerAddrs) NumUnconnectedAddrs() int { + return len(pa.expiringHeap) +} + type clock interface { Now() time.Time } @@ -131,12 +178,18 @@ func (rc realclock) Now() time.Time { return time.Now() } +const ( + defaultMaxSignedPeerRecords = 100_000 + defaultMaxUnconnectedAddrs = 1_000_000 +) + // memoryAddrBook manages addresses. type memoryAddrBook struct { - mu sync.RWMutex - // TODO bound the number of not connected addresses we store. - addrs peerAddrs - signedPeerRecords map[peer.ID]*peerRecordState + mu sync.RWMutex + addrs peerAddrs + signedPeerRecords map[peer.ID]*peerRecordState + maxUnconnectedAddrs int + maxSignedPeerRecords int refCount sync.WaitGroup cancel func() @@ -145,18 +198,20 @@ type memoryAddrBook struct { clock clock } -var _ pstore.AddrBook = (*memoryAddrBook)(nil) -var _ pstore.CertifiedAddrBook = (*memoryAddrBook)(nil) +var _ peerstore.AddrBook = (*memoryAddrBook)(nil) +var _ peerstore.CertifiedAddrBook = (*memoryAddrBook)(nil) func NewAddrBook() *memoryAddrBook { ctx, cancel := context.WithCancel(context.Background()) ab := &memoryAddrBook{ - addrs: newPeerAddrs(), - signedPeerRecords: make(map[peer.ID]*peerRecordState), - subManager: NewAddrSubManager(), - cancel: cancel, - clock: realclock{}, + addrs: newPeerAddrs(), + signedPeerRecords: make(map[peer.ID]*peerRecordState), + subManager: NewAddrSubManager(), + cancel: cancel, + clock: realclock{}, + maxUnconnectedAddrs: defaultMaxUnconnectedAddrs, + maxSignedPeerRecords: defaultMaxUnconnectedAddrs, } ab.refCount.Add(1) go ab.background(ctx) @@ -172,6 +227,23 @@ func WithClock(clock clock) AddrBookOption { } } +// WithMaxAddresses sets the maximum number of unconnected addresses to store. +// The maximum number of connected addresses is bounded by the connection +// limits in the Connection Manager and Resource Manager. +func WithMaxAddresses(n int) AddrBookOption { + return func(b *memoryAddrBook) error { + b.maxUnconnectedAddrs = n + return nil + } +} + +func WithMaxSignedPeerRecords(n int) AddrBookOption { + return func(b *memoryAddrBook) error { + b.maxSignedPeerRecords = n + return nil + } +} + // background periodically schedules a gc func (mab *memoryAddrBook) background(ctx context.Context) { defer mab.refCount.Done() @@ -204,6 +276,7 @@ func (mab *memoryAddrBook) gc() { if !ok { return } + putExpiringAddrs(ea) mab.maybeDeleteSignedPeerRecordUnlocked(ea.Peer) } } @@ -252,6 +325,10 @@ func (mab *memoryAddrBook) ConsumePeerRecord(recordEnvelope *record.Envelope, tt if found && lastState.Seq > rec.Seq { return false, nil } + // check if we are over the max signed peer record limit + if !found && len(mab.signedPeerRecords) >= mab.maxSignedPeerRecords { + return false, errors.New("too many signed peer records") + } mab.signedPeerRecords[rec.PeerID] = &peerRecordState{ Envelope: recordEnvelope, Seq: rec.Seq, @@ -281,6 +358,11 @@ func (mab *memoryAddrBook) addAddrsUnlocked(p peer.ID, addrs []ma.Multiaddr, ttl return } + // we are over limit, drop these addrs. + if !ttlIsConnected(ttl) && mab.addrs.NumUnconnectedAddrs() >= mab.maxUnconnectedAddrs { + return + } + exp := mab.clock.Now().Add(ttl) for _, addr := range addrs { // Remove suffix of /p2p/peer-id from address @@ -296,8 +378,9 @@ func (mab *memoryAddrBook) addAddrsUnlocked(p peer.ID, addrs []ma.Multiaddr, ttl a, found := mab.addrs.FindAddr(p, addr) if !found { // not found, announce it. - entry := &expiringAddr{Addr: addr, Expiry: exp, TTL: ttl, Peer: p} - heap.Push(&mab.addrs, entry) + entry := getExpiringAddrs() + *entry = expiringAddr{Addr: addr, Expiry: exp, TTL: ttl, Peer: p} + mab.addrs.Insert(entry) mab.subManager.BroadcastAddr(p, addr) } else { // update ttl & exp to whichever is greater between new and existing entry @@ -311,7 +394,7 @@ func (mab *memoryAddrBook) addAddrsUnlocked(p peer.ID, addrs []ma.Multiaddr, ttl a.Expiry = exp } if changed { - mab.addrs.Fix(a) + mab.addrs.Update(a) } } } @@ -344,17 +427,28 @@ func (mab *memoryAddrBook) SetAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Du if a, found := mab.addrs.FindAddr(p, addr); found { if ttl > 0 { - a.Addr = addr - a.Expiry = exp - a.TTL = ttl - mab.addrs.Fix(a) - mab.subManager.BroadcastAddr(p, addr) + if a.IsConnected() && !ttlIsConnected(ttl) && mab.addrs.NumUnconnectedAddrs() >= mab.maxUnconnectedAddrs { + mab.addrs.Delete(a) + putExpiringAddrs(a) + } else { + a.Addr = addr + a.Expiry = exp + a.TTL = ttl + mab.addrs.Update(a) + mab.subManager.BroadcastAddr(p, addr) + } } else { mab.addrs.Delete(a) + putExpiringAddrs(a) } } else { if ttl > 0 { - heap.Push(&mab.addrs, &expiringAddr{Addr: addr, Expiry: exp, TTL: ttl, Peer: p}) + if !ttlIsConnected(ttl) && mab.addrs.NumUnconnectedAddrs() >= mab.maxUnconnectedAddrs { + continue + } + entry := getExpiringAddrs() + *entry = expiringAddr{Addr: addr, Expiry: exp, TTL: ttl, Peer: p} + mab.addrs.Insert(entry) mab.subManager.BroadcastAddr(p, addr) } } @@ -374,10 +468,17 @@ func (mab *memoryAddrBook) UpdateAddrs(p peer.ID, oldTTL time.Duration, newTTL t if oldTTL == a.TTL { if newTTL == 0 { mab.addrs.Delete(a) + putExpiringAddrs(a) } else { - a.TTL = newTTL - a.Expiry = exp - mab.addrs.Fix(a) + // We are over limit, drop these addresses. + if ttlIsConnected(oldTTL) && !ttlIsConnected(newTTL) && mab.addrs.NumUnconnectedAddrs() >= mab.maxUnconnectedAddrs { + mab.addrs.Delete(a) + putExpiringAddrs(a) + } else { + a.TTL = newTTL + a.Expiry = exp + mab.addrs.Update(a) + } } } } @@ -436,6 +537,7 @@ func (mab *memoryAddrBook) ClearAddrs(p peer.ID) { delete(mab.signedPeerRecords, p) for _, a := range mab.addrs.Addrs[p] { mab.addrs.Delete(a) + putExpiringAddrs(a) } } diff --git a/p2p/host/peerstore/pstoremem/addr_book_test.go b/p2p/host/peerstore/pstoremem/addr_book_test.go index 963c4552cf..e8ba89ff93 100644 --- a/p2p/host/peerstore/pstoremem/addr_book_test.go +++ b/p2p/host/peerstore/pstoremem/addr_book_test.go @@ -22,8 +22,8 @@ func TestPeerAddrsNextExpiry(t *testing.T) { // t1 is before t2 t1 := time.Time{}.Add(1 * time.Second) t2 := time.Time{}.Add(2 * time.Second) - heap.Push(pa, &expiringAddr{Addr: a1, Expiry: t1, TTL: 10 * time.Second, Peer: "p1"}) - heap.Push(pa, &expiringAddr{Addr: a2, Expiry: t2, TTL: 10 * time.Second, Peer: "p2"}) + paa.Insert(&expiringAddr{Addr: a1, Expiry: t1, TTL: 10 * time.Second, Peer: "p1"}) + paa.Insert(&expiringAddr{Addr: a2, Expiry: t2, TTL: 10 * time.Second, Peer: "p2"}) if pa.NextExpiry() != t1 { t.Fatal("expiry should be set to t1, got", pa.NextExpiry()) @@ -49,7 +49,7 @@ func TestPeerAddrsHeapProperty(t *testing.T) { const N = 10000 expiringAddrs := peerAddrsInput(N) for i := 0; i < N; i++ { - heap.Push(pa, expiringAddrs[i]) + paa.Insert(expiringAddrs[i]) } for i := 0; i < N; i++ { @@ -70,7 +70,7 @@ func TestPeerAddrsHeapPropertyDeletions(t *testing.T) { const N = 10000 expiringAddrs := peerAddrsInput(N) for i := 0; i < N; i++ { - heap.Push(pa, expiringAddrs[i]) + paa.Insert(expiringAddrs[i]) } // delete every 3rd element @@ -108,7 +108,7 @@ func TestPeerAddrsHeapPropertyUpdates(t *testing.T) { var endElements []ma.Multiaddr for i := 0; i < N; i += 3 { expiringAddrs[i].Expiry = time.Time{}.Add(1000_000 * time.Second) - pa.Fix(expiringAddrs[i]) + pa.Update(expiringAddrs[i]) endElements = append(endElements, expiringAddrs[i].Addr) } @@ -148,7 +148,7 @@ func TestPeerAddrsExpiry(t *testing.T) { expiringAddrs[i].Expiry = time.Time{}.Add(time.Duration(1+rand.Intn(N)) * time.Second) } for i := 0; i < N; i++ { - heap.Push(pa, expiringAddrs[i]) + pa.Insert(expiringAddrs[i]) } expiry := time.Time{}.Add(time.Duration(1+rand.Intn(N)) * time.Second) @@ -174,6 +174,18 @@ func TestPeerAddrsExpiry(t *testing.T) { } } +func TestPeerLimits(t *testing.T) { + ab := NewAddrBook() + defer ab.Close() + ab.maxUnconnectedAddrs = 1024 + + peers := peerAddrsInput(2048) + for _, p := range peers { + ab.AddAddr(p.Peer, p.Addr, p.TTL) + } + require.Equal(t, 1024, ab.addrs.NumUnconnectedAddrs()) +} + func BenchmarkPeerAddrs(b *testing.B) { sizes := [...]int{1, 10, 100, 1000, 10_000, 100_000, 1000_000} for _, sz := range sizes { @@ -184,7 +196,7 @@ func BenchmarkPeerAddrs(b *testing.B) { pa := &paa expiringAddrs := peerAddrsInput(sz) for i := 0; i < sz; i++ { - heap.Push(pa, expiringAddrs[i]) + pa.Insert(expiringAddrs[i]) } b.StartTimer() for {