diff --git a/p2p/net/swarm/dial_error.go b/p2p/net/swarm/dial_error.go index 711ee06072..e9b8731097 100644 --- a/p2p/net/swarm/dial_error.go +++ b/p2p/net/swarm/dial_error.go @@ -1,6 +1,7 @@ package swarm import ( + "errors" "fmt" "os" "strings" @@ -30,10 +31,7 @@ func (e *DialError) recordErr(addr ma.Multiaddr, err error) { e.Skipped++ return } - e.DialErrors = append(e.DialErrors, TransportError{ - Address: addr, - Cause: err, - }) + e.DialErrors = append(e.DialErrors, TransportError{Address: addr, Cause: err}) } func (e *DialError) Error() string { @@ -51,9 +49,22 @@ func (e *DialError) Error() string { return builder.String() } -// Unwrap implements https://godoc.org/golang.org/x/xerrors#Wrapper. -func (e *DialError) Unwrap() error { - return e.Cause +func (e *DialError) Unwrap() []error { + if e == nil || len(e.DialErrors) == 0 { + return nil + } + errs := make([]error, len(e.DialErrors)) + for i := 0; i < len(e.DialErrors); i++ { + errs[i] = e.DialErrors[i] + } + return errs +} + +func (e *DialError) Is(target error) bool { + if e == target { + return true + } + return e != nil && e.Cause != nil && errors.Is(e.Cause, target) } var _ error = (*DialError)(nil) @@ -64,8 +75,12 @@ type TransportError struct { Cause error } -func (e *TransportError) Error() string { +func (e TransportError) Error() string { return fmt.Sprintf("failed to dial %s: %s", e.Address, e.Cause) } -var _ error = (*TransportError)(nil) +func (e TransportError) Unwrap() error { + return e.Cause +} + +var _ error = TransportError{} diff --git a/p2p/net/swarm/dial_error_test.go b/p2p/net/swarm/dial_error_test.go new file mode 100644 index 0000000000..4051e7e0b3 --- /dev/null +++ b/p2p/net/swarm/dial_error_test.go @@ -0,0 +1,50 @@ +package swarm + +import ( + "net" + "os" + "testing" + + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func TestTransportError(t *testing.T) { + aa := ma.StringCast("/ip4/1.2.3.4/tcp/1234") + te := TransportError{Address: aa, Cause: ErrDialBackoff} + require.ErrorIs(t, te, ErrDialBackoff, "TransportError should implement Unwrap") +} + +func TestDialError(t *testing.T) { + de := &DialError{Peer: "pid", Cause: ErrGaterDisallowedConnection} + require.ErrorIs(t, de, ErrGaterDisallowedConnection, + "DialError Unwrap should handle DialError.Cause") + require.ErrorIs(t, de, de, "DialError Unwrap should handle match to self") + + aa := ma.StringCast("/ip4/1.2.3.4/tcp/1234") + ab := ma.StringCast("/ip6/1::1/udp/1234/quic-v1") + de = &DialError{ + Peer: "pid", + DialErrors: []TransportError{ + {Address: aa, Cause: ErrDialBackoff}, {Address: ab, Cause: ErrNoTransport}, + }, + } + require.ErrorIs(t, de, ErrDialBackoff, "DialError.Unwrap should traverse TransportErrors") + require.ErrorIs(t, de, ErrNoTransport, "DialError.Unwrap should traverse TransportErrors") + + de = &DialError{ + Peer: "pid", + DialErrors: []TransportError{{Address: ab, Cause: ErrNoTransport}, + // wrapped error 2 levels deep + {Address: aa, Cause: &net.OpError{ + Op: "write", + Net: "tcp", + Err: &os.SyscallError{ + Syscall: "connect", + Err: os.ErrPermission, + }, + }}, + }, + } + require.ErrorIs(t, de, os.ErrPermission, "DialError.Unwrap should traverse TransportErrors") +}