diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go index 26811d0943..473183b308 100644 --- a/p2p/protocol/autonatv2/autonat.go +++ b/p2p/protocol/autonatv2/autonat.go @@ -82,11 +82,12 @@ type AutoNAT struct { mx sync.Mutex peers *peersMap - allowAllAddrs bool // for testing + // allowAllAddrs enables using private and localhost addresses for reachability checks. + // This is only useful for testing. + allowAllAddrs bool } -// New returns a new AutoNAT instance. The returned instance runs the server when the provided host -// is publicly reachable. +// New returns a new AutoNAT instance. // host and dialerHost should have the same dialing capabilities. In case the host doesn't support // a transport, dial back requests for address for that transport will be ignored. func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) { @@ -99,19 +100,12 @@ func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, // We are listening on event.EvtPeerProtocolsUpdated, event.EvtPeerConnectednessChanged // event.EvtPeerIdentificationCompleted to maintain our set of autonat supporting peers. // - // We listen on event.EvtLocalReachabilityChanged to Disable the server if we are not - // publicly reachable. Currently this event is sent by the AutoNAT v1 module. During the - // transition period from AutoNAT v1 to v2, there won't be enough v2 servers on the network - // and most clients will be unable to discover a peer which supports AutoNAT v2. So, we use - // v1 to determine reachability for the transition period. - // // Once there are enough v2 servers on the network for nodes to determine their reachability // using AutoNAT v2, we'll use Address Pipeline // (https://github.com/libp2p/go-libp2p/issues/2229)(to be implemented in a future release) // to determine reachability using v2 client and send this event from Address Pipeline, if // we are publicly reachable. sub, err := host.EventBus().Subscribe([]interface{}{ - new(event.EvtLocalReachabilityChanged), new(event.EvtPeerProtocolsUpdated), new(event.EvtPeerConnectednessChanged), new(event.EvtPeerIdentificationCompleted), @@ -132,6 +126,7 @@ func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, peers: newPeersMap(), } an.cli.RegisterDialBack() + an.srv.Enable() an.wg.Add(1) go an.background() @@ -149,12 +144,6 @@ func (an *AutoNAT) background() { return case e := <-an.sub.Out(): switch evt := e.(type) { - case event.EvtLocalReachabilityChanged: - if evt.Reachability == network.ReachabilityPrivate { - an.srv.Disable() - } else { - an.srv.Enable() - } case event.EvtPeerProtocolsUpdated: an.updatePeer(evt.Peer) case event.EvtPeerConnectednessChanged: @@ -171,8 +160,8 @@ func (an *AutoNAT) Close() { an.wg.Wait() } -// CheckReachability makes a single dial request for checking reachability for requested addresses -func (an *AutoNAT) CheckReachability(ctx context.Context, reqs []Request) (Result, error) { +// GetReachability makes a single dial request for checking reachability for requested addresses +func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, error) { if !an.allowAllAddrs { for _, r := range reqs { if !manet.IsPublicAddr(r.Addr) { @@ -185,7 +174,7 @@ func (an *AutoNAT) CheckReachability(ctx context.Context, reqs []Request) (Resul return Result{}, ErrNoValidPeers } - res, err := an.cli.CheckReachability(ctx, p, reqs) + res, err := an.cli.GetReachability(ctx, p, reqs) if err != nil { log.Debugf("reachability check with %s failed, err: %s", p, err) return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err) diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index f1e8299cb2..e28e803469 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -80,7 +80,7 @@ func idAndWait(t *testing.T, cli *AutoNAT, srv *AutoNAT) { func TestAutoNATPrivateAddr(t *testing.T) { an := newAutoNAT(t, nil) - res, err := an.CheckReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}}) + res, err := an.GetReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}}) require.Equal(t, res, Result{}) require.Contains(t, err.Error(), "private address cannot be verified by autonatv2") } @@ -112,7 +112,7 @@ func TestClientRequest(t *testing.T) { s.Reset() }) - res, err := an.CheckReachability(context.Background(), []Request{ + res, err := an.GetReachability(context.Background(), []Request{ {Addr: addrs[0], SendDialData: true}, {Addr: addrs[1]}, }) require.Equal(t, res, Result{}) @@ -167,7 +167,7 @@ func TestClientServerError(t *testing.T) { t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { b.SetStreamHandler(DialProtocol, tc.handler) addrs := an.host.Addrs() - res, err := an.CheckReachability( + res, err := an.GetReachability( context.Background(), newTestRequests(addrs, false)) require.Equal(t, res, Result{}) @@ -280,7 +280,7 @@ func TestClientDataRequest(t *testing.T) { b.SetStreamHandler(DialProtocol, tc.handler) addrs := an.host.Addrs() - res, err := an.CheckReachability( + res, err := an.GetReachability( context.Background(), []Request{ {Addr: addrs[0], SendDialData: true}, @@ -489,7 +489,7 @@ func TestClientDialBacks(t *testing.T) { t.Run(tc.name, func(t *testing.T) { addrs := an.host.Addrs() b.SetStreamHandler(DialProtocol, tc.handler) - res, err := an.CheckReachability( + res, err := an.GetReachability( context.Background(), []Request{ {Addr: addrs[0], SendDialData: true}, diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go index d824db55b8..616931a667 100644 --- a/p2p/protocol/autonatv2/client.go +++ b/p2p/protocol/autonatv2/client.go @@ -38,8 +38,8 @@ func (ac *client) RegisterDialBack() { ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack) } -// CheckReachability verifies address reachability with a AutoNAT v2 server p. -func (ac *client) CheckReachability(ctx context.Context, p peer.ID, reqs []Request) (Result, error) { +// GetReachability verifies address reachability with a AutoNAT v2 server p. +func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request) (Result, error) { ctx, cancel := context.WithTimeout(ctx, streamTimeout) defer cancel() diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go index 5e788173cb..64f66a7ee8 100644 --- a/p2p/protocol/autonatv2/server.go +++ b/p2p/protocol/autonatv2/server.go @@ -31,8 +31,8 @@ type server struct { dialerHost host.Host limiter *rateLimiter - // dialDataRequestPolicy is used to determine whether dialing the address requires receiving dial data. - // It is set to amplification attack prevention by default. + // dialDataRequestPolicy is used to determine whether dialing the address requires receiving + // dial data. It is set to amplification attack prevention by default. dialDataRequestPolicy dataRequestPolicyFunc // for tests @@ -98,7 +98,7 @@ func (as *server) handleDialRequest(s network.Stream) { } if msg.GetDialRequest() == nil { s.Reset() - log.Debugf("invalid message type from %s: %T", p, msg.Msg) + log.Debugf("invalid message type from %s: %T expected: DialRequest", p, msg.Msg) return } @@ -119,7 +119,7 @@ func (as *server) handleDialRequest(s network.Stream) { continue } // Check if the host can dial the address. This check ensures that we do not - // attempt dialing an IPv6 address if we have no IPv6 connectivity as the host dialer's + // attempt dialing an IPv6 address if we have no IPv6 connectivity as the host's // black hole detector is likely to be more accurate. if as.host.Network().CanDial(p, a) != network.DialabilityDialable { continue @@ -141,14 +141,13 @@ func (as *server) handleDialRequest(s network.Stream) { } if err := w.WriteMsg(&msg); err != nil { s.Reset() - log.Debugf("failed to write response to %s: %s", p, err) + log.Debugf("failed to write dial refused response to %s: %s", p, err) return } return } isDialDataRequired := as.dialDataRequestPolicy(s, dialAddr) - if !as.limiter.Accept(p, isDialDataRequired) { msg = pb.Message{ Msg: &pb.Message_DialResponse{ @@ -159,10 +158,10 @@ func (as *server) handleDialRequest(s network.Stream) { } if err := w.WriteMsg(&msg); err != nil { s.Reset() - log.Debugf("failed to write response to %s: %s", p, err) + log.Debugf("failed to write request rejected response to %s: %s", p, err) return } - log.Debugf("rejecting request from %s: rate limit exceeded", p) + log.Debugf("rejected request from %s: rate limit exceeded", p) return } defer as.limiter.CompleteRequest(p) @@ -248,12 +247,14 @@ func (as *server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialSt return pb.DialStatus_E_DIAL_BACK_ERROR } - // Since the underlying connection is on a separate dialer, it'll be closed after this function returns. - // Connection close will drop all the queued writes. To ensure message delivery, do a CloseWrite and - // wait a second for the peer to Close its end of the stream. + // Since the underlying connection is on a separate dialer, it'll be closed after this + // function returns. Connection close will drop all the queued writes. + // To ensure message delivery, do a CloseWrite and read a byte from the stream. The peer + // actually sends a DialDataResponse back but we only care about the fact that the DialBack + // message has reached the peer. So we ignore that message on the read side. s.CloseWrite() - s.SetDeadline(as.now().Add(1 * time.Second)) - b := make([]byte, 1) // Read 1 byte here because 0 len reads are free to return (0, nil) immediately + s.SetDeadline(as.now().Add(5 * time.Second)) // 5 is a magic number + b := make([]byte, 1) // Read 1 byte here because 0 len reads are free to return (0, nil) immediately s.Read(b) return pb.DialStatus_OK @@ -275,6 +276,7 @@ type rateLimiter struct { dialDataReqs []time.Time // ongoingReqs tracks in progress requests. This is used to disallow multiple concurrent requests by the // same peer + // TODO: Should we allow a few concurrent requests per peer? ongoingReqs map[peer.ID]struct{} now func() time.Time // for tests diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go index 00409bac6d..46d8697eb5 100644 --- a/p2p/protocol/autonatv2/server_test.go +++ b/p2p/protocol/autonatv2/server_test.go @@ -35,7 +35,7 @@ func TestServerInvalidAddrsRejected(t *testing.T) { idAndWait(t, c, an) - res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) + res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) require.ErrorIs(t, err, ErrDialRefused) require.Equal(t, Result{}, res) }) @@ -47,7 +47,7 @@ func TestServerInvalidAddrsRejected(t *testing.T) { idAndWait(t, c, an) - res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) + res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) require.ErrorIs(t, err, ErrDialRefused) require.Equal(t, Result{}, res) }) @@ -84,10 +84,10 @@ func TestServerDataRequest(t *testing.T) { } } - _, err := c.CheckReachability(context.Background(), []Request{{Addr: tcpAddr, SendDialData: true}, {Addr: quicAddr}}) + _, err := c.GetReachability(context.Background(), []Request{{Addr: tcpAddr, SendDialData: true}, {Addr: quicAddr}}) require.Error(t, err) - res, err := c.CheckReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}}) + res, err := c.GetReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}}) require.NoError(t, err) require.Equal(t, Result{ @@ -113,7 +113,7 @@ func TestServerDial(t *testing.T) { hostAddrs := c.host.Addrs() t.Run("unreachable addr", func(t *testing.T) { - res, err := c.CheckReachability(context.Background(), + res, err := c.GetReachability(context.Background(), append([]Request{{Addr: unreachableAddr, SendDialData: true}}, newTestRequests(hostAddrs, false)...)) require.NoError(t, err) require.Equal(t, Result{ @@ -125,7 +125,7 @@ func TestServerDial(t *testing.T) { }) t.Run("reachable addr", func(t *testing.T) { - res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) + res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) require.NoError(t, err) require.Equal(t, Result{ Idx: 0, @@ -137,7 +137,7 @@ func TestServerDial(t *testing.T) { t.Run("dialback error", func(t *testing.T) { c.host.RemoveStreamHandler(DialBackProtocol) - res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) + res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) require.NoError(t, err) require.Equal(t, Result{ Idx: 0,