diff --git a/p2p/net/mock/mock_peernet.go b/p2p/net/mock/mock_peernet.go index 2e56b7f2bb..5f8c63c73c 100644 --- a/p2p/net/mock/mock_peernet.go +++ b/p2p/net/mock/mock_peernet.go @@ -434,3 +434,7 @@ func (pn *peernet) notifyAll(notification func(f network.Notifiee)) { func (pn *peernet) ResourceManager() network.ResourceManager { return &network.NullResourceManager{} } + +func (pn *peernet) CanDial(addr ma.Multiaddr) bool { + return true +} diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go index 889e2b191a..4c93073ed3 100644 --- a/p2p/protocol/autonatv2/autonat.go +++ b/p2p/protocol/autonatv2/autonat.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pbv2" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "golang.org/x/exp/rand" @@ -29,19 +30,27 @@ type AutoNAT struct { allowAllAddrs bool // for testing } -func New(h host.Host, dialer host.Host) (*AutoNAT, error) { +func New(h host.Host, dialer host.Host, opts ...AutoNATOption) (*AutoNAT, error) { + s := defaultSettings() + for _, o := range opts { + if err := o(s); err != nil { + return nil, err + } + } sub, err := h.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged)) if err != nil { return nil, fmt.Errorf("failed to subscribe to event.EvtLocalReachabilityChanged: %w", err) } ctx, cancel := context.WithCancel(context.Background()) + an := &AutoNAT{ - host: h, - ctx: ctx, - cancel: cancel, - sub: sub, - srv: &Server{dialer: dialer, host: h}, - cli: NewClient(h), + host: h, + ctx: ctx, + cancel: cancel, + sub: sub, + srv: NewServer(h, dialer, s), + cli: NewClient(h), + allowAllAddrs: s.allowAllAddrs, } an.cli.Register() @@ -54,7 +63,7 @@ func (an *AutoNAT) background() { for { select { case <-an.ctx.Done(): - an.srv.Stop() + an.srv.Disable() an.wg.Done() return case evt := <-an.sub.Out(): @@ -64,9 +73,9 @@ func (an *AutoNAT) background() { log.Errorf("Unexpected event %s of type %T", evt, evt) } if revt.Reachability == network.ReachabilityPrivate { - an.srv.Stop() + an.srv.Disable() } else { - an.srv.Start() + an.srv.Enable() } } } @@ -119,9 +128,9 @@ func (an *AutoNAT) validPeer() peer.ID { } type Result struct { - Addr ma.Multiaddr - Rch network.Reachability - Err error + Addr ma.Multiaddr + Reachability network.Reachability + Status pbv2.DialStatus } var ( diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index 087d85cb7c..c50c055faa 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" @@ -19,11 +20,13 @@ import ( "github.com/stretchr/testify/require" ) -func newAutoNAT(t *testing.T) *AutoNAT { +func newAutoNAT(t *testing.T, dialer host.Host, opts ...AutoNATOption) *AutoNAT { t.Helper() h := bhost.NewBlankHost(swarmt.GenSwarm(t)) - dialer := bhost.NewBlankHost(swarmt.GenSwarm(t)) - an, err := New(h, dialer) + if dialer == nil { + dialer = bhost.NewBlankHost(swarmt.GenSwarm(t)) + } + an, err := New(h, dialer, opts...) if err != nil { t.Error(err) } @@ -45,7 +48,7 @@ func parseAddrs(t *testing.T, msg *pbv2.Message) []ma.Multiaddr { } func TestValidPeer(t *testing.T) { - an := newAutoNAT(t) + an := newAutoNAT(t, nil) require.Equal(t, an.validPeer(), peer.ID("")) an.host.Peerstore().AddAddr("peer1", ma.StringCast("/ip4/127.0.0.1/tcp/1"), peerstore.PermanentAddrTTL) an.host.Peerstore().AddAddr("peer2", ma.StringCast("/ip4/127.0.0.1/tcp/2"), peerstore.PermanentAddrTTL) @@ -72,15 +75,14 @@ func TestValidPeer(t *testing.T) { } func TestAutoNATPrivateAddr(t *testing.T) { - an := newAutoNAT(t) + an := newAutoNAT(t, nil) res, err := an.CheckReachability(context.Background(), []ma.Multiaddr{ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}, nil) require.Nil(t, res) require.NotNil(t, err) } func TestClientRequest(t *testing.T) { - an := newAutoNAT(t) - an.allowAllAddrs = true + an := newAutoNAT(t, nil) addrs := an.host.Addrs() @@ -111,8 +113,7 @@ func TestClientRequest(t *testing.T) { } func TestClientServerError(t *testing.T) { - an := newAutoNAT(t) - an.allowAllAddrs = true + an := newAutoNAT(t, nil, allowAll) addrs := an.host.Addrs() p := bhost.NewBlankHost(swarmt.GenSwarm(t)) @@ -122,7 +123,10 @@ func TestClientServerError(t *testing.T) { tests := []struct { handler func(network.Stream) }{ - {handler: func(s network.Stream) { s.Reset(); done <- true }}, + {handler: func(s network.Stream) { + s.Reset() + done <- true + }}, {handler: func(s network.Stream) { r := pbio.NewDelimitedReader(s, maxMsgSize) var msg pbv2.Message @@ -156,8 +160,7 @@ func TestClientServerError(t *testing.T) { } func TestClientDataRequest(t *testing.T) { - an := newAutoNAT(t) - an.allowAllAddrs = true + an := newAutoNAT(t, nil, allowAll) addrs := an.host.Addrs() p := bhost.NewBlankHost(swarmt.GenSwarm(t)) @@ -230,8 +233,7 @@ func TestClientDataRequest(t *testing.T) { } func TestClientDialAttempts(t *testing.T) { - an := newAutoNAT(t) - an.allowAllAddrs = true + an := newAutoNAT(t, nil, allowAll) addrs := an.host.Addrs() p := bhost.NewBlankHost(swarmt.GenSwarm(t)) @@ -245,6 +247,9 @@ func TestClientDialAttempts(t *testing.T) { }{ { handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pbv2.Message + r.ReadMsg(&msg) resp := &pbv2.DialResponse{ Status: pbv2.DialResponse_ResponseStatus_OK, DialStatuses: []pbv2.DialStatus{pbv2.DialStatus_OK}, @@ -399,13 +404,13 @@ func TestClientDialAttempts(t *testing.T) { require.NoError(t, err) if !tc.success { for i := 0; i < len(res); i++ { - require.Error(t, res[i].Err) - require.Equal(t, res[i].Rch, network.ReachabilityUnknown) + require.NotEqual(t, res[i].Status, pbv2.DialStatus_OK) + require.Equal(t, res[i].Reachability, network.ReachabilityUnknown) } } else { success := false for i := 0; i < len(res); i++ { - if res[i].Rch == network.ReachabilityPublic { + if res[i].Reachability == network.ReachabilityPublic { success = true break } diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go index 8b84a21649..9361b3808a 100644 --- a/p2p/protocol/autonatv2/client.go +++ b/p2p/protocol/autonatv2/client.go @@ -144,27 +144,26 @@ func (ac *Client) newResults(ds []pbv2.DialStatus, highPriorityAddrs []ma.Multia } else { addr = lowPriorityAddrs[i-len(highPriorityAddrs)] } - err := ErrDialNotAttempted rch := network.ReachabilityUnknown + status := pbv2.DialStatus_SKIPPED if i < len(ds) { switch ds[i] { case pbv2.DialStatus_OK: if areAddrsConsistent(attempt, addr) { - err = nil + status = pbv2.DialStatus_OK rch = network.ReachabilityPublic } else { - err = errors.New("attempt error") + status = pbv2.DialStatus_E_ATTEMPT_ERROR rch = network.ReachabilityUnknown } case pbv2.DialStatus_E_DIAL_ERROR: - err = errors.New("dial failed") rch = network.ReachabilityPrivate default: - err = errors.New("other") + status = ds[i] rch = network.ReachabilityUnknown } } - res[i] = Result{Addr: addr, Rch: rch, Err: err} + res[i] = Result{Addr: addr, Reachability: rch, Status: status} } return res } diff --git a/p2p/protocol/autonatv2/options.go b/p2p/protocol/autonatv2/options.go new file mode 100644 index 0000000000..c08919e7b2 --- /dev/null +++ b/p2p/protocol/autonatv2/options.go @@ -0,0 +1,50 @@ +package autonatv2 + +import "time" + +type autoNATSettings struct { + allowAllAddrs bool + serverRPM int + serverRPMPerPeer int + dataRequestPolicy dataRequestPolicyFunc + now func() time.Time +} + +func defaultSettings() *autoNATSettings { + return &autoNATSettings{ + allowAllAddrs: false, + serverRPM: 20, + serverRPMPerPeer: 2, + dataRequestPolicy: defaultDataRequestPolicy, + now: time.Now, + } +} + +type AutoNATOption func(s *autoNATSettings) error + +func allowAll(s *autoNATSettings) error { + s.allowAllAddrs = true + return nil +} + +func WithServerRateLimit(rpm, rpmPerPeer int) AutoNATOption { + return func(s *autoNATSettings) error { + s.serverRPM = rpm + s.serverRPMPerPeer = rpmPerPeer + return nil + } +} + +func WithDataRequestPolicy(drp dataRequestPolicyFunc) AutoNATOption { + return func(s *autoNATSettings) error { + s.dataRequestPolicy = drp + return nil + } +} + +func WithNow(now func() time.Time) AutoNATOption { + return func(s *autoNATSettings) error { + s.now = now + return nil + } +} diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go index a78e429925..553ac341d0 100644 --- a/p2p/protocol/autonatv2/server.go +++ b/p2p/protocol/autonatv2/server.go @@ -3,6 +3,7 @@ package autonatv2 import ( "context" "fmt" + "sync" "time" "github.com/libp2p/go-libp2p/core/host" @@ -11,7 +12,7 @@ import ( "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pbv2" "github.com/libp2p/go-msgio/pbio" - "github.com/multiformats/go-multiaddr" + ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "golang.org/x/exp/rand" @@ -26,21 +27,30 @@ type Server struct { host host.Host dataRequestPolicy dataRequestPolicyFunc allowAllAddrs bool + limiter *rateLimiter + now func() time.Time // for tests } -func NewServer(host, dialer host.Host, dataRequestPolicy dataRequestPolicyFunc, allowAllAddrs bool) *Server { - drp := defaultDataRequestPolicy - if dataRequestPolicy != nil { - drp = dataRequestPolicy +func NewServer(host, dialer host.Host, s *autoNATSettings) *Server { + return &Server{ + dialer: dialer, + host: host, + dataRequestPolicy: s.dataRequestPolicy, + allowAllAddrs: s.allowAllAddrs, + limiter: &rateLimiter{ + RPM: s.serverRPM, + RPMPerPeer: s.serverRPMPerPeer, + now: s.now, + }, + now: s.now, } - return &Server{dialer: dialer, host: host, dataRequestPolicy: drp, allowAllAddrs: allowAllAddrs} } -func (as *Server) Start() { +func (as *Server) Enable() { as.host.SetStreamHandler(DialProtocol, as.handleDialRequest) } -func (as *Server) Stop() { +func (as *Server) Disable() { as.host.RemoveStreamHandler(DialProtocol) } @@ -58,7 +68,7 @@ func (as *Server) handleDialRequest(s network.Stream) { } defer s.Scope().ReleaseMemory(maxMsgSize) - s.SetDeadline(time.Now().Add(time.Minute)) + s.SetDeadline(as.now().Add(time.Minute)) defer s.Close() r := pbio.NewDelimitedReader(s, maxMsgSize) @@ -73,11 +83,17 @@ func (as *Server) handleDialRequest(s network.Stream) { log.Debugf("invalid message type: %T", msg.Msg) return } + if !as.limiter.Accept(s.Conn().RemotePeer()) { + s.Reset() + log.Debugf("rate limited request from %s", s.Conn().RemotePeer()) + return + } + nonce := msg.GetDialRequest().Nonce statuses := make([]pbv2.DialStatus, 0, len(msg.GetDialRequest().GetAddrs())) var dialAddr ma.Multiaddr for _, ab := range msg.GetDialRequest().GetAddrs() { - a, err := multiaddr.NewMultiaddrBytes(ab) + a, err := ma.NewMultiaddrBytes(ab) if err != nil { statuses = append(statuses, pbv2.DialStatus_E_ADDRESS_UNKNOWN) continue @@ -101,7 +117,11 @@ func (as *Server) handleDialRequest(s network.Stream) { w := pbio.NewDelimitedWriter(s) if dialAddr == nil { msg := getResponseMsg(pbv2.DialResponse_ResponseStatus_OK, statuses) - w.WriteMsg(msg) + if err := w.WriteMsg(msg); err != nil { + log.Debugf("failed to write response: %s", err) + s.Reset() + return + } return } @@ -109,8 +129,8 @@ func (as *Server) handleDialRequest(s network.Stream) { msg.Reset() err := getDialData(w, r, len(statuses)) if err != nil { - s.Reset() log.Debugf("peer refused data request: %s", err) + s.Reset() return } } @@ -174,7 +194,7 @@ func (as *Server) attemptDial(p peer.ID, addr ma.Multiaddr, nonce uint64) pbv2.D return pbv2.DialStatus_E_DIAL_ERROR } defer s.Close() - s.SetDeadline(time.Now().Add(5 * time.Second)) + s.SetDeadline(as.now().Add(5 * time.Second)) w := pbio.NewDelimitedWriter(s) if err := w.WriteMsg(&pbv2.DialAttempt{Nonce: nonce}); err != nil { @@ -183,7 +203,7 @@ func (as *Server) attemptDial(p peer.ID, addr ma.Multiaddr, nonce uint64) pbv2.D } // s.Close() here might discard the message s.CloseWrite() - s.SetDeadline(time.Now().Add(1 * time.Second)) + s.SetDeadline(as.now().Add(1 * time.Second)) b := make([]byte, 1) s.Read(b) @@ -200,3 +220,56 @@ func getResponseMsg(respStatus pbv2.DialResponse_ResponseStatus, statuses []pbv2 }, } } + +// rateLimiter implements a sliding window rate limit of requests per minute. +type rateLimiter struct { + RPMPerPeer int + RPM int + + mu sync.Mutex + reqs []time.Time + peerReqs map[peer.ID][]time.Time + now func() time.Time // for tests +} + +func (r *rateLimiter) Accept(p peer.ID) bool { + r.mu.Lock() + defer r.mu.Unlock() + if r.peerReqs == nil { + r.peerReqs = make(map[peer.ID][]time.Time) + } + + nw := r.now() + r.cleanup(p, nw) + + if len(r.reqs) >= r.RPM || len(r.peerReqs[p]) >= r.RPMPerPeer { + return false + } + r.reqs = append(r.reqs, nw) + r.peerReqs[p] = append(r.peerReqs[p], nw) + return true +} + +// cleanup removes stale requests. +// +// This is fast enough in rate limited cases and the state is small enough to +// clean up quickly when blocking requests. +func (r *rateLimiter) cleanup(p peer.ID, now time.Time) { + idx := len(r.reqs) + for i, t := range r.reqs { + if now.Sub(t).Minutes() <= 1 { + idx = i + break + } + } + r.reqs = r.reqs[idx:] + + idx = len(r.peerReqs[p]) + for i, t := range r.peerReqs[p] { + if now.Sub(t).Minutes() <= 1 { + idx = i + break + } + } + r.peerReqs[p] = r.peerReqs[p][idx:] +} diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go index 8853dc35df..b8055b723b 100644 --- a/p2p/protocol/autonatv2/server_test.go +++ b/p2p/protocol/autonatv2/server_test.go @@ -3,84 +3,84 @@ package autonatv2 import ( "context" "testing" + "time" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/test" bhost "github.com/libp2p/go-libp2p/p2p/host/blank" swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pbv2" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) +// identify provides server address and protocol to client +func identify(cli *AutoNAT, srv *AutoNAT) { + cli.host.Peerstore().AddAddrs(srv.host.ID(), srv.host.Addrs(), peerstore.PermanentAddrTTL) + cli.host.Peerstore().AddProtocols(srv.host.ID(), DialProtocol) +} + func TestServerAllAddrsInvalid(t *testing.T) { - h := bhost.NewBlankHost(swarmt.GenSwarm(t)) dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableQUIC, swarmt.OptDisableTCP)) - as, err := New(h, dialer) - require.NoError(t, err) - defer as.Close() - defer as.host.Close() + an := newAutoNAT(t, dialer, allowAll) + defer an.Close() + defer an.host.Close() + an.srv.Enable() - as.srv.Start() - - c := newAutoNAT(t) - c.allowAllAddrs = true + c := newAutoNAT(t, nil, allowAll) defer c.Close() defer c.host.Close() - c.host.Peerstore().AddAddrs(as.host.ID(), as.host.Addrs(), peerstore.PermanentAddrTTL) - c.host.Peerstore().AddProtocols(as.host.ID(), DialProtocol) + identify(c, an) res, err := c.CheckReachability(context.Background(), c.host.Addrs(), nil) require.NoError(t, err) for _, r := range res { - require.Error(t, r.Err) + require.Equal(t, r.Status, pbv2.DialStatus_E_TRANSPORT_NOT_SUPPORTED) } } func TestServerPrivateRejected(t *testing.T) { - h := bhost.NewBlankHost(swarmt.GenSwarm(t)) - dialer := bhost.NewBlankHost(swarmt.GenSwarm(t)) - as, err := New(h, dialer) - require.NoError(t, err) - defer as.Close() - defer as.host.Close() + an := newAutoNAT(t, nil) + defer an.Close() + defer an.host.Close() + an.srv.Enable() - as.srv.Start() - - c := newAutoNAT(t) - c.allowAllAddrs = true + c := newAutoNAT(t, nil, allowAll) defer c.Close() defer c.host.Close() - c.host.Peerstore().AddAddrs(as.host.ID(), as.host.Addrs(), peerstore.PermanentAddrTTL) - c.host.Peerstore().AddProtocols(as.host.ID(), DialProtocol) + identify(c, an) res, err := c.CheckReachability(context.Background(), c.host.Addrs(), nil) require.NoError(t, err) for _, r := range res { - require.Error(t, r.Err) + require.Equal(t, r.Status, pbv2.DialStatus_E_DIAL_REFUSED) } } func TestServerDataRequest(t *testing.T) { - h := bhost.NewBlankHost(swarmt.GenSwarm(t)) dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) - as := NewServer(h, dialer, func(s network.Stream, dialAddr ma.Multiaddr) bool { - if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { - return true - } - return false - }, true) - defer as.host.Close() - as.Start() - - c := newAutoNAT(t) + an := newAutoNAT(t, dialer, allowAll, WithDataRequestPolicy( + func(s network.Stream, dialAddr ma.Multiaddr) bool { + if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { + return true + } + return false + }), + WithServerRateLimit(10, 10), + ) + an.srv.Enable() + defer an.host.Close() + + c := newAutoNAT(t, nil) c.allowAllAddrs = true defer c.Close() defer c.host.Close() - c.host.Peerstore().AddAddrs(as.host.ID(), as.host.Addrs(), peerstore.PermanentAddrTTL) - c.host.Peerstore().AddProtocols(as.host.ID(), DialProtocol) + identify(c, an) + var quicAddr, tcpAddr ma.Multiaddr for _, a := range c.host.Addrs() { if _, err := a.ValueForProtocol(ma.P_QUIC_V1); err == nil { @@ -96,29 +96,48 @@ func TestServerDataRequest(t *testing.T) { res, err := c.CheckReachability(context.Background(), []ma.Multiaddr{quicAddr}, []ma.Multiaddr{tcpAddr}) require.NoError(t, err) - require.Equal(t, res[0].Rch, network.ReachabilityPublic) + require.Equal(t, res[0].Reachability, network.ReachabilityPublic) } func TestServerDial(t *testing.T) { - h := bhost.NewBlankHost(swarmt.GenSwarm(t)) - dialer := bhost.NewBlankHost(swarmt.GenSwarm(t)) - as := NewServer(h, dialer, nil, true) - defer as.host.Close() - as.Start() + an := newAutoNAT(t, nil, WithServerRateLimit(10, 10), allowAll) + defer an.host.Close() + an.srv.Enable() - c := newAutoNAT(t) - c.allowAllAddrs = true + c := newAutoNAT(t, nil, allowAll) defer c.Close() defer c.host.Close() - c.host.Peerstore().AddAddrs(as.host.ID(), as.host.Addrs(), peerstore.PermanentAddrTTL) - c.host.Peerstore().AddProtocols(as.host.ID(), DialProtocol) + identify(c, an) + randAddr := ma.StringCast("/ip4/1.2.3.4/tcp/2") res, err := c.CheckReachability(context.Background(), []ma.Multiaddr{randAddr}, c.host.Addrs()) require.NoError(t, err) - require.Equal(t, res[0].Rch, network.ReachabilityPrivate) + require.Equal(t, res[0].Reachability, network.ReachabilityPrivate) res, err = c.CheckReachability(context.Background(), nil, c.host.Addrs()) require.NoError(t, err) - require.Equal(t, res[0].Rch, network.ReachabilityPublic) + require.Equal(t, res[0].Reachability, network.ReachabilityPublic) +} + +func TestRateLimiter(t *testing.T) { + cl := test.NewMockClock() + r := rateLimiter{RPM: 3, RPMPerPeer: 2, now: cl.Now} + + require.True(t, r.Accept("peer1")) + + cl.AdvanceBy(10 * time.Second) + require.True(t, r.Accept("peer1")) + + cl.AdvanceBy(10 * time.Second) + require.False(t, r.Accept("peer1")) + + cl.AdvanceBy(10 * time.Second) + require.True(t, r.Accept("peer2")) + + cl.AdvanceBy(10 * time.Second) + require.False(t, r.Accept("peer3")) + + cl.AdvanceBy(21 * time.Second) // first request expired + require.True(t, r.Accept("peer1")) }