diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index aa6260314..d02a1a733 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -410,6 +410,8 @@ func TestStage1RaceRelays(t *testing.T) { p := r.RouteForAllUntilTxTun(myControl) _ = p + r.FlushAll() + myControl.Stop() theirControl.Stop() relayControl.Stop() diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3a49dcbb5..a2a57e13f 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -8,6 +8,7 @@ import ( "io" "net" "os" + "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cidr" @@ -21,6 +22,7 @@ type TestTun struct { routeTree *cidr.Tree4 l *logrus.Logger + closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } @@ -50,6 +52,10 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int // These are unencrypted ip layer frames destined for another nebula node. // packets should exit the udp side, capture them with udpConn.Get func (t *TestTun) Send(packet []byte) { + if t.closed.Load() { + return + } + if t.l.Level >= logrus.DebugLevel { t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet") } @@ -98,6 +104,10 @@ func (t *TestTun) Name() string { } func (t *TestTun) Write(b []byte) (n int, err error) { + if t.closed.Load() { + return 0, io.ErrClosedPipe + } + packet := make([]byte, len(b), len(b)) copy(packet, b) t.TxPackets <- packet @@ -105,7 +115,10 @@ func (t *TestTun) Write(b []byte) (n int, err error) { } func (t *TestTun) Close() error { - close(t.rxPackets) + if t.closed.CompareAndSwap(false, true) { + close(t.rxPackets) + close(t.TxPackets) + } return nil } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index f03a69cbe..55985f47f 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -5,7 +5,9 @@ package udp import ( "fmt" + "io" "net" + "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -42,7 +44,8 @@ type TesterConn struct { RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula - l *logrus.Logger + closed atomic.Bool + l *logrus.Logger } func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) { @@ -58,6 +61,10 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, er // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send func (u *TesterConn) Send(packet *Packet) { + if u.closed.Load() { + return + } + h := &header.H{} if err := h.Parse(packet.Data); err != nil { panic(err) @@ -92,6 +99,10 @@ func (u *TesterConn) Get(block bool) *Packet { //********************************************************************************************************************// func (u *TesterConn) WriteTo(b []byte, addr *Addr) error { + if u.closed.Load() { + return io.ErrClosedPipe + } + p := &Packet{ Data: make([]byte, len(b), len(b)), FromIp: make([]byte, 16), @@ -142,7 +153,9 @@ func (u *TesterConn) Rebind() error { } func (u *TesterConn) Close() error { - close(u.RxPackets) - close(u.TxPackets) + if u.closed.CompareAndSwap(false, true) { + close(u.RxPackets) + close(u.TxPackets) + } return nil }