Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swarm: return errors on filtered addresses when dialing #2461

Merged
merged 1 commit into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions p2p/net/swarm/black_hole_detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ type blackHoleDetector struct {
}

// FilterAddrs filters the peer's addresses removing black holed addresses
func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multiaddr, blackHoled []ma.Multiaddr) {
hasUDP, hasIPv6 := false, false
for _, a := range addrs {
if !manet.IsPublicAddr(a) {
Expand All @@ -202,6 +202,7 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
ipv6Res = d.ipv6.HandleRequest()
}

blackHoled = make([]ma.Multiaddr, 0, len(addrs))
return ma.FilterAddrs(
addrs,
func(a ma.Multiaddr) bool {
Expand All @@ -218,14 +219,16 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
}

if udpRes == blackHoleResultBlocked && isProtocolAddr(a, ma.P_UDP) {
blackHoled = append(blackHoled, a)
return false
}
if ipv6Res == blackHoleResultBlocked && isProtocolAddr(a, ma.P_IP6) {
blackHoled = append(blackHoled, a)
return false
}
return true
},
)
), blackHoled
}

// RecordResult updates the state of the relevant `blackHoleFilter`s for addr
Expand Down
44 changes: 31 additions & 13 deletions p2p/net/swarm/black_hole_detector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestBlackHoleDetectorInApplicableAddress(t *testing.T) {
ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1"),
}
for i := 0; i < 1000; i++ {
filteredAddrs := bhd.FilterAddrs(addrs)
filteredAddrs, _ := bhd.FilterAddrs(addrs)
require.ElementsMatch(t, addrs, filteredAddrs)
for j := 0; j < len(addrs); j++ {
bhd.RecordResult(addrs[j], false)
Expand All @@ -101,20 +101,29 @@ func TestBlackHoleDetectorUDPDisabled(t *testing.T) {
for i := 0; i < 100; i++ {
bhd.RecordResult(publicAddr, false)
}
addrs := []ma.Multiaddr{publicAddr, privAddr}
require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs))
wantAddrs := []ma.Multiaddr{publicAddr, privAddr}
wantRemovedAddrs := make([]ma.Multiaddr, 0)

gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs)
require.ElementsMatch(t, wantAddrs, gotAddrs)
require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs)
}

func TestBlackHoleDetectorIPv6Disabled(t *testing.T) {
udpConfig := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5}
bhd := newBlackHoleDetector(udpConfig, blackHoleConfig{Enabled: false}, nil)
publicAddr := ma.StringCast("/ip6/1::1/tcp/1234")
privAddr := ma.StringCast("/ip6/::1/tcp/1234")
addrs := []ma.Multiaddr{publicAddr, privAddr}
for i := 0; i < 100; i++ {
bhd.RecordResult(publicAddr, false)
}
require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs))

wantAddrs := []ma.Multiaddr{publicAddr, privAddr}
wantRemovedAddrs := make([]ma.Multiaddr, 0)

gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs)
require.ElementsMatch(t, wantAddrs, gotAddrs)
require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs)
}

func TestBlackHoleDetectorProbes(t *testing.T) {
Expand All @@ -128,7 +137,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) {
bhd.RecordResult(udp6Addr, false)
}
for i := 1; i < 100; i++ {
filteredAddrs := bhd.FilterAddrs(addrs)
filteredAddrs, _ := bhd.FilterAddrs(addrs)
if i%2 == 0 || i%3 == 0 {
if len(filteredAddrs) == 0 {
t.Fatalf("expected probe to be allowed irrespective of the state of other black hole filter")
Expand All @@ -145,7 +154,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) {
func TestBlackHoleDetectorAddrFiltering(t *testing.T) {
udp6Pub := ma.StringCast("/ip6/1::1/udp/1234/quic-v1")
udp6Pri := ma.StringCast("/ip6/::1/udp/1234/quic-v1")
upd4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1")
udp4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1")
udp4Pri := ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1")
tcp6Pub := ma.StringCast("/ip6/1::1/tcp/1234/quic-v1")
tcp6Pri := ma.StringCast("/ip6/::1/tcp/1234/quic-v1")
Expand All @@ -158,26 +167,35 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) {
ipv6: &blackHoleFilter{n: 100, minSuccesses: 10, name: "ipv6"},
}
for i := 0; i < 100; i++ {
bhd.RecordResult(upd4Pub, !udpBlocked)
bhd.RecordResult(udp4Pub, !udpBlocked)
}
for i := 0; i < 100; i++ {
bhd.RecordResult(tcp6Pub, !ipv6Blocked)
}
return bhd
}

allInput := []ma.Multiaddr{udp6Pub, udp6Pri, upd4Pub, udp4Pri, tcp6Pub, tcp6Pri,
allInput := []ma.Multiaddr{udp6Pub, udp6Pri, udp4Pub, udp4Pri, tcp6Pub, tcp6Pri,
tcp4Pub, tcp4Pri}

udpBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pub, tcp6Pri, tcp4Pub, tcp4Pri}
udpPublicAddrs := []ma.Multiaddr{udp6Pub, udp4Pub}
bhd := makeBHD(true, false)
require.ElementsMatch(t, udpBlockedOutput, bhd.FilterAddrs(allInput))
gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(allInput)
require.ElementsMatch(t, udpBlockedOutput, gotAddrs)
require.ElementsMatch(t, udpPublicAddrs, gotRemovedAddrs)

ip6BlockedOutput := []ma.Multiaddr{udp6Pri, upd4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
ip6BlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
ip6PublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub}
bhd = makeBHD(false, true)
require.ElementsMatch(t, ip6BlockedOutput, bhd.FilterAddrs(allInput))
gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput)
require.ElementsMatch(t, ip6BlockedOutput, gotAddrs)
require.ElementsMatch(t, ip6PublicAddrs, gotRemovedAddrs)

bothBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
bothPublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub, udp4Pub}
bhd = makeBHD(true, true)
require.ElementsMatch(t, bothBlockedOutput, bhd.FilterAddrs(allInput))
gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput)
require.ElementsMatch(t, bothBlockedOutput, gotAddrs)
require.ElementsMatch(t, bothPublicAddrs, gotRemovedAddrs)
}
13 changes: 10 additions & 3 deletions p2p/net/swarm/dial_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,14 @@ loop:
continue loop
}

addrs, err := w.s.addrsForDial(req.ctx, w.peer)
addrs, addrErrs, err := w.s.addrsForDial(req.ctx, w.peer)
if err != nil {
req.resch <- dialResponse{err: err}
req.resch <- dialResponse{
err: &DialError{
Peer: w.peer,
DialErrors: addrErrs,
Cause: err,
}}
continue loop
}

Expand All @@ -179,8 +184,8 @@ loop:
// create the pending request object
pr := &pendRequest{
req: req,
err: &DialError{Peer: w.peer},
addrs: make(map[string]struct{}, len(addrRanking)),
err: &DialError{Peer: w.peer, DialErrors: addrErrs},
}
for _, adelay := range addrRanking {
pr.addrs[string(adelay.Addr.Bytes())] = struct{}{}
Expand Down Expand Up @@ -221,6 +226,7 @@ loop:

if len(todial) == 0 && len(tojoin) == 0 {
// all request applicable addrs have been dialed, we must have errored
pr.err.Cause = ErrAllDialsFailed
req.resch <- dialResponse{err: pr.err}
continue loop
}
Expand Down Expand Up @@ -371,6 +377,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) {
if c != nil {
pr.req.resch <- dialResponse{conn: c}
} else {
pr.err.Cause = ErrAllDialsFailed
pr.req.resch <- dialResponse{err: pr.err}
}
delete(w.pendingRequests, pr)
Expand Down
61 changes: 39 additions & 22 deletions p2p/net/swarm/swarm_dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,10 @@ func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) {
w.loop()
}

func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) {
func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) (goodAddrs []ma.Multiaddr, addrErrs []TransportError, err error) {
peerAddrs := s.peers.Addrs(p)
if len(peerAddrs) == 0 {
return nil, ErrNoAddresses
return nil, nil, ErrNoAddresses
}

peerAddrsAfterTransportResolved := make([]ma.Multiaddr, 0, len(peerAddrs))
Expand All @@ -308,22 +308,22 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, er
Addrs: peerAddrsAfterTransportResolved,
})
if err != nil {
return nil, err
return nil, nil, err
}

goodAddrs := s.filterKnownUndialables(p, resolved)
goodAddrs = ma.Unique(resolved)
goodAddrs, addrErrs = s.filterKnownUndialables(p, goodAddrs)
if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect {
goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr)
}
goodAddrs = ma.Unique(goodAddrs)

if len(goodAddrs) == 0 {
return nil, ErrNoGoodAddresses
return nil, addrErrs, ErrNoGoodAddresses
}

s.peers.AddAddrs(p, goodAddrs, peerstore.TempAddrTTL)

return goodAddrs, nil
return goodAddrs, addrErrs, nil
}

func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) ([]ma.Multiaddr, error) {
Expand Down Expand Up @@ -402,11 +402,6 @@ func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr,
return nil
}

func (s *Swarm) canDial(addr ma.Multiaddr) bool {
t := s.TransportForDialing(addr)
return t != nil && t.CanDial(addr)
}

func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool {
t := s.TransportForDialing(addr)
return !t.Proxy()
Expand All @@ -418,7 +413,7 @@ func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool {
// addresses that we know to be our own, and addresses with a better tranport
// available. This is an optimization to avoid wasting time on dials that we
// know are going to fail or for which we have a better alternative.
func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Multiaddr {
func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) (goodAddrs []ma.Multiaddr, addrErrs []TransportError) {
lisAddrs, _ := s.InterfaceListenAddresses()
var ourAddrs []ma.Multiaddr
for _, addr := range lisAddrs {
Expand All @@ -431,27 +426,49 @@ func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Mul
})
}

// The order of these two filters is important. If we can only dial /webtransport,
// we don't want to filter /webtransport addresses out because the peer had a /quic-v1
// address
addrErrs = make([]TransportError, 0, len(addrs))

// filter addresses we cannot dial
addrs = ma.FilterAddrs(addrs, s.canDial)
// The order of checking for transport and filtering low priority addrs is important. If we
// can only dial /webtransport, we don't want to filter /webtransport addresses out because
// the peer had a /quic-v1 address

// filter addresses with no transport
addrs = ma.FilterAddrs(addrs, func(a ma.Multiaddr) bool {
if s.TransportForDialing(a) == nil {
addrErrs = append(addrErrs, TransportError{Address: a, Cause: ErrNoTransport})
return false
}
return true
})

// filter low priority addresses among the addresses we can dial
// We don't return an error for these addresses
addrs = filterLowPriorityAddresses(addrs)

// remove black holed addrs
addrs = s.bhd.FilterAddrs(addrs)
addrs, blackHoledAddrs := s.bhd.FilterAddrs(addrs)
for _, a := range blackHoledAddrs {
addrErrs = append(addrErrs, TransportError{Address: a, Cause: ErrDialRefusedBlackHole})
}

return ma.FilterAddrs(addrs,
func(addr ma.Multiaddr) bool { return !ma.Contains(ourAddrs, addr) },
func(addr ma.Multiaddr) bool {
if ma.Contains(ourAddrs, addr) {
addrErrs = append(addrErrs, TransportError{Address: addr, Cause: ErrDialToSelf})
return false
}
return true
},
// TODO: Consider allowing link-local addresses
func(addr ma.Multiaddr) bool { return !manet.IsIP6LinkLocal(addr) },
func(addr ma.Multiaddr) bool {
return s.gater == nil || s.gater.InterceptAddrDial(p, addr)
if s.gater != nil && !s.gater.InterceptAddrDial(p, addr) {
addrErrs = append(addrErrs, TransportError{Address: addr, Cause: ErrGaterDisallowedConnection})
return false
}
return true
},
)
), addrErrs
}

// limitedDial will start a dial to the given peer when
Expand Down
21 changes: 11 additions & 10 deletions p2p/net/swarm/swarm_dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/rand"
"errors"
"net"
"sort"
"testing"
Expand Down Expand Up @@ -65,7 +66,7 @@ func TestAddrsForDial(t *testing.T) {
ps.AddAddr(otherPeer, ma.StringCast("/dns4/example.com/tcp/1234/wss"), time.Hour)

ctx := context.Background()
mas, err := s.addrsForDial(ctx, otherPeer)
mas, _, err := s.addrsForDial(ctx, otherPeer)
require.NoError(t, err)

require.NotZero(t, len(mas))
Expand Down Expand Up @@ -110,7 +111,7 @@ func TestDedupAddrsForDial(t *testing.T) {
ps.AddAddr(otherPeer, ma.StringCast("/ip4/1.2.3.4/tcp/1234"), time.Hour)

ctx := context.Background()
mas, err := s.addrsForDial(ctx, otherPeer)
mas, _, err := s.addrsForDial(ctx, otherPeer)
require.NoError(t, err)

require.Equal(t, 1, len(mas))
Expand Down Expand Up @@ -183,7 +184,7 @@ func TestAddrResolution(t *testing.T) {

tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100)
defer cancel()
mas, err := s.addrsForDial(tctx, p1)
mas, _, err := s.addrsForDial(tctx, p1)
require.NoError(t, err)

require.Len(t, mas, 1)
Expand Down Expand Up @@ -241,7 +242,7 @@ func TestAddrResolutionRecursive(t *testing.T) {
tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100)
defer cancel()
s.Peerstore().AddAddrs(pi1.ID, pi1.Addrs, peerstore.TempAddrTTL)
_, err = s.addrsForDial(tctx, p1)
_, _, err = s.addrsForDial(tctx, p1)
require.NoError(t, err)

addrs1 := s.Peerstore().Addrs(pi1.ID)
Expand All @@ -253,7 +254,7 @@ func TestAddrResolutionRecursive(t *testing.T) {
require.NoError(t, err)

s.Peerstore().AddAddrs(pi2.ID, pi2.Addrs, peerstore.TempAddrTTL)
_, err = s.addrsForDial(tctx, p2)
_, _, err = s.addrsForDial(tctx, p2)
// This never resolves to a good address
require.Equal(t, ErrNoGoodAddresses, err)

Expand Down Expand Up @@ -315,7 +316,7 @@ func TestAddrsForDialFiltering(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
s.Peerstore().ClearAddrs(p1)
s.Peerstore().AddAddrs(p1, tc.input, peerstore.PermanentAddrTTL)
result, err := s.addrsForDial(ctx, p1)
result, _, err := s.addrsForDial(ctx, p1)
require.NoError(t, err)
sort.Slice(result, func(i, j int) bool { return bytes.Compare(result[i].Bytes(), result[j].Bytes()) < 0 })
sort.Slice(tc.output, func(i, j int) bool { return bytes.Compare(tc.output[i].Bytes(), tc.output[j].Bytes()) < 0 })
Expand Down Expand Up @@ -366,10 +367,10 @@ func TestBlackHoledAddrBlocked(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
conn, err := s.DialPeer(ctx, p)
if conn != nil {
t.Fatalf("expected dial to be blocked")
}
if err != ErrNoGoodAddresses {
require.Nil(t, conn)
var de *DialError
if !errors.As(err, &de) {
t.Fatalf("expected to receive an error of type *DialError, got %s of type %T", err, err)
}
require.Contains(t, de.DialErrors, TransportError{Address: addr, Cause: ErrDialRefusedBlackHole})
}