From 6281130baed050f4406fb080183eaf74ba954499 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Wed, 25 Sep 2024 11:55:26 -0400 Subject: [PATCH 01/32] feat(tcpreuse): add options for sharing TCP listeners amongst TCP, WS, and WSS transports --- p2p/transport/tcp/tcp.go | 32 ++- p2p/transport/tcp/tcp_test.go | 13 +- p2p/transport/tcpreuse/demultiplex.go | 240 ++++++++++++++++++ p2p/transport/tcpreuse/demultiplex_test.go | 50 ++++ p2p/transport/tcpreuse/dialer.go | 16 ++ p2p/transport/tcpreuse/listener.go | 250 +++++++++++++++++++ p2p/transport/{tcp => tcpreuse}/reuseport.go | 10 +- p2p/transport/websocket/addrs_test.go | 2 +- p2p/transport/websocket/listener.go | 34 ++- p2p/transport/websocket/websocket.go | 12 +- 10 files changed, 635 insertions(+), 24 deletions(-) create mode 100644 p2p/transport/tcpreuse/demultiplex.go create mode 100644 p2p/transport/tcpreuse/demultiplex_test.go create mode 100644 p2p/transport/tcpreuse/dialer.go create mode 100644 p2p/transport/tcpreuse/listener.go rename p2p/transport/{tcp => tcpreuse}/reuseport.go (81%) diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index d52bb96019..66fe9b7631 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -13,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/net/reuseport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" logging "github.com/ipfs/go-log/v2" ma "github.com/multiformats/go-multiaddr" @@ -33,6 +34,9 @@ type canKeepAlive interface { var _ canKeepAlive = &net.TCPConn{} +// Deprecated: Use tcpreuse.ReuseportIsAvailable +var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable + func tryKeepAlive(conn net.Conn, keepAlive bool) { keepAliveConn, ok := conn.(canKeepAlive) if !ok { @@ -113,6 +117,13 @@ func WithMetrics() Option { } } +func WithSharedTCP(mgr *tcpreuse.ConnMgr) Option { + return func(tr *TcpTransport) error { + tr.sharedTcp = mgr + return nil + } +} + // TcpTransport is the TCP transport. type TcpTransport struct { // Connection upgrader for upgrading insecure stream connections to @@ -122,6 +133,9 @@ type TcpTransport struct { disableReuseport bool // Explicitly disable reuseport. enableMetrics bool + // share and demultiplex TCP listeners across multiple transports + sharedTcp *tcpreuse.ConnMgr + // TCP connect timeout connectTimeout time.Duration @@ -168,6 +182,10 @@ func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Co defer cancel() } + if t.sharedTcp != nil { + return t.sharedTcp.DialContext(ctx, raddr) + } + if t.UseReuseport() { return t.reuse.DialContext(ctx, raddr) } @@ -233,10 +251,10 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p // UseReuseport returns true if reuseport is enabled and available. func (t *TcpTransport) UseReuseport() bool { - return !t.disableReuseport && ReuseportIsAvailable() + return !t.disableReuseport && tcpreuse.ReuseportIsAvailable() } -func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) { +func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, error) { if t.UseReuseport() { return t.reuse.Listen(laddr) } @@ -245,10 +263,18 @@ func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) { // Listen listens on the given multiaddr. func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { - list, err := t.maListen(laddr) + var list manet.Listener + var err error + + if t.sharedTcp == nil { + list, err = t.unsharedMAListen(laddr) + } else { + list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.MultistreamSelect) + } if err != nil { return nil, err } + if t.enableMetrics { list = newTracingListener(&tcpListener{list, 0}) } diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index a57a65e420..4c692fbf4c 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -14,6 +14,7 @@ import ( "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" ma "github.com/multiformats/go-multiaddr" @@ -41,9 +42,9 @@ func TestTcpTransport(t *testing.T) { zero := "/ip4/127.0.0.1/tcp/0" ttransport.SubtestTransport(t, ta, tb, zero, peerA) - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestTcpTransportWithMetrics(t *testing.T) { @@ -126,9 +127,9 @@ func TestTcpTransportCantDialDNS(t *testing.T) { t.Fatal("shouldn't be able to dial dns") } - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestTcpTransportCantListenUtp(t *testing.T) { @@ -143,9 +144,9 @@ func TestTcpTransportCantListenUtp(t *testing.T) { _, err = tpt.Listen(utpa) require.Error(t, err, "shouldn't be able to listen on utp addr with tcp transport") - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestDialWithUpdates(t *testing.T) { diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go new file mode 100644 index 0000000000..59e26a9aee --- /dev/null +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -0,0 +1,240 @@ +package tcpreuse + +import ( + "bufio" + "errors" + "fmt" + "io" + "math" + "net" + "time" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +type peekAble interface { + // Peek returns the next n bytes without advancing the reader. The bytes stop + // being valid at the next read call. If Peek returns fewer than n bytes, it + // also returns an error explaining why the read is short. The error is + // [ErrBufferFull] if n is larger than b's buffer size. + Peek(n int) ([]byte, error) +} + +var _ peekAble = (*bufio.Reader)(nil) + +type DemultiplexedConnType int + +const ( + Unknown DemultiplexedConnType = iota + MultistreamSelect + HTTP + TLS +) + +func (t DemultiplexedConnType) String() string { + switch t { + case MultistreamSelect: + return "MultistreamSelect" + case HTTP: + return "HTTP" + case TLS: + return "TLS" + default: + return fmt.Sprintf("Unknown(%d)", int(t)) + } +} + +func (t DemultiplexedConnType) IsKnown() bool { + return t >= 1 || t <= 3 +} + +func ConnTypeFromConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) { + if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { + closeErr := c.Close() + return 0, nil, errors.Join(err, closeErr) + } + + s, sc, err := ReadSampleFromConn(c) + if err != nil { + closeErr := c.Close() + return 0, nil, errors.Join(err, closeErr) + } + + if err := c.SetReadDeadline(time.Time{}); err != nil { + closeErr := c.Close() + return 0, nil, errors.Join(err, closeErr) + } + + if IsMultistreamSelect(s) { + return MultistreamSelect, sc, nil + } + if IsTLS(s) { + return TLS, sc, nil + } + if IsHTTP(s) { + return HTTP, sc, nil + } + return Unknown, sc, nil +} + +// ReadSampleFromConn read the sample and returns a reader which still include the sample, so it can be kept undamaged. +// If an error occurs it only return the error. +func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) { + if peekAble, ok := c.(peekAble); ok { + b, err := peekAble.Peek(len(Sample{})) + switch { + case err == nil: + mac, err := manet.WrapNetConn(c) + if err != nil { + return Sample{}, nil, err + } + + return Sample(b), mac, nil + case errors.Is(err, bufio.ErrBufferFull): + // fallback to sampledConn + default: + return Sample{}, nil, err + } + } + + tcpConnLike, ok := c.(tcpConnInterface) + if !ok { + return Sample{}, nil, fmt.Errorf("expected tcp-like connection") + } + + laddr, err := manet.FromNetAddr(c.LocalAddr()) + if err != nil { + return Sample{}, nil, fmt.Errorf("failed to convert nconn.LocalAddr: %s", err) + } + + raddr, err := manet.FromNetAddr(c.RemoteAddr()) + if err != nil { + return Sample{}, nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err) + } + + sc := &sampledConn{tcpConnInterface: tcpConnLike, maEndpoints: maEndpoints{laddr: laddr, raddr: raddr}} + _, err = io.ReadFull(c, sc.s[:]) + if err != nil { + return Sample{}, nil, err + } + + return sc.s, sc, nil +} + +// Try out best to mimic a TCPConn's functions +// Note: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be misused given we've read a few bytes from the connection +// If this is an issue here we can revisit the options. +type tcpConnInterface interface { + net.Conn + + CloseRead() error + CloseWrite() error + + SetLinger(sec int) error + SetKeepAlive(keepalive bool) error + SetKeepAlivePeriod(d time.Duration) error + SetNoDelay(noDelay bool) error + MultipathTCP() (bool, error) + + io.ReaderFrom + io.WriterTo +} + +type maEndpoints struct { + laddr ma.Multiaddr + raddr ma.Multiaddr +} + +// LocalMultiaddr returns the local address associated with +// this connection +func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr { + return c.laddr +} + +// RemoteMultiaddr returns the remote address associated with +// this connection +func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr { + return c.raddr +} + +type sampledConn struct { + tcpConnInterface + maEndpoints + + s Sample + readFromSample uint8 +} + +var _ = [math.MaxUint8]struct{}{}[len(Sample{})] // compiletime assert sampledConn.readFromSample wont overflow +var _ io.ReaderFrom = (*sampledConn)(nil) +var _ io.WriterTo = (*sampledConn)(nil) + +func (sc *sampledConn) Read(b []byte) (int, error) { + if int(sc.readFromSample) != len(sc.s) { + red := copy(b, sc.s[sc.readFromSample:]) + sc.readFromSample += uint8(red) + return red, nil + } + + return sc.tcpConnInterface.Read(b) +} + +// forward optimizations +func (sc *sampledConn) ReadFrom(r io.Reader) (int64, error) { + return io.Copy(sc.tcpConnInterface, r) +} + +// forward optimizations +func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) { + if int(sc.readFromSample) != len(sc.s) { + b := sc.s[sc.readFromSample:] + written, err := w.Write(b) + if written < 0 || len(b) < written { + // buggy writer, harden against this + sc.readFromSample = uint8(len(sc.s)) + total = int64(len(sc.s)) + } else { + sc.readFromSample += uint8(written) + total += int64(written) + } + if err != nil { + return total, err + } + } + + written, err := io.Copy(w, sc.tcpConnInterface) + total += written + return total, err +} + +type Matcher interface { + Match(s Sample) bool +} + +// Sample might evolve over time. +type Sample [3]byte + +// Matchers are implemented here instead of in the transports so we can easily fuzz them together. + +func IsMultistreamSelect(s Sample) bool { + return string(s[:]) == "\x13/m" +} + +func IsHTTP(s Sample) bool { + switch string(s[:]) { + case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT": + return true + default: + return false + } +} + +func IsTLS(s Sample) bool { + switch string(s[:]) { + case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03", "\x16\x03\x04": + return true + default: + return false + } +} diff --git a/p2p/transport/tcpreuse/demultiplex_test.go b/p2p/transport/tcpreuse/demultiplex_test.go new file mode 100644 index 0000000000..3d6e91f35a --- /dev/null +++ b/p2p/transport/tcpreuse/demultiplex_test.go @@ -0,0 +1,50 @@ +package tcpreuse + +import "testing" + +func FuzzClash(f *testing.F) { + // make untyped literals type correctly + add := func(a, b, c byte) { f.Add(a, b, c) } + + // multistream-select + add('\x13', '/', 'm') + // http + add('G', 'E', 'T') + add('H', 'E', 'A') + add('P', 'O', 'S') + add('P', 'U', 'T') + add('D', 'E', 'L') + add('C', 'O', 'N') + add('O', 'P', 'T') + add('T', 'R', 'A') + add('P', 'A', 'T') + // tls + add('\x16', '\x03', '\x01') + add('\x16', '\x03', '\x02') + add('\x16', '\x03', '\x03') + add('\x16', '\x03', '\x04') + + f.Fuzz(func(t *testing.T, a, b, c byte) { + s := Sample{a, b, c} + var total uint + + ms := IsMultistreamSelect(s) + if ms { + total++ + } + + http := IsHTTP(s) + if http { + total++ + } + + tls := IsTLS(s) + if tls { + total++ + } + + if total > 1 { + t.Errorf("clash on: %q; ms: %v; http: %v; tls: %v", s, ms, http, tls) + } + }) +} diff --git a/p2p/transport/tcpreuse/dialer.go b/p2p/transport/tcpreuse/dialer.go new file mode 100644 index 0000000000..ad634583ed --- /dev/null +++ b/p2p/transport/tcpreuse/dialer.go @@ -0,0 +1,16 @@ +package tcpreuse + +import ( + "context" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +// DialContext is like Dial but takes a context. +func (t *ConnMgr) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) { + if t.useReuseport() { + return t.reuse.DialContext(ctx, raddr) + } + var d manet.Dialer + return d.DialContext(ctx, raddr) +} diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go new file mode 100644 index 0000000000..59aeed1f93 --- /dev/null +++ b/p2p/transport/tcpreuse/listener.go @@ -0,0 +1,250 @@ +package tcpreuse + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/net/reuseport" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +var log = logging.Logger("tcp-demultiplex") + +type ConnMgr struct { + disableReuseport bool + reuse reuseport.Transport + listeners map[string]*multiplexedListener + mx sync.Mutex +} + +func NewConnMgr(disableReuseport bool) *ConnMgr { + return &ConnMgr{ + disableReuseport: disableReuseport, + reuse: reuseport.Transport{}, + listeners: make(map[string]*multiplexedListener), + } +} + +func (t *ConnMgr) maListen(laddr ma.Multiaddr) (manet.Listener, error) { + if t.useReuseport() { + return t.reuse.Listen(laddr) + } else { + return manet.Listen(laddr) + } +} + +func (t *ConnMgr) useReuseport() bool { + return !t.disableReuseport && ReuseportIsAvailable() +} + +func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) { + if !connType.IsKnown() { + return nil, fmt.Errorf("unknown connection type: %s", connType) + } + + t.mx.Lock() + defer t.mx.Unlock() + ml, ok := t.listeners[laddr.String()] + if ok { + dl, err := ml.DemultiplexedListen(connType) + if err != nil { + return nil, err + } + return dl, nil + } + + l, err := t.maListen(laddr) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithCancel(context.Background()) + cancelFunc := func() error { + cancel() + t.mx.Lock() + defer t.mx.Unlock() + delete(t.listeners, laddr.String()) + return l.Close() + } + ml = &multiplexedListener{ + Listener: l, + listeners: make(map[DemultiplexedConnType]*demultiplexedListener), + buffer: make(chan manet.Conn, 16), // TODO: how big should this buffer be? + ctx: ctx, + closeFn: cancelFunc, + } + + dl, err := ml.DemultiplexedListen(connType) + if err != nil { + cerr := ml.Close() + return nil, errors.Join(err, cerr) + } + + go func() { + err = ml.Run() + if err != nil { + log.Debugf("Error running multiplexed listener: %s", err.Error()) + } + }() + + t.listeners[laddr.String()] = ml + + return dl, nil +} + +var _ manet.Listener = &demultiplexedListener{} + +type multiplexedListener struct { + manet.Listener + listeners map[DemultiplexedConnType]*demultiplexedListener + mx sync.Mutex + listenerCounter int + buffer chan manet.Conn + + ctx context.Context + closeFn func() error +} + +func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) { + if !connType.IsKnown() { + return nil, fmt.Errorf("unknown connection type: %s", connType) + } + + m.mx.Lock() + defer m.mx.Unlock() + l, ok := m.listeners[connType] + if ok { + return l, nil + } + + ctx, cancel := context.WithCancel(m.ctx) + closeFn := func() error { + cancel() + m.mx.Lock() + defer m.mx.Unlock() + m.listenerCounter-- + if m.listenerCounter == 0 { + return m.Close() + } + return nil + } + + l = &demultiplexedListener{ + buffer: make(chan manet.Conn, 16), // TODO: how big should this buffer be? + inner: m.Listener, + ctx: ctx, + closeFn: closeFn, + } + + m.listeners[connType] = l + m.listenerCounter++ + + return l, nil +} + +func (m *multiplexedListener) Run() error { + const numWorkers = 16 + for i := 0; i < numWorkers; i++ { + go func() { + m.background() + }() + } + + for { + c, err := m.Listener.Accept() + if err != nil { + return err + } + + select { + case m.buffer <- c: + case <-m.ctx.Done(): + return transport.ErrListenerClosed + } + } +} + +func (m *multiplexedListener) background() { + // TODO: if/how do we want to handle stalled connections and stop them from clogging up the pipeline? + // Drop connection because the buffer is full + for { + select { + case c := <-m.buffer: + t, sampleC, err := ConnTypeFromConn(c) + if err != nil { + closeErr := c.Close() + err = errors.Join(err, closeErr) + log.Debugf("error demultiplexing connection: %s", err.Error()) + continue + } + + demux, ok := m.listeners[t] + if !ok { + closeErr := c.Close() + if closeErr != nil { + log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error()) + } else { + log.Debugf("no registered listener for demultiplex connection %s", t) + } + continue + } + + select { + case demux.buffer <- sampleC: + case <-m.ctx.Done(): + return + default: + closeErr := c.Close() + if closeErr != nil { + log.Debugf("dropped connection due to full buffer of awaiting connections of type %s. Error closing the connection %s", t, closeErr.Error()) + } else { + log.Debugf("dropped connection due to full buffer of awaiting connections of type %s", t) + } + continue + } + case <-m.ctx.Done(): + return + } + } +} + +func (m *multiplexedListener) Close() error { + cerr := m.closeFn() + lerr := m.Listener.Close() + return errors.Join(lerr, cerr) +} + +type demultiplexedListener struct { + buffer chan manet.Conn + inner manet.Listener + ctx context.Context + closeFn func() error +} + +func (m *demultiplexedListener) Accept() (manet.Conn, error) { + select { + case c := <-m.buffer: + return c, nil + case <-m.ctx.Done(): + return nil, transport.ErrListenerClosed + } +} + +func (m *demultiplexedListener) Close() error { + return m.closeFn() +} + +func (m *demultiplexedListener) Multiaddr() ma.Multiaddr { + // TODO: do we need to add a suffix for the rest of the transport? + return m.inner.Multiaddr() +} + +func (m *demultiplexedListener) Addr() net.Addr { + return m.inner.Addr() +} diff --git a/p2p/transport/tcp/reuseport.go b/p2p/transport/tcpreuse/reuseport.go similarity index 81% rename from p2p/transport/tcp/reuseport.go rename to p2p/transport/tcpreuse/reuseport.go index ba09304622..a2529c0bda 100644 --- a/p2p/transport/tcp/reuseport.go +++ b/p2p/transport/tcpreuse/reuseport.go @@ -1,4 +1,4 @@ -package tcp +package tcpreuse import ( "os" @@ -11,13 +11,13 @@ import ( // It default to true. const envReuseport = "LIBP2P_TCP_REUSEPORT" -// envReuseportVal stores the value of envReuseport. defaults to true. -var envReuseportVal = true +// EnvReuseportVal stores the value of envReuseport. defaults to true. +var EnvReuseportVal = true func init() { v := strings.ToLower(os.Getenv(envReuseport)) if v == "false" || v == "f" || v == "0" { - envReuseportVal = false + EnvReuseportVal = false log.Infof("REUSEPORT disabled (LIBP2P_TCP_REUSEPORT=%s)", v) } } @@ -31,5 +31,5 @@ func init() { // If this becomes a sought after feature, we could add this to the config. // In the end, reuseport is a stop-gap. func ReuseportIsAvailable() bool { - return envReuseportVal && reuseport.Available() + return EnvReuseportVal && reuseport.Available() } diff --git a/p2p/transport/websocket/addrs_test.go b/p2p/transport/websocket/addrs_test.go index 3c5ba502a9..50a8b9e823 100644 --- a/p2p/transport/websocket/addrs_test.go +++ b/p2p/transport/websocket/addrs_test.go @@ -69,7 +69,7 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) { } func TestListeningOnDNSAddr(t *testing.T) { - ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil) + ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil) require.NoError(t, err) addr := ln.Multiaddr() first, rest := ma.SplitFirst(addr) diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 8071ddb814..1bf4f2ee47 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -12,6 +12,7 @@ import ( logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" @@ -50,7 +51,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { // newListener creates a new listener from a raw net.Listener. // tlsConf may be nil (for unencrypted websockets). -func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { +func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) { parsed, err := parseWebsocketMultiaddr(a) if err != nil { return nil, err @@ -60,19 +61,36 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a) } - lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) - if err != nil { - return nil, err - } - nl, err := net.Listen(lnet, lnaddr) - if err != nil { - return nil, err + var nl net.Listener + + if sharedTcp == nil { + lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) + if err != nil { + return nil, err + } + nl, err = net.Listen(lnet, lnaddr) + if err != nil { + return nil, err + } + } else { + var connType tcpreuse.DemultiplexedConnType + if parsed.isWSS { + connType = tcpreuse.TLS + } else { + connType = tcpreuse.HTTP + } + mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType) + if err != nil { + return nil, err + } + nl = manet.NetListener(mal) } laddr, err := manet.FromNetAddr(nl.Addr()) if err != nil { return nil, err } + first, _ := ma.SplitFirst(a) // Don't resolve dns addresses. // We want to be able to announce domain names, so the peer can validate the TLS certificate. diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 0f07617dc7..304e4d7ba6 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" @@ -80,6 +81,13 @@ func WithTLSConfig(conf *tls.Config) Option { } } +func WithSharedTCP(mgr *tcpreuse.ConnMgr) Option { + return func(t *WebsocketTransport) error { + t.sharedTcp = mgr + return nil + } +} + // WebsocketTransport is the actual go-libp2p transport type WebsocketTransport struct { upgrader transport.Upgrader @@ -87,6 +95,8 @@ type WebsocketTransport struct { tlsClientConf *tls.Config tlsConf *tls.Config + + sharedTcp *tcpreuse.ConnMgr } var _ transport.Transport = (*WebsocketTransport)(nil) @@ -233,7 +243,7 @@ func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) { if t.tlsConf != nil { tlsConf = t.tlsConf.Clone() } - l, err := newListener(a, tlsConf) + l, err := newListener(a, tlsConf, t.sharedTcp) if err != nil { return nil, err } From 2ae20f4717f3eaa62cda3db6fd8aa7f759b32c31 Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 2 Oct 2024 20:12:27 +0530 Subject: [PATCH 02/32] add tests for listener --- p2p/transport/tcp/tcp.go | 4 +- p2p/transport/tcpreuse/demultiplex.go | 48 +-- p2p/transport/tcpreuse/listener.go | 163 +++++---- p2p/transport/tcpreuse/listener_test.go | 430 ++++++++++++++++++++++++ p2p/transport/websocket/listener.go | 4 +- 5 files changed, 553 insertions(+), 96 deletions(-) create mode 100644 p2p/transport/tcpreuse/listener_test.go diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index 66fe9b7631..5883e43f6a 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -148,7 +148,7 @@ var _ transport.Transport = &TcpTransport{} var _ transport.DialUpdater = &TcpTransport{} // NewTCPTransport creates a tcp transport object that tracks dialers and listeners -// created. It represents an entire TCP stack (though it might not necessarily be). +// created. func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) { if rcmgr == nil { rcmgr = &network.NullResourceManager{} @@ -269,7 +269,7 @@ func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { if t.sharedTcp == nil { list, err = t.unsharedMAListen(laddr) } else { - list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.MultistreamSelect) + list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect) } if err != nil { return nil, err diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index 59e26a9aee..2036c91437 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -23,22 +23,24 @@ type peekAble interface { var _ peekAble = (*bufio.Reader)(nil) +// TODO: We can unexport this type and rely completely on the multiaddr passed in to +// DemultiplexedListen. type DemultiplexedConnType int const ( - Unknown DemultiplexedConnType = iota - MultistreamSelect - HTTP - TLS + DemultiplexedConnType_Unknown DemultiplexedConnType = iota + DemultiplexedConnType_MultistreamSelect + DemultiplexedConnType_HTTP + DemultiplexedConnType_TLS ) func (t DemultiplexedConnType) String() string { switch t { - case MultistreamSelect: + case DemultiplexedConnType_MultistreamSelect: return "MultistreamSelect" - case HTTP: + case DemultiplexedConnType_HTTP: return "HTTP" - case TLS: + case DemultiplexedConnType_TLS: return "TLS" default: return fmt.Sprintf("Unknown(%d)", int(t)) @@ -49,7 +51,7 @@ func (t DemultiplexedConnType) IsKnown() bool { return t >= 1 || t <= 3 } -func ConnTypeFromConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) { +func getDemultiplexedConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) { if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) @@ -67,20 +69,24 @@ func ConnTypeFromConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) { } if IsMultistreamSelect(s) { - return MultistreamSelect, sc, nil + return DemultiplexedConnType_MultistreamSelect, sc, nil } if IsTLS(s) { - return TLS, sc, nil + return DemultiplexedConnType_TLS, sc, nil } if IsHTTP(s) { - return HTTP, sc, nil + return DemultiplexedConnType_HTTP, sc, nil } - return Unknown, sc, nil + return DemultiplexedConnType_Unknown, sc, nil } -// ReadSampleFromConn read the sample and returns a reader which still include the sample, so it can be kept undamaged. -// If an error occurs it only return the error. +// ReadSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged. +// If an error occurs it only returns the error. func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) { + // TODO: Should we remove this? This is only implemented by bufio.Reader. + // This made sense for magiselect: https://github.com/libp2p/go-libp2p/pull/2737 as it deals with a wrapped + // ReadWriteCloser from multistream which does use a buffered reader underneath. + // For our present purpose, we have a net.Conn and no net.Conn implementation offers peeking. if peekAble, ok := c.(peekAble); ok { b, err := peekAble.Peek(len(Sample{})) switch { @@ -92,6 +98,7 @@ func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) { return Sample(b), mac, nil case errors.Is(err, bufio.ErrBufferFull): + // We can only peek < len(Sample{}) data. // fallback to sampledConn default: return Sample{}, nil, err @@ -118,13 +125,12 @@ func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) { if err != nil { return Sample{}, nil, err } - return sc.s, sc, nil } -// Try out best to mimic a TCPConn's functions -// Note: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be misused given we've read a few bytes from the connection -// If this is an issue here we can revisit the options. +// tcpConnInterface is the interface for TCPConn's functions +// Note: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be misused given we've read a few bytes from the connection. +// TODO: allow SyscallConn? Disallowing it breaks metrics tracking in TCP Transport. type tcpConnInterface interface { net.Conn @@ -180,12 +186,12 @@ func (sc *sampledConn) Read(b []byte) (int, error) { return sc.tcpConnInterface.Read(b) } -// forward optimizations +// TODO: Do we need these? + func (sc *sampledConn) ReadFrom(r io.Reader) (int64, error) { return io.Copy(sc.tcpConnInterface, r) } -// forward optimizations func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) { if int(sc.readFromSample) != len(sc.s) { b := sc.s[sc.readFromSample:] @@ -212,7 +218,7 @@ type Matcher interface { Match(s Sample) bool } -// Sample might evolve over time. +// Sample is the byte sequence we use to demultiplex. type Sample [3]byte // Matchers are implemented here instead of in the transports so we can easily fuzz them together. diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 59aeed1f93..73286d5006 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -14,8 +14,11 @@ import ( manet "github.com/multiformats/go-multiaddr/net" ) +const acceptQueueSize = 64 // It is fine to read 3 bytes from 64 connections in parallel. + var log = logging.Logger("tcp-demultiplex") +// ConnMgr enables you to share the same listen address between TCP and WebSocket transports. type ConnMgr struct { disableReuseport bool reuse reuseport.Transport @@ -31,11 +34,11 @@ func NewConnMgr(disableReuseport bool) *ConnMgr { } } -func (t *ConnMgr) maListen(laddr ma.Multiaddr) (manet.Listener, error) { +func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) { if t.useReuseport() { - return t.reuse.Listen(laddr) + return t.reuse.Listen(listenAddr) } else { - return manet.Listen(laddr) + return manet.Listen(listenAddr) } } @@ -43,10 +46,34 @@ func (t *ConnMgr) useReuseport() bool { return !t.disableReuseport && ReuseportIsAvailable() } +func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) { + haveTCP := false + addr, _ := ma.SplitFunc(listenAddr, func(c ma.Component) bool { + if haveTCP { + return true + } + if c.Protocol().Code == ma.P_TCP { + haveTCP = true + } + return false + }) + if !haveTCP { + return nil, fmt.Errorf("invalid listen addr %s, need tcp address", listenAddr) + } + return addr, nil +} + +// DemultiplexedListen returns a listener for laddr listening for `connType` connections. The connections +// accepted from returned listeners need to be upgraded with a `transport.Upgrader`. +// NOTE: All listeners for port 0 share the same underlying socket, so they have the same specific port. func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) { if !connType.IsKnown() { return nil, fmt.Errorf("unknown connection type: %s", connType) } + laddr, err := getTCPAddr(laddr) + if err != nil { + return nil, err + } t.mx.Lock() defer t.mx.Unlock() @@ -75,7 +102,6 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed ml = &multiplexedListener{ Listener: l, listeners: make(map[DemultiplexedConnType]*demultiplexedListener), - buffer: make(chan manet.Conn, 16), // TODO: how big should this buffer be? ctx: ctx, closeFn: cancelFunc, } @@ -86,15 +112,11 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed return nil, errors.Join(err, cerr) } - go func() { - err = ml.Run() - if err != nil { - log.Debugf("Error running multiplexed listener: %s", err.Error()) - } - }() - t.listeners[laddr.String()] = ml + ml.wg.Add(1) + go ml.run() + return dl, nil } @@ -102,13 +124,12 @@ var _ manet.Listener = &demultiplexedListener{} type multiplexedListener struct { manet.Listener - listeners map[DemultiplexedConnType]*demultiplexedListener - mx sync.Mutex - listenerCounter int - buffer chan manet.Conn + listeners map[DemultiplexedConnType]*demultiplexedListener + mx sync.RWMutex ctx context.Context closeFn func() error + wg sync.WaitGroup } func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) { @@ -124,67 +145,51 @@ func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType } ctx, cancel := context.WithCancel(m.ctx) - closeFn := func() error { - cancel() - m.mx.Lock() - defer m.mx.Unlock() - m.listenerCounter-- - if m.listenerCounter == 0 { - return m.Close() - } - return nil - } - l = &demultiplexedListener{ - buffer: make(chan manet.Conn, 16), // TODO: how big should this buffer be? - inner: m.Listener, - ctx: ctx, - closeFn: closeFn, + buffer: make(chan manet.Conn), + inner: m.Listener, + ctx: ctx, + cancelFunc: cancel, + closeFn: func() error { m.removeDemultiplexedListener(connType); return nil }, } m.listeners[connType] = l - m.listenerCounter++ return l, nil } -func (m *multiplexedListener) Run() error { - const numWorkers = 16 - for i := 0; i < numWorkers; i++ { - go func() { - m.background() - }() - } - +func (m *multiplexedListener) run() error { + defer m.Close() + defer m.wg.Done() + acceptQueue := make(chan struct{}, acceptQueueSize) for { c, err := m.Listener.Accept() if err != nil { return err } - select { - case m.buffer <- c: + case acceptQueue <- struct{}{}: case <-m.ctx.Done(): - return transport.ErrListenerClosed + c.Close() + log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr()) } - } -} -func (m *multiplexedListener) background() { - // TODO: if/how do we want to handle stalled connections and stop them from clogging up the pipeline? - // Drop connection because the buffer is full - for { - select { - case c := <-m.buffer: - t, sampleC, err := ConnTypeFromConn(c) + m.wg.Add(1) + go func() { + defer func() { <-acceptQueue }() + defer m.wg.Done() + // TODO: if/how do we want to handle stalled connections and stop them from clogging up the pipeline? + // Drop connection because the buffer is full + t, sampleC, err := getDemultiplexedConn(c) if err != nil { closeErr := c.Close() err = errors.Join(err, closeErr) log.Debugf("error demultiplexing connection: %s", err.Error()) - continue + return } - + m.mx.RLock() demux, ok := m.listeners[t] + m.mx.RUnlock() if !ok { closeErr := c.Close() if closeErr != nil { @@ -192,39 +197,55 @@ func (m *multiplexedListener) background() { } else { log.Debugf("no registered listener for demultiplex connection %s", t) } - continue + return } select { case demux.buffer <- sampleC: case <-m.ctx.Done(): + sampleC.Close() return - default: - closeErr := c.Close() - if closeErr != nil { - log.Debugf("dropped connection due to full buffer of awaiting connections of type %s. Error closing the connection %s", t, closeErr.Error()) - } else { - log.Debugf("dropped connection due to full buffer of awaiting connections of type %s", t) - } - continue } - case <-m.ctx.Done(): - return - } + }() } } func (m *multiplexedListener) Close() error { - cerr := m.closeFn() + m.mx.Lock() + for _, l := range m.listeners { + l.cancelFunc() + } + err := m.closeListener() + m.mx.Unlock() + m.wg.Wait() + return err +} + +func (m *multiplexedListener) closeListener() error { lerr := m.Listener.Close() + cerr := m.closeFn() return errors.Join(lerr, cerr) } +func (m *multiplexedListener) removeDemultiplexedListener(c DemultiplexedConnType) { + m.mx.Lock() + defer m.mx.Unlock() + + delete(m.listeners, c) + if len(m.listeners) == 0 { + m.closeListener() + m.mx.Unlock() + m.wg.Wait() + m.mx.Lock() + } +} + type demultiplexedListener struct { - buffer chan manet.Conn - inner manet.Listener - ctx context.Context - closeFn func() error + buffer chan manet.Conn + inner manet.Listener + ctx context.Context + cancelFunc context.CancelFunc + closeFn func() error } func (m *demultiplexedListener) Accept() (manet.Conn, error) { @@ -237,11 +258,11 @@ func (m *demultiplexedListener) Accept() (manet.Conn, error) { } func (m *demultiplexedListener) Close() error { + m.cancelFunc() return m.closeFn() } func (m *demultiplexedListener) Multiaddr() ma.Multiaddr { - // TODO: do we need to add a suffix for the rest of the transport? return m.inner.Multiaddr() } diff --git a/p2p/transport/tcpreuse/listener_test.go b/p2p/transport/tcpreuse/listener_test.go new file mode 100644 index 0000000000..f9d1fe589c --- /dev/null +++ b/p2p/transport/tcpreuse/listener_test.go @@ -0,0 +1,430 @@ +package tcpreuse + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "net" + "net/http" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "github.com/multiformats/go-multistream" + "github.com/stretchr/testify/require" +) + +func selfSignedTLSConfig(t *testing.T) *tls.Config { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + certTemplate := x509.Certificate{ + SerialNumber: &big.Int{}, + Subject: pkix.Name{ + Organization: []string{"Test"}, + }, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv) + require.NoError(t, err) + + cert := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: priv, + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + return tlsConfig +} + +type wsHandler struct{ conns chan *websocket.Conn } + +func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + u := websocket.Upgrader{} + c, _ := u.Upgrade(w, r, http.Header{}) + wh.conns <- c +} + +func TestListenerSingle(t *testing.T) { + listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") + const N = 128 + for _, disableReuseport := range []bool{true, false} { + t.Run(fmt.Sprintf("multistream-reuseport:%v", disableReuseport), func(t *testing.T) { + cm := NewConnMgr(disableReuseport) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + go func() { + d := net.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := d.DialContext(ctx, l.Addr().Network(), l.Addr().String()) + if err != nil { + t.Error("failed to dial", err, i) + return + } + lconn := multistream.NewMSSelect(conn, "a") + buf := make([]byte, 10) + _, err = lconn.Write([]byte("hello-multistream")) + if err != nil { + t.Error(err) + } + _, err = lconn.Read(buf) + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + var wg sync.WaitGroup + for i := 0; i < N; i++ { + c, err := l.Accept() + require.NoError(t, err) + wg.Add(1) + go func() { + defer wg.Done() + cc := multistream.NewMSSelect(c, "a") + buf := make([]byte, 30) + n, err := cc.Read(buf) + require.NoError(t, err) + require.Equal(t, "hello-multistream", string(buf[:n])) + c.Close() + }() + } + wg.Wait() + }) + + t.Run(fmt.Sprintf("WebSocket-reuseport:%v", disableReuseport), func(t *testing.T) { + cm := NewConnMgr(disableReuseport) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + http.Serve(manet.NetListener(l), wh) + }() + go func() { + d := websocket.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", l.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("hello")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + var wg sync.WaitGroup + for i := 0; i < N; i++ { + c := <-wh.conns + wg.Add(1) + go func() { + defer wg.Done() + msgType, buf, err := c.ReadMessage() + require.NoError(t, err) + require.Equal(t, msgType, websocket.TextMessage) + require.Equal(t, "hello", string(buf)) + c.Close() + }() + } + wg.Wait() + }) + + t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", disableReuseport), func(t *testing.T) { + cm := NewConnMgr(disableReuseport) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) + require.NoError(t, err) + defer l.Close() + wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)} + s.ServeTLS(manet.NetListener(l), "", "") + }() + go func() { + d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", l.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("hello")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + var wg sync.WaitGroup + for i := 0; i < N; i++ { + c := <-wh.conns + wg.Add(1) + go func() { + defer wg.Done() + msgType, buf, err := c.ReadMessage() + require.NoError(t, err) + require.Equal(t, msgType, websocket.TextMessage) + require.Equal(t, "hello", string(buf)) + c.Close() + }() + } + wg.Wait() + }) + } +} + +func TestListenerMultiplexed(t *testing.T) { + listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") + const N = 128 + for _, disableReuseport := range []bool{true, false} { + cm := NewConnMgr(disableReuseport) + msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + defer msl.Close() + + wsl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + defer wsl.Close() + require.Equal(t, wsl.Multiaddr(), msl.Multiaddr()) + wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + http.Serve(manet.NetListener(wsl), wh) + }() + + wssl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) + require.NoError(t, err) + defer wssl.Close() + require.Equal(t, wssl.Multiaddr(), wsl.Multiaddr()) + whs := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + s := http.Server{Handler: whs, TLSConfig: selfSignedTLSConfig(t)} + s.ServeTLS(manet.NetListener(wssl), "", "") + }() + + // multistream connections + go func() { + d := net.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := d.DialContext(ctx, msl.Addr().Network(), msl.Addr().String()) + if err != nil { + t.Error("failed to dial", err, i) + return + } + lconn := multistream.NewMSSelect(conn, "a") + buf := make([]byte, 10) + _, err = lconn.Write([]byte("multistream")) + if err != nil { + t.Error(err) + } + _, err = lconn.Read(buf) + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + // ws connections + go func() { + d := websocket.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", msl.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("websocket")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + // wss connections + go func() { + d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", msl.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("websocket-tls")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < N; i++ { + c, err := msl.Accept() + require.NoError(t, err) + wg.Add(1) + go func() { + defer wg.Done() + cc := multistream.NewMSSelect(c, "a") + buf := make([]byte, 20) + n, err := cc.Read(buf) + require.NoError(t, err) + require.Equal(t, "multistream", string(buf[:n])) + cc.Close() + }() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < N; i++ { + c := <-wh.conns + wg.Add(1) + go func() { + defer wg.Done() + msgType, buf, err := c.ReadMessage() + require.NoError(t, err) + require.Equal(t, msgType, websocket.TextMessage) + require.Equal(t, "websocket", string(buf)) + c.Close() + }() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < N; i++ { + c := <-whs.conns + wg.Add(1) + go func() { + defer wg.Done() + msgType, buf, err := c.ReadMessage() + require.NoError(t, err) + require.Equal(t, msgType, websocket.TextMessage) + require.Equal(t, "websocket-tls", string(buf)) + c.Close() + }() + } + }() + wg.Wait() + } +} + +func TestListenerClose(t *testing.T) { + + testClose := func(listenAddr ma.Multiaddr) { + // listen on port 0 + cm := NewConnMgr(true) + ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + wl.Close() + + wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + + mll, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + require.Equal(t, mll, ml) + + wl.Close() + ml.Close() + + ml, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + + require.NotEqual(t, ml.Multiaddr(), mll.Multiaddr()) + require.NotEqual(t, mll, ml) + ml.Close() + + // Now listen on the specific port previously used + listenAddr = ml.Multiaddr() + ml, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + wl.Close() + + wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + + mll, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + require.Equal(t, mll, ml) + + wl.Close() + ml.Close() + + ml, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + + require.Equal(t, ml.Multiaddr(), mll.Multiaddr()) + require.NotEqual(t, mll, ml) + ml.Close() + } + listenAddrs := []ma.Multiaddr{ma.StringCast("/ip4/0.0.0.0/tcp/0"), ma.StringCast("/ip6/::/tcp/0")} + for _, listenAddr := range listenAddrs { + testClose(listenAddr) + } +} diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 1bf4f2ee47..40b290e212 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -75,9 +75,9 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg } else { var connType tcpreuse.DemultiplexedConnType if parsed.isWSS { - connType = tcpreuse.TLS + connType = tcpreuse.DemultiplexedConnType_TLS } else { - connType = tcpreuse.HTTP + connType = tcpreuse.DemultiplexedConnType_HTTP } mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType) if err != nil { From 05ee445f47a468b44d9a5688453b1274935c55c1 Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 6 Oct 2024 21:30:18 +0530 Subject: [PATCH 03/32] create conn scope early to prevent DoS attacks --- p2p/net/upgrader/listener.go | 38 +++++++++++------- p2p/transport/tcpreuse/demultiplex.go | 26 +++++++++---- p2p/transport/tcpreuse/listener.go | 52 +++++++++++++++++++++---- p2p/transport/tcpreuse/listener_test.go | 10 ++--- p2p/transport/websocket/conn.go | 11 ++++++ 5 files changed, 102 insertions(+), 35 deletions(-) diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index 8af2791b36..0530bde292 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -84,23 +84,33 @@ func (l *listener) handleIncoming() { } catcher.Reset() - // gate the connection if applicable - if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { - log.Debugf("gater blocked incoming connection on local addr %s from %s", - maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) - if err := maconn.Close(); err != nil { - log.Warnf("failed to close incoming connection rejected by gater: %s", err) - } - continue + var connScope network.ConnManagementScope + if sc, ok := maconn.(interface { + Scope() network.ConnManagementScope + }); ok { + connScope = sc.Scope() } - connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr()) - if err != nil { - log.Debugw("resource manager blocked accept of new connection", "error", err) - if err := maconn.Close(); err != nil { - log.Warnf("failed to incoming connection rejected by resource manager: %s", err) + if connScope != nil { + // gate the connection if applicable + if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { + log.Debugf("gater blocked incoming connection on local addr %s from %s", + maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) + if err := maconn.Close(); err != nil { + log.Warnf("failed to close incoming connection rejected by gater: %s", err) + } + continue + } + + var err error + connScope, err = l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr()) + if err != nil { + log.Debugw("resource manager blocked accept of new connection", "error", err) + if err := maconn.Close(); err != nil { + log.Warnf("failed to incoming connection rejected by resource manager: %s", err) + } + continue } - continue } // The go routine below calls Release when the context is diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index 2036c91437..342e7de0b3 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -9,6 +9,7 @@ import ( "net" "time" + "github.com/libp2p/go-libp2p/core/network" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" ) @@ -51,13 +52,13 @@ func (t DemultiplexedConnType) IsKnown() bool { return t >= 1 || t <= 3 } -func getDemultiplexedConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) { +func getDemultiplexedConn(c net.Conn, scope network.ConnManagementScope) (DemultiplexedConnType, manet.Conn, error) { if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) } - s, sc, err := ReadSampleFromConn(c) + s, sc, err := readSampleFromConn(c, scope) if err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) @@ -80,9 +81,9 @@ func getDemultiplexedConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) return DemultiplexedConnType_Unknown, sc, nil } -// ReadSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged. +// readSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged. // If an error occurs it only returns the error. -func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) { +func readSampleFromConn(c net.Conn, scope network.ConnManagementScope) (Sample, manet.Conn, error) { // TODO: Should we remove this? This is only implemented by bufio.Reader. // This made sense for magiselect: https://github.com/libp2p/go-libp2p/pull/2737 as it deals with a wrapped // ReadWriteCloser from multistream which does use a buffered reader underneath. @@ -120,7 +121,11 @@ func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) { return Sample{}, nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err) } - sc := &sampledConn{tcpConnInterface: tcpConnLike, maEndpoints: maEndpoints{laddr: laddr, raddr: raddr}} + sc := &sampledConn{ + tcpConnInterface: tcpConnLike, + maEndpoints: maEndpoints{laddr: laddr, raddr: raddr}, + scope: scope, + } _, err = io.ReadFull(c, sc.s[:]) if err != nil { return Sample{}, nil, err @@ -167,7 +172,7 @@ func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr { type sampledConn struct { tcpConnInterface maEndpoints - + scope network.ConnManagementScope s Sample readFromSample uint8 } @@ -214,8 +219,13 @@ func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) { return total, err } -type Matcher interface { - Match(s Sample) bool +func (sc *sampledConn) Scope() network.ConnManagementScope { + return sc.scope +} + +func (sc *sampledConn) Close() error { + sc.scope.Done() + return sc.tcpConnInterface.Close() } // Sample is the byte sequence we use to demultiplex. diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 73286d5006..4a5bfa119b 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -8,6 +8,8 @@ import ( "sync" logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/connmgr" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/net/reuseport" ma "github.com/multiformats/go-multiaddr" @@ -22,14 +24,22 @@ var log = logging.Logger("tcp-demultiplex") type ConnMgr struct { disableReuseport bool reuse reuseport.Transport - listeners map[string]*multiplexedListener - mx sync.Mutex + connGater connmgr.ConnectionGater + rcmgr network.ResourceManager + + mx sync.Mutex + listeners map[string]*multiplexedListener } -func NewConnMgr(disableReuseport bool) *ConnMgr { +func NewConnMgr(disableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr { + if rcmgr == nil { + rcmgr = &network.NullResourceManager{} + } return &ConnMgr{ disableReuseport: disableReuseport, reuse: reuseport.Transport{}, + connGater: gater, + rcmgr: rcmgr, listeners: make(map[string]*multiplexedListener), } } @@ -104,6 +114,8 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed listeners: make(map[DemultiplexedConnType]*demultiplexedListener), ctx: ctx, closeFn: cancelFunc, + connGater: t.connGater, + rcmgr: t.rcmgr, } dl, err := ml.DemultiplexedListen(connType) @@ -127,9 +139,11 @@ type multiplexedListener struct { listeners map[DemultiplexedConnType]*demultiplexedListener mx sync.RWMutex - ctx context.Context - closeFn func() error - wg sync.WaitGroup + connGater connmgr.ConnectionGater + rcmgr network.ResourceManager + ctx context.Context + closeFn func() error + wg sync.WaitGroup } func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) { @@ -167,6 +181,26 @@ func (m *multiplexedListener) run() error { if err != nil { return err } + + // gate the connection if applicable + if m.connGater != nil && !m.connGater.InterceptAccept(c) { + log.Debugf("gater blocked incoming connection on local addr %s from %s", + c.LocalMultiaddr(), c.RemoteMultiaddr()) + if err := c.Close(); err != nil { + log.Warnf("failed to close incoming connection rejected by gater: %s", err) + } + continue + } + + connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr()) + if err != nil { + log.Debugw("resource manager blocked accept of new connection", "error", err) + if err := c.Close(); err != nil { + log.Warnf("failed to incoming connection rejected by resource manager: %s", err) + } + continue + } + select { case acceptQueue <- struct{}{}: case <-m.ctx.Done(): @@ -180,18 +214,20 @@ func (m *multiplexedListener) run() error { defer m.wg.Done() // TODO: if/how do we want to handle stalled connections and stop them from clogging up the pipeline? // Drop connection because the buffer is full - t, sampleC, err := getDemultiplexedConn(c) + t, sampleC, err := getDemultiplexedConn(c, connScope) if err != nil { + connScope.Done() closeErr := c.Close() err = errors.Join(err, closeErr) log.Debugf("error demultiplexing connection: %s", err.Error()) return } + m.mx.RLock() demux, ok := m.listeners[t] m.mx.RUnlock() if !ok { - closeErr := c.Close() + closeErr := sampleC.Close() if closeErr != nil { log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error()) } else { diff --git a/p2p/transport/tcpreuse/listener_test.go b/p2p/transport/tcpreuse/listener_test.go index f9d1fe589c..b61ffce09d 100644 --- a/p2p/transport/tcpreuse/listener_test.go +++ b/p2p/transport/tcpreuse/listener_test.go @@ -65,7 +65,7 @@ func TestListenerSingle(t *testing.T) { const N = 128 for _, disableReuseport := range []bool{true, false} { t.Run(fmt.Sprintf("multistream-reuseport:%v", disableReuseport), func(t *testing.T) { - cm := NewConnMgr(disableReuseport) + cm := NewConnMgr(disableReuseport, nil, nil) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) go func() { @@ -112,7 +112,7 @@ func TestListenerSingle(t *testing.T) { }) t.Run(fmt.Sprintf("WebSocket-reuseport:%v", disableReuseport), func(t *testing.T) { - cm := NewConnMgr(disableReuseport) + cm := NewConnMgr(disableReuseport, nil, nil) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) require.NoError(t, err) wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} @@ -158,7 +158,7 @@ func TestListenerSingle(t *testing.T) { }) t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", disableReuseport), func(t *testing.T) { - cm := NewConnMgr(disableReuseport) + cm := NewConnMgr(disableReuseport, nil, nil) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) require.NoError(t, err) defer l.Close() @@ -211,7 +211,7 @@ func TestListenerMultiplexed(t *testing.T) { listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") const N = 128 for _, disableReuseport := range []bool{true, false} { - cm := NewConnMgr(disableReuseport) + cm := NewConnMgr(disableReuseport, nil, nil) msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) defer msl.Close() @@ -370,7 +370,7 @@ func TestListenerClose(t *testing.T) { testClose := func(listenAddr ma.Multiaddr) { // listen on port 0 - cm := NewConnMgr(true) + cm := NewConnMgr(true, nil, nil) ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 30b70055d0..19d4e46ec5 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -99,9 +99,20 @@ func (c *Conn) Write(b []byte) (n int, err error) { return len(b), nil } +func (c *Conn) Scope() network.ConnManagementScope { + nc := c.NetConn() + if sc, ok := nc.(interface { + Scope() network.ConnManagementScope + }); ok { + return sc.Scope() + } + return nil +} + // Close closes the connection. Only the first call to Close will receive the // close error, subsequent and concurrent calls will return nil. // This method is thread-safe. +// TODO: Fix this ^ func (c *Conn) Close() error { var err error c.closeOnce.Do(func() { From 3457a7bc680713afc620f9392b41c1cbc0d1a5b0 Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 6 Oct 2024 23:09:32 +0530 Subject: [PATCH 04/32] add fx option, move rcmgr and upgrader to sharedtcp --- config/config.go | 9 ++++++ libp2p_test.go | 15 +++++++++- options.go | 7 +++++ p2p/net/swarm/dial_worker_test.go | 2 +- p2p/net/swarm/swarm_addr_test.go | 2 +- p2p/net/swarm/swarm_dial_test.go | 8 +++--- p2p/net/swarm/testing/testing.go | 2 +- p2p/net/upgrader/listener.go | 2 +- p2p/protocol/circuitv2/relay/relay_test.go | 2 +- p2p/test/transport/gating_test.go | 25 +++++++++++++++- p2p/test/transport/transport_test.go | 32 +++++++++++++++++++++ p2p/transport/tcp/tcp.go | 3 +- p2p/transport/tcp/tcp_test.go | 20 ++++++------- p2p/transport/tcpreuse/listener.go | 1 - p2p/transport/websocket/conn.go | 33 ++++++++++++++++++++++ p2p/transport/websocket/listener.go | 18 ++++++------ p2p/transport/websocket/websocket.go | 3 +- p2p/transport/websocket/websocket_test.go | 28 +++++++++--------- 18 files changed, 165 insertions(+), 47 deletions(-) diff --git a/config/config.go b/config/config.go index fb5a2ab1b1..900c06bc30 100644 --- a/config/config.go +++ b/config/config.go @@ -38,6 +38,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "github.com/prometheus/client_golang/prometheus" @@ -145,6 +146,8 @@ type Config struct { CustomIPv6BlackHoleSuccessCounter bool UserFxOptions []fx.Option + + ShareTCPListener bool } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -289,6 +292,12 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }), fx.Provide(func() pnet.PSK { return cfg.PSK }), fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }), + fx.Provide(func(gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *tcpreuse.ConnMgr { + if !cfg.ShareTCPListener { + return nil + } + return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, gater, rcmgr) + }), fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn { hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool { quicAddrPorts := map[string]struct{}{} diff --git a/libp2p_test.go b/libp2p_test.go index b290227fc1..0aa261d23c 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -59,7 +59,7 @@ func TestTransportConstructor(t *testing.T) { _ connmgr.ConnectionGater, upgrader transport.Upgrader, ) transport.Transport { - tpt, err := tcp.NewTCPTransport(upgrader, nil) + tpt, err := tcp.NewTCPTransport(upgrader, nil, nil) require.NoError(t, err) return tpt } @@ -751,3 +751,16 @@ func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config { }}, } } + +func TestSharedTCPAddr(t *testing.T) { + h, err := New( + ShareTCPListener(), + Transport(tcp.NewTCPTransport), + Transport(websocket.New), + ListenAddrStrings("/ip4/0.0.0.0/tcp/8888"), + ListenAddrStrings("/ip4/0.0.0.0/tcp/8888/ws"), + ) + require.NoError(t, err) + fmt.Println(h.Addrs()) + h.Close() +} diff --git a/options.go b/options.go index 4fbf8eb2ac..7c94ed7892 100644 --- a/options.go +++ b/options.go @@ -643,3 +643,10 @@ func WithFxOption(opts ...fx.Option) Option { return nil } } + +func ShareTCPListener() Option { + return func(cfg *Config) error { + cfg.ShareTCPListener = true + return nil + } +} diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index ed4f00ff58..d264fd1230 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -84,7 +84,7 @@ func makeSwarmWithNoListenAddrs(t *testing.T, opts ...Option) *Swarm { upgrader := makeUpgrader(t, s) var tcpOpts []tcp.Option tcpOpts = append(tcpOpts, tcp.DisableReuseport()) - tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) + tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...) require.NoError(t, err) if err := s.AddTransport(tcpTransport); err != nil { t.Fatal(err) diff --git a/p2p/net/swarm/swarm_addr_test.go b/p2p/net/swarm/swarm_addr_test.go index 435866e920..43e76716e5 100644 --- a/p2p/net/swarm/swarm_addr_test.go +++ b/p2p/net/swarm/swarm_addr_test.go @@ -79,7 +79,7 @@ func TestDialAddressSelection(t *testing.T) { s, err := swarm.NewSwarm("local", nil, eventbus.NewBus()) require.NoError(t, err) - tcpTr, err := tcp.NewTCPTransport(nil, nil) + tcpTr, err := tcp.NewTCPTransport(nil, nil, nil) require.NoError(t, err) require.NoError(t, s.AddTransport(tcpTr)) reuse, err := quicreuse.NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{}) diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index 0ef43cf62e..add6f5cbba 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -53,7 +53,7 @@ func TestAddrsForDial(t *testing.T) { ps.AddPrivKey(id, priv) t.Cleanup(func() { ps.Close() }) - tpt, err := websocket.New(nil, &network.NullResourceManager{}) + tpt, err := websocket.New(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(ResolverFromMaDNS{resolver})) require.NoError(t, err) @@ -100,7 +100,7 @@ func TestDedupAddrsForDial(t *testing.T) { require.NoError(t, err) defer s.Close() - tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}) + tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(tpt) require.NoError(t, err) @@ -134,7 +134,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { }) // Add a tcp transport so that we know we can dial a tcp multiaddr and we don't filter it out. - tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}) + tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(tpt) require.NoError(t, err) @@ -151,7 +151,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { err = s.AddTransport(wtTpt) require.NoError(t, err) - wsTpt, err := websocket.New(nil, &network.NullResourceManager{}) + wsTpt, err := websocket.New(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(wsTpt) require.NoError(t, err) diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index 2bbe8b27a5..773314a1b8 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -164,7 +164,7 @@ func GenSwarm(t testing.TB, opts ...Option) *swarm.Swarm { if cfg.disableReuseport { tcpOpts = append(tcpOpts, tcp.DisableReuseport()) } - tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) + tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...) require.NoError(t, err) if err := s.AddTransport(tcpTransport); err != nil { t.Fatal(err) diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index 0530bde292..65da2bec6c 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -91,7 +91,7 @@ func (l *listener) handleIncoming() { connScope = sc.Scope() } - if connScope != nil { + if connScope == nil { // gate the connection if applicable if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { log.Debugf("gater blocked incoming connection on local addr %s from %s", diff --git a/p2p/protocol/circuitv2/relay/relay_test.go b/p2p/protocol/circuitv2/relay/relay_test.go index e5d32b0c96..f6b63e32de 100644 --- a/p2p/protocol/circuitv2/relay/relay_test.go +++ b/p2p/protocol/circuitv2/relay/relay_test.go @@ -60,7 +60,7 @@ func getNetHosts(t *testing.T, ctx context.Context, n int) (hosts []host.Host, u upgrader := swarmt.GenUpgrader(t, netw, nil) upgraders = append(upgraders, upgrader) - tpt, err := tcp.NewTCPTransport(upgrader, nil) + tpt, err := tcp.NewTCPTransport(upgrader, nil, nil) if err != nil { t.Fatal(err) } diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index df53da6eeb..99ce67b521 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -2,6 +2,8 @@ package transport_integration import ( "context" + "encoding/binary" + "net/netip" "strings" "testing" "time" @@ -30,6 +32,23 @@ func stripCertHash(addr ma.Multiaddr) ma.Multiaddr { return addr } +func addrPort(addr ma.Multiaddr) netip.AddrPort { + a := netip.Addr{} + p := uint16(0) + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_IP4 || c.Protocol().Code == ma.P_IP6 { + a, _ = netip.AddrFromSlice(c.RawValue()) + return false + } + if c.Protocol().Code == ma.P_UDP || c.Protocol().Code == ma.P_TCP { + p = binary.BigEndian.Uint16(c.RawValue()) + return true + } + return false + }) + return netip.AddrPortFrom(a, p) +} + func TestInterceptPeerDial(t *testing.T) { if race.WithRace() { t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") @@ -173,10 +192,14 @@ func TestInterceptAccept(t *testing.T) { // remove the certhash component from WebTransport addresses require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) }).AnyTimes() + } else if strings.Contains(tc.Name, "WebSocket-Shared") { + connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { + require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr())) + }) } else { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) + require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr()) }) } diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 7cfab5f3ca..60f8ca0c06 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -99,6 +99,38 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "TCP-Shared / TLS / Yamux", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) + libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New)) + libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport)) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, + { + Name: "WebSocket-Shared", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, { Name: "WebSocket", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index 5883e43f6a..1b145c2b45 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -149,7 +149,7 @@ var _ transport.DialUpdater = &TcpTransport{} // NewTCPTransport creates a tcp transport object that tracks dialers and listeners // created. -func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) { +func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*TcpTransport, error) { if rcmgr == nil { rcmgr = &network.NullResourceManager{} } @@ -157,6 +157,7 @@ func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, upgrader: upgrader, connectTimeout: defaultConnectTimeout, // can be set by using the WithConnectionTimeout option rcmgr: rcmgr, + sharedTcp: sharedTCP, } for _, o := range opts { if err := o(tr); err != nil { diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index 4c692fbf4c..1f939d92be 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -32,11 +32,11 @@ func TestTcpTransport(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil) + tb, err := NewTCPTransport(ub, nil, nil) require.NoError(t, err) zero := "/ip4/127.0.0.1/tcp/0" @@ -53,11 +53,11 @@ func TestTcpTransportWithMetrics(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil, WithMetrics()) + ta, err := NewTCPTransport(ua, nil, nil, WithMetrics()) require.NoError(t, err) ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil, WithMetrics()) + tb, err := NewTCPTransport(ub, nil, nil, WithMetrics()) require.NoError(t, err) zero := "/ip4/127.0.0.1/tcp/0" @@ -73,7 +73,7 @@ func TestResourceManager(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) require.NoError(t, err) @@ -82,7 +82,7 @@ func TestResourceManager(t *testing.T) { ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) rcmgr := mocknetwork.NewMockResourceManager(ctrl) - tb, err := NewTCPTransport(ub, rcmgr) + tb, err := NewTCPTransport(ub, rcmgr, nil) require.NoError(t, err) t.Run("success", func(t *testing.T) { @@ -120,7 +120,7 @@ func TestTcpTransportCantDialDNS(t *testing.T) { require.NoError(t, err) var u transport.Upgrader - tpt, err := NewTCPTransport(u, nil) + tpt, err := NewTCPTransport(u, nil, nil) require.NoError(t, err) if tpt.CanDial(dnsa) { @@ -138,7 +138,7 @@ func TestTcpTransportCantListenUtp(t *testing.T) { require.NoError(t, err) var u transport.Upgrader - tpt, err := NewTCPTransport(u, nil) + tpt, err := NewTCPTransport(u, nil, nil) require.NoError(t, err) _, err = tpt.Listen(utpa) @@ -155,7 +155,7 @@ func TestDialWithUpdates(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) require.NoError(t, err) @@ -163,7 +163,7 @@ func TestDialWithUpdates(t *testing.T) { ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil) + tb, err := NewTCPTransport(ub, nil, nil) require.NoError(t, err) updCh := make(chan transport.DialUpdate, 1) diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 4a5bfa119b..e0bfe8eef2 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -191,7 +191,6 @@ func (m *multiplexedListener) run() error { } continue } - connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr()) if err != nil { log.Debugw("resource manager blocked accept of new connection", "error", err) diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 19d4e46ec5..df97189d90 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -8,6 +8,8 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" ws "github.com/gorilla/websocket" ) @@ -23,21 +25,52 @@ type Conn struct { DefaultMessageType int reader io.Reader closeOnce sync.Once + laddr ma.Multiaddr + raddr ma.Multiaddr readLock, writeLock sync.Mutex } var _ net.Conn = (*Conn)(nil) +var _ manet.Conn = (*Conn)(nil) // NewConn creates a Conn given a regular gorilla/websocket Conn. +// +// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release. func NewConn(raw *ws.Conn, secure bool) *Conn { + lna := NewAddrWithScheme(raw.LocalAddr().String(), secure) + laddr, err := manet.FromNetAddr(lna) + if err != nil { + log.Errorf("BUG: invalid localaddr on websocket conn", raw.LocalAddr()) + return nil + } + + rna := NewAddrWithScheme(raw.RemoteAddr().String(), secure) + raddr, err := manet.FromNetAddr(rna) + if err != nil { + log.Errorf("BUG: invalid remoteaddr on websocket conn", raw.RemoteAddr()) + return nil + } + return &Conn{ Conn: raw, secure: secure, DefaultMessageType: ws.BinaryMessage, + laddr: laddr, + raddr: raddr, } } +// LocalMultiaddr implements manet.Conn. +func (c *Conn) LocalMultiaddr() ma.Multiaddr { + return c.laddr +} + +// RemoteMultiaddr implements manet.Conn. +func (c *Conn) RemoteMultiaddr() ma.Multiaddr { + return c.raddr +} + func (c *Conn) Read(b []byte) (int, error) { c.readLock.Lock() defer c.readLock.Unlock() diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 40b290e212..dd399aa079 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -4,11 +4,12 @@ import ( "crypto/tls" "errors" "fmt" - "go.uber.org/zap" "net" "net/http" "sync" + "go.uber.org/zap" + logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/transport" @@ -129,7 +130,12 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { // The upgrader writes a response for us. return } - + nc := NewConn(c, l.isWss) + if nc == nil { + c.Close() + w.WriteHeader(500) + return + } select { case l.incoming <- NewConn(c, l.isWss): case <-l.closed: @@ -144,13 +150,7 @@ func (l *listener) Accept() (manet.Conn, error) { if !ok { return nil, transport.ErrListenerClosed } - - mnc, err := manet.WrapNetConn(c) - if err != nil { - c.Close() - return nil, err - } - return mnc, nil + return c, nil case <-l.closed: return nil, transport.ErrListenerClosed } diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 304e4d7ba6..8388a7c1e3 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -101,7 +101,7 @@ type WebsocketTransport struct { var _ transport.Transport = (*WebsocketTransport)(nil) -func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*WebsocketTransport, error) { +func New(u transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*WebsocketTransport, error) { if rcmgr == nil { rcmgr = &network.NullResourceManager{} } @@ -109,6 +109,7 @@ func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (* upgrader: u, rcmgr: rcmgr, tlsClientConf: &tls.Config{}, + sharedTcp: sharedTCP, } for _, opt := range opts { if err := opt(t); err != nil { diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 8f912c4138..9ca03775a2 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -154,7 +154,7 @@ func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID } id, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSConfig(tlsConf)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSConfig(tlsConf)) if err != nil { t.Fatal(err) } @@ -237,7 +237,7 @@ func TestHostHeaderWss(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig)) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -256,7 +256,7 @@ func TestDialWss(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig)) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -279,7 +279,7 @@ func TestDialWssNoClientCert(t *testing.T) { require.Contains(t, serverMA.String(), "tls") _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -294,12 +294,12 @@ func TestDialWssNoClientCert(t *testing.T) { func TestWebsocketTransport(t *testing.T) { peerA, ua := newUpgrader(t) - ta, err := New(ua, nil) + ta, err := New(ua, nil, nil) if err != nil { t.Fatal(err) } _, ub := newUpgrader(t) - tb, err := New(ub, nil) + tb, err := New(ub, nil, nil) if err != nil { t.Fatal(err) } @@ -325,7 +325,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { opts = append(opts, WithTLSConfig(tlsConf)) } server, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, opts...) + tpt, err := New(u, &network.NullResourceManager{}, nil, opts...) require.NoError(t, err) l, err := tpt.Listen(laddr) require.NoError(t, err) @@ -344,7 +344,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { opts = append(opts, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) } _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, opts...) + tpt, err := New(u, &network.NullResourceManager{}, nil, opts...) require.NoError(t, err) c, err := tpt.Dial(context.Background(), l.Multiaddr(), server) require.NoError(t, err) @@ -382,7 +382,7 @@ func TestWebsocketConnection(t *testing.T) { func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) addr := ma.StringCast("/ip4/127.0.0.1/tcp/0/wss") _, err = tpt.Listen(addr) @@ -391,7 +391,7 @@ func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) { func TestWebsocketListenSecureAndInsecure(t *testing.T) { serverID, serverUpgrader := newUpgrader(t) - server, err := New(serverUpgrader, &network.NullResourceManager{}, WithTLSConfig(generateTLSConfig(t))) + server, err := New(serverUpgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t))) require.NoError(t, err) lnInsecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) @@ -401,7 +401,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { t.Run("insecure", func(t *testing.T) { _, clientUpgrader := newUpgrader(t) - client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) require.NoError(t, err) // dialing the insecure address should succeed @@ -418,7 +418,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { t.Run("secure", func(t *testing.T) { _, clientUpgrader := newUpgrader(t) - client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) require.NoError(t, err) // dialing the insecure address should succeed @@ -436,7 +436,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { func TestConcurrentClose(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) if err != nil { @@ -474,7 +474,7 @@ func TestConcurrentClose(t *testing.T) { func TestWriteZero(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) if err != nil { t.Fatal(err) } From ac8b478693dcd7dec2213daf329c1696b14b33d6 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 7 Oct 2024 15:56:44 +0530 Subject: [PATCH 05/32] make fewer concurrent connections in test, it breaks mac and windows --- p2p/transport/tcpreuse/listener_test.go | 85 +++++++++++++++++-------- 1 file changed, 60 insertions(+), 25 deletions(-) diff --git a/p2p/transport/tcpreuse/listener_test.go b/p2p/transport/tcpreuse/listener_test.go index b61ffce09d..6dff0901d8 100644 --- a/p2p/transport/tcpreuse/listener_test.go +++ b/p2p/transport/tcpreuse/listener_test.go @@ -20,6 +20,7 @@ import ( ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multistream" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -62,7 +63,7 @@ func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func TestListenerSingle(t *testing.T) { listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") - const N = 128 + const N = 64 for _, disableReuseport := range []bool{true, false} { t.Run(fmt.Sprintf("multistream-reuseport:%v", disableReuseport), func(t *testing.T) { cm := NewConnMgr(disableReuseport, nil, nil) @@ -101,11 +102,15 @@ func TestListenerSingle(t *testing.T) { go func() { defer wg.Done() cc := multistream.NewMSSelect(c, "a") + defer cc.Close() buf := make([]byte, 30) n, err := cc.Read(buf) - require.NoError(t, err) - require.Equal(t, "hello-multistream", string(buf[:n])) - c.Close() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello-multistream", string(buf[:n])) { + return + } }() } wg.Wait() @@ -147,11 +152,17 @@ func TestListenerSingle(t *testing.T) { wg.Add(1) go func() { defer wg.Done() + defer c.Close() msgType, buf, err := c.ReadMessage() - require.NoError(t, err) - require.Equal(t, msgType, websocket.TextMessage) - require.Equal(t, "hello", string(buf)) - c.Close() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "hello", string(buf)) { + return + } }() } wg.Wait() @@ -195,11 +206,17 @@ func TestListenerSingle(t *testing.T) { wg.Add(1) go func() { defer wg.Done() + defer c.Close() msgType, buf, err := c.ReadMessage() - require.NoError(t, err) - require.Equal(t, msgType, websocket.TextMessage) - require.Equal(t, "hello", string(buf)) - c.Close() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "hello", string(buf)) { + return + } }() } wg.Wait() @@ -209,7 +226,7 @@ func TestListenerSingle(t *testing.T) { func TestListenerMultiplexed(t *testing.T) { listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") - const N = 128 + const N = 20 for _, disableReuseport := range []bool{true, false} { cm := NewConnMgr(disableReuseport, nil, nil) msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) @@ -315,16 +332,22 @@ func TestListenerMultiplexed(t *testing.T) { defer wg.Done() for i := 0; i < N; i++ { c, err := msl.Accept() - require.NoError(t, err) + if !assert.NoError(t, err) { + return + } wg.Add(1) go func() { defer wg.Done() cc := multistream.NewMSSelect(c, "a") + defer cc.Close() buf := make([]byte, 20) n, err := cc.Read(buf) - require.NoError(t, err) - require.Equal(t, "multistream", string(buf[:n])) - cc.Close() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "multistream", string(buf[:n])) { + return + } }() } }() @@ -337,11 +360,17 @@ func TestListenerMultiplexed(t *testing.T) { wg.Add(1) go func() { defer wg.Done() + defer c.Close() msgType, buf, err := c.ReadMessage() - require.NoError(t, err) - require.Equal(t, msgType, websocket.TextMessage) - require.Equal(t, "websocket", string(buf)) - c.Close() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "websocket", string(buf)) { + return + } }() } }() @@ -354,11 +383,17 @@ func TestListenerMultiplexed(t *testing.T) { wg.Add(1) go func() { defer wg.Done() + defer c.Close() msgType, buf, err := c.ReadMessage() - require.NoError(t, err) - require.Equal(t, msgType, websocket.TextMessage) - require.Equal(t, "websocket-tls", string(buf)) - c.Close() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "websocket-tls", string(buf)) { + return + } }() } }() From 6d5fa533ae12f113bdecae2bb8bb80cec95c3226 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 7 Oct 2024 16:18:57 +0530 Subject: [PATCH 06/32] add some comments --- options.go | 1 + p2p/net/upgrader/listener.go | 2 +- p2p/transport/tcpreuse/listener.go | 13 ++++++++++--- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/options.go b/options.go index 7c94ed7892..0821c1eb9b 100644 --- a/options.go +++ b/options.go @@ -644,6 +644,7 @@ func WithFxOption(opts ...fx.Option) Option { } } +// ShareTCPListener shares the same listen address between TCP and Websocket transports. func ShareTCPListener() Option { return func(cfg *Config) error { cfg.ShareTCPListener = true diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index 65da2bec6c..9bee564d45 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -84,13 +84,13 @@ func (l *listener) handleIncoming() { } catcher.Reset() + // Check if we already have a connection scope. See the comment in tcpreuse/listener.go for an explanation. var connScope network.ConnManagementScope if sc, ok := maconn.(interface { Scope() network.ConnManagementScope }); ok { connScope = sc.Scope() } - if connScope == nil { // gate the connection if applicable if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index e0bfe8eef2..55fd85ed56 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -182,7 +182,15 @@ func (m *multiplexedListener) run() error { return err } - // gate the connection if applicable + // Gate and resource limit the connection here. + // If done after sampling the connection, we'll be vulnerable to DOS attacks by a single peer + // which clogs up our entire connection queue. + // This duplicates the responsibility of gating and resource limiting between here and the upgrader. The + // alternative without duplication requires moving the process of upgrading the connection here, which forces + // us to establish the websocket connection here. That is more duplication, or a significant breaking change. + // + // Bugs around multiple calls to OpenConnection or InterceptAccept are prevented by the transport + // integration tests. if m.connGater != nil && !m.connGater.InterceptAccept(c) { log.Debugf("gater blocked incoming connection on local addr %s from %s", c.LocalMultiaddr(), c.RemoteMultiaddr()) @@ -202,6 +210,7 @@ func (m *multiplexedListener) run() error { select { case acceptQueue <- struct{}{}: + // TODO: We can drop the connection, but this is similar to the behaviour in the upgrader. case <-m.ctx.Done(): c.Close() log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr()) @@ -211,8 +220,6 @@ func (m *multiplexedListener) run() error { go func() { defer func() { <-acceptQueue }() defer m.wg.Done() - // TODO: if/how do we want to handle stalled connections and stop them from clogging up the pipeline? - // Drop connection because the buffer is full t, sampleC, err := getDemultiplexedConn(c, connScope) if err != nil { connScope.Done() From 0115a62bfdbcca73e371bbc52afefa3da4c534f9 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Mon, 28 Oct 2024 22:03:57 -0700 Subject: [PATCH 07/32] Add OS specific sampledconn --- .../sampledconn/sampledconn_common.go | 53 ++++++++++++++ .../internal/sampledconn/sampledconn_other.go | 9 +++ .../internal/sampledconn/sampledconn_test.go | 70 +++++++++++++++++++ .../internal/sampledconn/sampledconn_unix.go | 50 +++++++++++++ 4 files changed, 182 insertions(+) create mode 100644 p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go create mode 100644 p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go create mode 100644 p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go create mode 100644 p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go new file mode 100644 index 0000000000..39f90183ca --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go @@ -0,0 +1,53 @@ +package sampledconn + +import ( + "io" + "net" + "time" +) + +const sampleSize = 3 + +type fallbackSampledConn struct { + tcpConnInterface + Sample [sampleSize]byte + readFromSample uint8 +} + +// tcpConnInterface is the interface for TCPConn's functions +// NOTE: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be +// misused given we've read a few bytes from the connection. +type tcpConnInterface interface { + net.Conn + + CloseRead() error + CloseWrite() error + + SetLinger(sec int) error + SetKeepAlive(keepalive bool) error + SetKeepAlivePeriod(d time.Duration) error + SetNoDelay(noDelay bool) error + MultipathTCP() (bool, error) + + io.ReaderFrom + io.WriterTo +} + +func newFallbackSampledConn(conn tcpConnInterface) (*fallbackSampledConn, error) { + s := &fallbackSampledConn{tcpConnInterface: conn} + _, err := io.ReadFull(conn, s.Sample[:]) + if err != nil { + return nil, err + } + return s, nil +} + +func (sc *fallbackSampledConn) Read(b []byte) (int, error) { + if int(sc.readFromSample) != len(sc.Sample) { + red := copy(b, sc.Sample[sc.readFromSample:]) + sc.readFromSample += uint8(red) + return red, nil + } + + return sc.tcpConnInterface.Read(b) +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go new file mode 100644 index 0000000000..41da906f04 --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go @@ -0,0 +1,9 @@ +//go:build !unix + +package sampledconn + +type SampledConn = *fallbackSampledConn + +func NewSampledConn(conn tcpConnInterface) (SampledConn, error) { + return newFallbackSampledConn(conn) +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go new file mode 100644 index 0000000000..c197a08acb --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go @@ -0,0 +1,70 @@ +package sampledconn + +import ( + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSampledConn(t *testing.T) { + testCases := []string{ + "platform", + // "fallback", + } + + // Start a TCP server + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + defer listener.Close() + + serverAddr := listener.Addr().String() + + // Server goroutine + go func() { + conn, err := listener.Accept() + assert.NoError(t, err) + defer conn.Close() + + // Write some data to the connection + _, err = conn.Write([]byte("hello")) + assert.NoError(t, err) + }() + + // Give the server a moment to start + time.Sleep(100 * time.Millisecond) + + for _, tc := range testCases { + t.Run(tc, func(t *testing.T) { + // Create a TCP client + clientConn, err := net.Dial("tcp", serverAddr) + assert.NoError(t, err) + defer clientConn.Close() + + if tc == "platform" { + // Wrap the client connection in SampledConn + sampledConn, err := NewSampledConn(clientConn.(*net.TCPConn)) + assert.NoError(t, err) + assert.Equal(t, "hel", string(sampledConn.Sample[:])) + + buf := make([]byte, 5) + _, err = sampledConn.Read(buf) + assert.NoError(t, err) + assert.Equal(t, "hello", string(buf)) + } else { + // Wrap the client connection in SampledConn + sampledConn, err := newFallbackSampledConn(clientConn.(tcpConnInterface)) + assert.NoError(t, err) + assert.Equal(t, "hel", string(sampledConn.Sample[:])) + + buf := make([]byte, 5) + _, err = io.ReadFull(sampledConn, buf) + assert.NoError(t, err) + assert.Equal(t, "hello", string(buf)) + + } + }) + } +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go new file mode 100644 index 0000000000..eadd2e497c --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go @@ -0,0 +1,50 @@ +//go:build unix + +package sampledconn + +import ( + "errors" + "net" + "syscall" +) + +type SampledConn struct { + *net.TCPConn + Sample [sampleSize]byte +} + +func NewSampledConn(conn *net.TCPConn) (SampledConn, error) { + s := SampledConn{ + TCPConn: conn, + } + + rawConn, err := conn.SyscallConn() + if err != nil { + return s, err + } + + readBytes := 0 + var readErr error + err = rawConn.Read(func(fd uintptr) bool { + for readBytes < sampleSize { + var n int + n, _, readErr = syscall.Recvfrom(int(fd), s.Sample[readBytes:], syscall.MSG_PEEK) + if errors.Is(readErr, syscall.EAGAIN) { + return false + } + if readErr != nil { + return true + } + readBytes += n + } + return true + }) + if readErr != nil { + return s, readErr + } + if err != nil { + return s, err + } + + return s, nil +} From cdbf16c185d3b34db92adb3aff335b46f940a6cf Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Tue, 29 Oct 2024 09:30:23 -0700 Subject: [PATCH 08/32] Return net.Conn unwrapped if possible --- .../sampledconn/sampledconn_common.go | 53 ++++++++++++++----- .../internal/sampledconn/sampledconn_other.go | 8 +-- .../internal/sampledconn/sampledconn_test.go | 10 ++-- .../internal/sampledconn/sampledconn_unix.go | 16 ++---- 4 files changed, 54 insertions(+), 33 deletions(-) diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go index 39f90183ca..eb71f7b44d 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go @@ -1,17 +1,44 @@ package sampledconn import ( + "errors" "io" "net" + "syscall" "time" ) -const sampleSize = 3 +const peekSize = 3 -type fallbackSampledConn struct { +type PeekedBytes = [peekSize]byte + +var errNotSupported = errors.New("not supported on this platform") + +var ErrNotTCPConn = errors.New("passed conn is not a TCPConn") + +func PeekBytes(conn net.Conn) (PeekedBytes, net.Conn, error) { + if c, ok := conn.(syscall.Conn); ok { + b, err := OSPeekConn(c) + if err == nil { + return b, conn, nil + } + if err != errNotSupported { + return PeekedBytes{}, nil, err + } + // Fallback to wrapping the coonn + } + + if c, ok := conn.(tcpConnInterface); ok { + return newFallbackSampledConn(c) + } + + return PeekedBytes{}, nil, ErrNotTCPConn +} + +type fallbackPeekingConn struct { tcpConnInterface - Sample [sampleSize]byte - readFromSample uint8 + peekedBytes PeekedBytes + bytesPeeked uint8 } // tcpConnInterface is the interface for TCPConn's functions @@ -33,19 +60,19 @@ type tcpConnInterface interface { io.WriterTo } -func newFallbackSampledConn(conn tcpConnInterface) (*fallbackSampledConn, error) { - s := &fallbackSampledConn{tcpConnInterface: conn} - _, err := io.ReadFull(conn, s.Sample[:]) +func newFallbackSampledConn(conn tcpConnInterface) (PeekedBytes, *fallbackPeekingConn, error) { + s := &fallbackPeekingConn{tcpConnInterface: conn} + _, err := io.ReadFull(conn, s.peekedBytes[:]) if err != nil { - return nil, err + return s.peekedBytes, nil, err } - return s, nil + return s.peekedBytes, s, nil } -func (sc *fallbackSampledConn) Read(b []byte) (int, error) { - if int(sc.readFromSample) != len(sc.Sample) { - red := copy(b, sc.Sample[sc.readFromSample:]) - sc.readFromSample += uint8(red) +func (sc *fallbackPeekingConn) Read(b []byte) (int, error) { + if int(sc.bytesPeeked) != len(sc.peekedBytes) { + red := copy(b, sc.peekedBytes[sc.bytesPeeked:]) + sc.bytesPeeked += uint8(red) return red, nil } diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go index 41da906f04..5197052fab 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go @@ -2,8 +2,10 @@ package sampledconn -type SampledConn = *fallbackSampledConn +import ( + "syscall" +) -func NewSampledConn(conn tcpConnInterface) (SampledConn, error) { - return newFallbackSampledConn(conn) +func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { + return PeekedBytes{}, errNotSupported } diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go index c197a08acb..1910bb3597 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go @@ -45,19 +45,19 @@ func TestSampledConn(t *testing.T) { if tc == "platform" { // Wrap the client connection in SampledConn - sampledConn, err := NewSampledConn(clientConn.(*net.TCPConn)) + peeked, clientConn, err := PeekBytes(clientConn.(*net.TCPConn)) assert.NoError(t, err) - assert.Equal(t, "hel", string(sampledConn.Sample[:])) + assert.Equal(t, "hel", string(peeked[:])) buf := make([]byte, 5) - _, err = sampledConn.Read(buf) + _, err = clientConn.Read(buf) assert.NoError(t, err) assert.Equal(t, "hello", string(buf)) } else { // Wrap the client connection in SampledConn - sampledConn, err := newFallbackSampledConn(clientConn.(tcpConnInterface)) + sample, sampledConn, err := newFallbackSampledConn(clientConn.(tcpConnInterface)) assert.NoError(t, err) - assert.Equal(t, "hel", string(sampledConn.Sample[:])) + assert.Equal(t, "hel", string(sample[:])) buf := make([]byte, 5) _, err = io.ReadFull(sampledConn, buf) diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go index eadd2e497c..9847e8d4be 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go @@ -4,19 +4,11 @@ package sampledconn import ( "errors" - "net" "syscall" ) -type SampledConn struct { - *net.TCPConn - Sample [sampleSize]byte -} - -func NewSampledConn(conn *net.TCPConn) (SampledConn, error) { - s := SampledConn{ - TCPConn: conn, - } +func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { + s := PeekedBytes{} rawConn, err := conn.SyscallConn() if err != nil { @@ -26,9 +18,9 @@ func NewSampledConn(conn *net.TCPConn) (SampledConn, error) { readBytes := 0 var readErr error err = rawConn.Read(func(fd uintptr) bool { - for readBytes < sampleSize { + for readBytes < peekSize { var n int - n, _, readErr = syscall.Recvfrom(int(fd), s.Sample[readBytes:], syscall.MSG_PEEK) + n, _, readErr = syscall.Recvfrom(int(fd), s[readBytes:], syscall.MSG_PEEK) if errors.Is(readErr, syscall.EAGAIN) { return false } From 7f4f875e3e2e97fbccb31380c105f8428dfc72f9 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Wed, 30 Oct 2024 13:16:09 -0400 Subject: [PATCH 09/32] feat(tcpreuse): add Windows sampledconn --- .../internal/sampledconn/sampledconn_other.go | 2 +- .../sampledconn/sampledconn_windows.go | 49 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 p2p/transport/tcpreuse/internal/sampledconn/sampledconn_windows.go diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go index 5197052fab..7386112395 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go @@ -1,4 +1,4 @@ -//go:build !unix +//go:build !unix && !windows package sampledconn diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_windows.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_windows.go new file mode 100644 index 0000000000..46b0617996 --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_windows.go @@ -0,0 +1,49 @@ +//go:build windows + +package sampledconn + +import ( + "errors" + "golang.org/x/sys/windows" + "syscall" +) + +func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { + s := PeekedBytes{} + + rawConn, err := conn.SyscallConn() + if err != nil { + return s, err + } + + readBytes := 0 + var readErr error + err = rawConn.Read(func(fd uintptr) bool { + for readBytes < peekSize { + var n uint32 + flags := uint32(windows.MSG_PEEK) + wsabuf := windows.WSABuf{ + Len: uint32(len(s) - readBytes), + Buf: &s[readBytes], + } + + readErr = windows.WSARecv(windows.Handle(fd), &wsabuf, 1, &n, &flags, nil, nil) + if errors.Is(readErr, windows.WSAEWOULDBLOCK) { + return false + } + if readErr != nil { + return true + } + readBytes += int(n) + } + return true + }) + if readErr != nil { + return s, readErr + } + if err != nil { + return s, err + } + + return s, nil +} From 00c9403268027797a36abc1c3fd5c18b638c3d37 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 10:55:40 -0700 Subject: [PATCH 10/32] Add test for metrics+unix --- p2p/transport/tcp/metrics_unix_test.go | 33 +++++++++++++++++++++++++ p2p/transport/tcpreuse/connwithscope.go | 26 +++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 p2p/transport/tcp/metrics_unix_test.go create mode 100644 p2p/transport/tcpreuse/connwithscope.go diff --git a/p2p/transport/tcp/metrics_unix_test.go b/p2p/transport/tcp/metrics_unix_test.go new file mode 100644 index 0000000000..094ced0d45 --- /dev/null +++ b/p2p/transport/tcp/metrics_unix_test.go @@ -0,0 +1,33 @@ +// go:build: unix + +package tcp + +import ( + "testing" + + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" + ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" + + "github.com/stretchr/testify/require" +) + +func TestTcpTransportCollectsMetricsWithSharedTcpSocket(t *testing.T) { + peerA, ia := makeInsecureMuxer(t) + _, ib := makeInsecureMuxer(t) + + sharedTCPSocketA := tcpreuse.NewConnMgr(false, nil, nil) + sharedTCPSocketB := tcpreuse.NewConnMgr(false, nil, nil) + + ua, err := tptu.New(ia, muxers, nil, nil, nil) + require.NoError(t, err) + ta, err := NewTCPTransport(ua, nil, sharedTCPSocketA, WithMetrics()) + require.NoError(t, err) + ub, err := tptu.New(ib, muxers, nil, nil, nil) + require.NoError(t, err) + tb, err := NewTCPTransport(ub, nil, sharedTCPSocketB, WithMetrics()) + require.NoError(t, err) + + zero := "/ip4/127.0.0.1/tcp/0" + ttransport.SubtestTransport(t, ta, tb, zero, peerA) +} diff --git a/p2p/transport/tcpreuse/connwithscope.go b/p2p/transport/tcpreuse/connwithscope.go new file mode 100644 index 0000000000..ca66f20325 --- /dev/null +++ b/p2p/transport/tcpreuse/connwithscope.go @@ -0,0 +1,26 @@ +package tcpreuse + +import ( + "fmt" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn" + manet "github.com/multiformats/go-multiaddr/net" +) + +type connWithScope struct { + sampledconn.ManetTCPConnInterface + scope network.ConnManagementScope +} + +func (c connWithScope) Scope() network.ConnManagementScope { + return c.scope +} + +func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) { + if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok { + return &connWithScope{tcpconn, scope}, nil + } + + return nil, fmt.Errorf("manet.Conn is not a TCP Conn") +} From 2c053d816832954134c2a14b9641a6e2fa3ebd52 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 11:03:07 -0700 Subject: [PATCH 11/32] tcp transport: Parameterize metrics collector in TCP --- p2p/transport/tcp/metrics.go | 29 ++++++++++++------- p2p/transport/tcp/metrics_none.go | 8 +++-- p2p/transport/tcp/tcp.go | 6 ++-- .../sampledconn/sampledconn_common.go | 25 +++++++++++----- 4 files changed, 46 insertions(+), 22 deletions(-) diff --git a/p2p/transport/tcp/metrics.go b/p2p/transport/tcp/metrics.go index 213ee2200a..50820d870c 100644 --- a/p2p/transport/tcp/metrics.go +++ b/p2p/transport/tcp/metrics.go @@ -24,7 +24,7 @@ var ( const collectFrequency = 10 * time.Second -var collector *aggregatingCollector +var defaultCollector *aggregatingCollector var initMetricsOnce sync.Once @@ -34,8 +34,8 @@ func initMetrics() { bytesSentDesc = prometheus.NewDesc("tcp_sent_bytes", "TCP bytes sent", nil, nil) bytesRcvdDesc = prometheus.NewDesc("tcp_rcvd_bytes", "TCP bytes received", nil, nil) - collector = newAggregatingCollector() - prometheus.MustRegister(collector) + defaultCollector = newAggregatingCollector() + prometheus.MustRegister(defaultCollector) const direction = "direction" @@ -196,7 +196,7 @@ func (c *aggregatingCollector) Collect(metrics chan<- prometheus.Metric) { func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) { c.mutex.Lock() - collector.removeConn(conn.id) + c.removeConn(conn.id) c.mutex.Unlock() closedConns.WithLabelValues(direction).Inc() } @@ -204,6 +204,8 @@ func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) { type tracingConn struct { id uint64 + collector *aggregatingCollector + startTime time.Time isClient bool @@ -213,7 +215,8 @@ type tracingConn struct { closeErr error } -func newTracingConn(c manet.Conn, isClient bool) (*tracingConn, error) { +// newTracingConn wraps a manet.Conn with a tracingConn. A nil collector will use the default collector. +func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (*tracingConn, error) { initMetricsOnce.Do(func() { initMetrics() }) conn, err := tcp.NewConn(c) if err != nil { @@ -224,8 +227,12 @@ func newTracingConn(c manet.Conn, isClient bool) (*tracingConn, error) { isClient: isClient, Conn: c, tcpConn: conn, + collector: collector, + } + if tc.collector == nil { + tc.collector = defaultCollector } - tc.id = collector.AddConn(tc) + tc.id = tc.collector.AddConn(tc) newConns.WithLabelValues(tc.getDirection()).Inc() return tc, nil } @@ -239,7 +246,7 @@ func (c *tracingConn) getDirection() string { func (c *tracingConn) Close() error { c.closeOnce.Do(func() { - collector.ClosedConn(c, c.getDirection()) + c.collector.ClosedConn(c, c.getDirection()) c.closeErr = c.Conn.Close() }) return c.closeErr @@ -258,10 +265,12 @@ func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) { type tracingListener struct { manet.Listener + collector *aggregatingCollector } -func newTracingListener(l manet.Listener) *tracingListener { - return &tracingListener{Listener: l} +// newTracingListener wraps a manet.Listener with a tracingListener. A nil collector will use the default collector. +func newTracingListener(l manet.Listener, collector *aggregatingCollector) *tracingListener { + return &tracingListener{Listener: l, collector: collector} } func (l *tracingListener) Accept() (manet.Conn, error) { @@ -269,5 +278,5 @@ func (l *tracingListener) Accept() (manet.Conn, error) { if err != nil { return nil, err } - return newTracingConn(conn, false) + return newTracingConn(conn, l.collector, false) } diff --git a/p2p/transport/tcp/metrics_none.go b/p2p/transport/tcp/metrics_none.go index 8538b30c89..cbee982070 100644 --- a/p2p/transport/tcp/metrics_none.go +++ b/p2p/transport/tcp/metrics_none.go @@ -6,5 +6,9 @@ package tcp import manet "github.com/multiformats/go-multiaddr/net" -func newTracingConn(c manet.Conn, _ bool) (manet.Conn, error) { return c, nil } -func newTracingListener(l manet.Listener) manet.Listener { return l } +type aggregatingCollector struct{} + +func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (manet.Conn, error) { + return c, nil +} +func newTracingListener(l manet.Listener, collector *aggregatingCollector) manet.Listener { return l } diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index 1b145c2b45..e197b26660 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -142,6 +142,8 @@ type TcpTransport struct { rcmgr network.ResourceManager reuse reuseport.Transport + + metricsCollector *aggregatingCollector } var _ transport.Transport = &TcpTransport{} @@ -231,7 +233,7 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p c := conn if t.enableMetrics { var err error - c, err = newTracingConn(conn, true) + c, err = newTracingConn(conn, t.metricsCollector, true) if err != nil { return nil, err } @@ -277,7 +279,7 @@ func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { } if t.enableMetrics { - list = newTracingListener(&tcpListener{list, 0}) + list = newTracingListener(&tcpListener{list, 0}, t.metricsCollector) } return t.upgrader.UpgradeListener(t, list), nil } diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go index eb71f7b44d..7324b45849 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go @@ -6,6 +6,8 @@ import ( "net" "syscall" "time" + + manet "github.com/multiformats/go-multiaddr/net" ) const peekSize = 3 @@ -16,7 +18,7 @@ var errNotSupported = errors.New("not supported on this platform") var ErrNotTCPConn = errors.New("passed conn is not a TCPConn") -func PeekBytes(conn net.Conn) (PeekedBytes, net.Conn, error) { +func PeekBytes(conn manet.Conn) (PeekedBytes, manet.Conn, error) { if c, ok := conn.(syscall.Conn); ok { b, err := OSPeekConn(c) if err == nil { @@ -28,7 +30,7 @@ func PeekBytes(conn net.Conn) (PeekedBytes, net.Conn, error) { // Fallback to wrapping the coonn } - if c, ok := conn.(tcpConnInterface); ok { + if c, ok := conn.(ManetTCPConnInterface); ok { return newFallbackSampledConn(c) } @@ -36,16 +38,18 @@ func PeekBytes(conn net.Conn) (PeekedBytes, net.Conn, error) { } type fallbackPeekingConn struct { - tcpConnInterface + ManetTCPConnInterface peekedBytes PeekedBytes bytesPeeked uint8 } // tcpConnInterface is the interface for TCPConn's functions -// NOTE: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be -// misused given we've read a few bytes from the connection. +// NOTE: `SyscallConn() (syscall.RawConn, error)` is here to make using this as +// a TCP Conn easier, but it's a potential footgun as you could skipped the +// peeked bytes if using the fallback type tcpConnInterface interface { net.Conn + syscall.Conn CloseRead() error CloseWrite() error @@ -60,8 +64,13 @@ type tcpConnInterface interface { io.WriterTo } -func newFallbackSampledConn(conn tcpConnInterface) (PeekedBytes, *fallbackPeekingConn, error) { - s := &fallbackPeekingConn{tcpConnInterface: conn} +type ManetTCPConnInterface interface { + manet.Conn + tcpConnInterface +} + +func newFallbackSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *fallbackPeekingConn, error) { + s := &fallbackPeekingConn{ManetTCPConnInterface: conn} _, err := io.ReadFull(conn, s.peekedBytes[:]) if err != nil { return s.peekedBytes, nil, err @@ -76,5 +85,5 @@ func (sc *fallbackPeekingConn) Read(b []byte) (int, error) { return red, nil } - return sc.tcpConnInterface.Read(b) + return sc.ManetTCPConnInterface.Read(b) } From 7e34d056771adb3ce4e26e635f569187be13fdb9 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 11:09:57 -0700 Subject: [PATCH 12/32] simplify demultiplex a bit --- p2p/transport/tcpreuse/demultiplex.go | 179 ++------------------- p2p/transport/tcpreuse/demultiplex_test.go | 2 +- p2p/transport/tcpreuse/listener.go | 19 ++- 3 files changed, 31 insertions(+), 169 deletions(-) diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index 342e7de0b3..864dc76040 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -4,13 +4,9 @@ import ( "bufio" "errors" "fmt" - "io" - "math" - "net" "time" - "github.com/libp2p/go-libp2p/core/network" - ma "github.com/multiformats/go-multiaddr" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn" manet "github.com/multiformats/go-multiaddr/net" ) @@ -52,13 +48,17 @@ func (t DemultiplexedConnType) IsKnown() bool { return t >= 1 || t <= 3 } -func getDemultiplexedConn(c net.Conn, scope network.ConnManagementScope) (DemultiplexedConnType, manet.Conn, error) { +// identifyConnType attempts to identify the connection type by peeking at the +// first few bytes. +// It Callers must not use the passed in Conn after this +// function returns. if an error is returned, the connection will be closed. +func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) { if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) } - s, sc, err := readSampleFromConn(c, scope) + s, c, err := sampledconn.PeekBytes(c) if err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) @@ -70,174 +70,25 @@ func getDemultiplexedConn(c net.Conn, scope network.ConnManagementScope) (Demult } if IsMultistreamSelect(s) { - return DemultiplexedConnType_MultistreamSelect, sc, nil + return DemultiplexedConnType_MultistreamSelect, c, nil } if IsTLS(s) { - return DemultiplexedConnType_TLS, sc, nil + return DemultiplexedConnType_TLS, c, nil } if IsHTTP(s) { - return DemultiplexedConnType_HTTP, sc, nil + return DemultiplexedConnType_HTTP, c, nil } - return DemultiplexedConnType_Unknown, sc, nil + return DemultiplexedConnType_Unknown, c, nil } -// readSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged. -// If an error occurs it only returns the error. -func readSampleFromConn(c net.Conn, scope network.ConnManagementScope) (Sample, manet.Conn, error) { - // TODO: Should we remove this? This is only implemented by bufio.Reader. - // This made sense for magiselect: https://github.com/libp2p/go-libp2p/pull/2737 as it deals with a wrapped - // ReadWriteCloser from multistream which does use a buffered reader underneath. - // For our present purpose, we have a net.Conn and no net.Conn implementation offers peeking. - if peekAble, ok := c.(peekAble); ok { - b, err := peekAble.Peek(len(Sample{})) - switch { - case err == nil: - mac, err := manet.WrapNetConn(c) - if err != nil { - return Sample{}, nil, err - } - - return Sample(b), mac, nil - case errors.Is(err, bufio.ErrBufferFull): - // We can only peek < len(Sample{}) data. - // fallback to sampledConn - default: - return Sample{}, nil, err - } - } - - tcpConnLike, ok := c.(tcpConnInterface) - if !ok { - return Sample{}, nil, fmt.Errorf("expected tcp-like connection") - } - - laddr, err := manet.FromNetAddr(c.LocalAddr()) - if err != nil { - return Sample{}, nil, fmt.Errorf("failed to convert nconn.LocalAddr: %s", err) - } - - raddr, err := manet.FromNetAddr(c.RemoteAddr()) - if err != nil { - return Sample{}, nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err) - } - - sc := &sampledConn{ - tcpConnInterface: tcpConnLike, - maEndpoints: maEndpoints{laddr: laddr, raddr: raddr}, - scope: scope, - } - _, err = io.ReadFull(c, sc.s[:]) - if err != nil { - return Sample{}, nil, err - } - return sc.s, sc, nil -} - -// tcpConnInterface is the interface for TCPConn's functions -// Note: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be misused given we've read a few bytes from the connection. -// TODO: allow SyscallConn? Disallowing it breaks metrics tracking in TCP Transport. -type tcpConnInterface interface { - net.Conn - - CloseRead() error - CloseWrite() error - - SetLinger(sec int) error - SetKeepAlive(keepalive bool) error - SetKeepAlivePeriod(d time.Duration) error - SetNoDelay(noDelay bool) error - MultipathTCP() (bool, error) - - io.ReaderFrom - io.WriterTo -} - -type maEndpoints struct { - laddr ma.Multiaddr - raddr ma.Multiaddr -} - -// LocalMultiaddr returns the local address associated with -// this connection -func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr { - return c.laddr -} - -// RemoteMultiaddr returns the remote address associated with -// this connection -func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr { - return c.raddr -} - -type sampledConn struct { - tcpConnInterface - maEndpoints - scope network.ConnManagementScope - s Sample - readFromSample uint8 -} - -var _ = [math.MaxUint8]struct{}{}[len(Sample{})] // compiletime assert sampledConn.readFromSample wont overflow -var _ io.ReaderFrom = (*sampledConn)(nil) -var _ io.WriterTo = (*sampledConn)(nil) - -func (sc *sampledConn) Read(b []byte) (int, error) { - if int(sc.readFromSample) != len(sc.s) { - red := copy(b, sc.s[sc.readFromSample:]) - sc.readFromSample += uint8(red) - return red, nil - } - - return sc.tcpConnInterface.Read(b) -} - -// TODO: Do we need these? - -func (sc *sampledConn) ReadFrom(r io.Reader) (int64, error) { - return io.Copy(sc.tcpConnInterface, r) -} - -func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) { - if int(sc.readFromSample) != len(sc.s) { - b := sc.s[sc.readFromSample:] - written, err := w.Write(b) - if written < 0 || len(b) < written { - // buggy writer, harden against this - sc.readFromSample = uint8(len(sc.s)) - total = int64(len(sc.s)) - } else { - sc.readFromSample += uint8(written) - total += int64(written) - } - if err != nil { - return total, err - } - } - - written, err := io.Copy(w, sc.tcpConnInterface) - total += written - return total, err -} - -func (sc *sampledConn) Scope() network.ConnManagementScope { - return sc.scope -} - -func (sc *sampledConn) Close() error { - sc.scope.Done() - return sc.tcpConnInterface.Close() -} - -// Sample is the byte sequence we use to demultiplex. -type Sample [3]byte - // Matchers are implemented here instead of in the transports so we can easily fuzz them together. +type Prefix = [3]byte -func IsMultistreamSelect(s Sample) bool { +func IsMultistreamSelect(s Prefix) bool { return string(s[:]) == "\x13/m" } -func IsHTTP(s Sample) bool { +func IsHTTP(s Prefix) bool { switch string(s[:]) { case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT": return true @@ -246,7 +97,7 @@ func IsHTTP(s Sample) bool { } } -func IsTLS(s Sample) bool { +func IsTLS(s Prefix) bool { switch string(s[:]) { case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03", "\x16\x03\x04": return true diff --git a/p2p/transport/tcpreuse/demultiplex_test.go b/p2p/transport/tcpreuse/demultiplex_test.go index 3d6e91f35a..e201f2ca75 100644 --- a/p2p/transport/tcpreuse/demultiplex_test.go +++ b/p2p/transport/tcpreuse/demultiplex_test.go @@ -25,7 +25,7 @@ func FuzzClash(f *testing.F) { add('\x16', '\x03', '\x04') f.Fuzz(func(t *testing.T, a, b, c byte) { - s := Sample{a, b, c} + s := Prefix{a, b, c} var total uint ms := IsMultistreamSelect(s) diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 55fd85ed56..2aa61a0fb0 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -220,7 +220,7 @@ func (m *multiplexedListener) run() error { go func() { defer func() { <-acceptQueue }() defer m.wg.Done() - t, sampleC, err := getDemultiplexedConn(c, connScope) + t, c, err := identifyConnType(c) if err != nil { connScope.Done() closeErr := c.Close() @@ -229,11 +229,22 @@ func (m *multiplexedListener) run() error { return } + // TODO: Add a test that makes sure we can get the SyscallConn in Unix platforms. + // Wrap the scope into the conn. + connWithScope, err := manetConnWithScope(c, connScope) + if err != nil { + connScope.Done() + closeErr := c.Close() + err = errors.Join(err, closeErr) + log.Debugf("error wrapping connection with scope: %s", err.Error()) + return + } + m.mx.RLock() demux, ok := m.listeners[t] m.mx.RUnlock() if !ok { - closeErr := sampleC.Close() + closeErr := connWithScope.Close() if closeErr != nil { log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error()) } else { @@ -243,9 +254,9 @@ func (m *multiplexedListener) run() error { } select { - case demux.buffer <- sampleC: + case demux.buffer <- connWithScope: case <-m.ctx.Done(): - sampleC.Close() + connWithScope.Close() return } }() From e465b298fe68fb4dfc1dd08d3c4dd097de62ae98 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 12:20:46 -0700 Subject: [PATCH 13/32] transport-testsuite(old): Parameterize subtests --- p2p/transport/testsuite/utils_suite.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/p2p/transport/testsuite/utils_suite.go b/p2p/transport/testsuite/utils_suite.go index 5e488397a5..8b002f8900 100644 --- a/p2p/transport/testsuite/utils_suite.go +++ b/p2p/transport/testsuite/utils_suite.go @@ -11,7 +11,9 @@ import ( ma "github.com/multiformats/go-multiaddr" ) -var Subtests = []func(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID){ +type TransportSubTestFn func(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) + +var Subtests = []TransportSubTestFn{ SubtestProtocols, SubtestBasic, SubtestCancel, @@ -33,12 +35,17 @@ func getFunctionName(i interface{}) string { } func SubtestTransport(t *testing.T, ta, tb transport.Transport, addr string, peerA peer.ID) { + t.Helper() + SubtestTransportWithFs(t, ta, tb, addr, peerA, Subtests) +} + +func SubtestTransportWithFs(t *testing.T, ta, tb transport.Transport, addr string, peerA peer.ID, tests []TransportSubTestFn) { maddr, err := ma.NewMultiaddr(addr) if err != nil { t.Fatal(err) } - for _, f := range Subtests { + for _, f := range tests { t.Run(getFunctionName(f), func(t *testing.T) { f(t, ta, tb, maddr, peerA) }) From 677222a77f62e35998ee24f7ae2950876b172476 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 12:21:09 -0700 Subject: [PATCH 14/32] tcp-transport: selectively run only 1conn tests --- p2p/transport/tcp/metrics_unix_test.go | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/p2p/transport/tcp/metrics_unix_test.go b/p2p/transport/tcp/metrics_unix_test.go index 094ced0d45..0a09526206 100644 --- a/p2p/transport/tcp/metrics_unix_test.go +++ b/p2p/transport/tcp/metrics_unix_test.go @@ -13,6 +13,7 @@ import ( ) func TestTcpTransportCollectsMetricsWithSharedTcpSocket(t *testing.T) { + peerA, ia := makeInsecureMuxer(t) _, ib := makeInsecureMuxer(t) @@ -29,5 +30,25 @@ func TestTcpTransportCollectsMetricsWithSharedTcpSocket(t *testing.T) { require.NoError(t, err) zero := "/ip4/127.0.0.1/tcp/0" - ttransport.SubtestTransport(t, ta, tb, zero, peerA) + + // Not running any test that needs more than 1 conn because the testsuite + // opens multiple conns via multiple listeners, which is not expected to work + // with the shared TCP socket. + subtestsToRun := []ttransport.TransportSubTestFn{ + ttransport.SubtestProtocols, + ttransport.SubtestBasic, + ttransport.SubtestCancel, + ttransport.SubtestPingPong, + + // Stolen from the stream muxer test suite. + ttransport.SubtestStress1Conn1Stream1Msg, + ttransport.SubtestStress1Conn1Stream100Msg, + ttransport.SubtestStress1Conn100Stream100Msg, + ttransport.SubtestStress1Conn1000Stream10Msg, + ttransport.SubtestStress1Conn100Stream100Msg10MB, + ttransport.SubtestStreamOpenStress, + ttransport.SubtestStreamReset, + } + + ttransport.SubtestTransportWithFs(t, ta, tb, zero, peerA, subtestsToRun) } From 8c334af068c36d2e5f4a2604b658d9e16612c268 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 12:33:21 -0700 Subject: [PATCH 15/32] sampledconn: update tests to be manet aware --- .../internal/sampledconn/sampledconn_test.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go index 1910bb3597..58ec8c2a45 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go @@ -2,10 +2,13 @@ package sampledconn import ( "io" - "net" + "syscall" "testing" "time" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "github.com/stretchr/testify/assert" ) @@ -16,11 +19,11 @@ func TestSampledConn(t *testing.T) { } // Start a TCP server - listener, err := net.Listen("tcp", "127.0.0.1:0") + listener, err := manet.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) assert.NoError(t, err) defer listener.Close() - serverAddr := listener.Addr().String() + serverAddr := listener.Multiaddr() // Server goroutine go func() { @@ -39,13 +42,16 @@ func TestSampledConn(t *testing.T) { for _, tc := range testCases { t.Run(tc, func(t *testing.T) { // Create a TCP client - clientConn, err := net.Dial("tcp", serverAddr) + clientConn, err := manet.Dial(serverAddr) assert.NoError(t, err) defer clientConn.Close() if tc == "platform" { // Wrap the client connection in SampledConn - peeked, clientConn, err := PeekBytes(clientConn.(*net.TCPConn)) + peeked, clientConn, err := PeekBytes(clientConn.(interface { + manet.Conn + syscall.Conn + })) assert.NoError(t, err) assert.Equal(t, "hel", string(peeked[:])) @@ -55,7 +61,7 @@ func TestSampledConn(t *testing.T) { assert.Equal(t, "hello", string(buf)) } else { // Wrap the client connection in SampledConn - sample, sampledConn, err := newFallbackSampledConn(clientConn.(tcpConnInterface)) + sample, sampledConn, err := newFallbackSampledConn(clientConn.(ManetTCPConnInterface)) assert.NoError(t, err) assert.Equal(t, "hel", string(sample[:])) From b686b8007eb051fcb6cef7eb429123f5239bce22 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 13:11:26 -0700 Subject: [PATCH 16/32] Remove unused interface --- p2p/transport/tcpreuse/demultiplex.go | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index 864dc76040..ab79286bf9 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -1,7 +1,6 @@ package tcpreuse import ( - "bufio" "errors" "fmt" "time" @@ -10,16 +9,6 @@ import ( manet "github.com/multiformats/go-multiaddr/net" ) -type peekAble interface { - // Peek returns the next n bytes without advancing the reader. The bytes stop - // being valid at the next read call. If Peek returns fewer than n bytes, it - // also returns an error explaining why the read is short. The error is - // [ErrBufferFull] if n is larger than b's buffer size. - Peek(n int) ([]byte, error) -} - -var _ peekAble = (*bufio.Reader)(nil) - // TODO: We can unexport this type and rely completely on the multiaddr passed in to // DemultiplexedListen. type DemultiplexedConnType int From a28a244d485e169126c71ab39af3a3b014621f02 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 13:25:17 -0700 Subject: [PATCH 17/32] sampledconn: Add test case back in --- .../internal/sampledconn/sampledconn_test.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go index 58ec8c2a45..a7c5a65f33 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go @@ -15,7 +15,7 @@ import ( func TestSampledConn(t *testing.T) { testCases := []string{ "platform", - // "fallback", + "fallback", } // Start a TCP server @@ -27,13 +27,15 @@ func TestSampledConn(t *testing.T) { // Server goroutine go func() { - conn, err := listener.Accept() - assert.NoError(t, err) - defer conn.Close() + for i := 0; i < len(testCases); i++ { + conn, err := listener.Accept() + assert.NoError(t, err) + defer conn.Close() - // Write some data to the connection - _, err = conn.Write([]byte("hello")) - assert.NoError(t, err) + // Write some data to the connection + _, err = conn.Write([]byte("hello")) + assert.NoError(t, err) + } }() // Give the server a moment to start From a39490f031bf8cf48ca49a70347f23976b79262f Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 13:27:28 -0700 Subject: [PATCH 18/32] Comment nits --- p2p/transport/tcpreuse/listener.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 2aa61a0fb0..7d4b4959dd 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -210,7 +210,7 @@ func (m *multiplexedListener) run() error { select { case acceptQueue <- struct{}{}: - // TODO: We can drop the connection, but this is similar to the behaviour in the upgrader. + // NOTE: We can drop the connection, but this is similar to the behaviour in the upgrader. case <-m.ctx.Done(): c.Close() log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr()) @@ -229,8 +229,6 @@ func (m *multiplexedListener) run() error { return } - // TODO: Add a test that makes sure we can get the SyscallConn in Unix platforms. - // Wrap the scope into the conn. connWithScope, err := manetConnWithScope(c, connScope) if err != nil { connScope.Done() From 76f560b9be5dc945631f64757ab2422f1f06e125 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:00:23 -0700 Subject: [PATCH 19/32] tcpreuse: flip reuseport bool --- p2p/transport/tcpreuse/listener.go | 22 +++++++++++----------- p2p/transport/tcpreuse/listener_test.go | 20 ++++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 7d4b4959dd..c23586d2e0 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -22,25 +22,25 @@ var log = logging.Logger("tcp-demultiplex") // ConnMgr enables you to share the same listen address between TCP and WebSocket transports. type ConnMgr struct { - disableReuseport bool - reuse reuseport.Transport - connGater connmgr.ConnectionGater - rcmgr network.ResourceManager + enableReuseport bool + reuse reuseport.Transport + connGater connmgr.ConnectionGater + rcmgr network.ResourceManager mx sync.Mutex listeners map[string]*multiplexedListener } -func NewConnMgr(disableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr { +func NewConnMgr(enableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr { if rcmgr == nil { rcmgr = &network.NullResourceManager{} } return &ConnMgr{ - disableReuseport: disableReuseport, - reuse: reuseport.Transport{}, - connGater: gater, - rcmgr: rcmgr, - listeners: make(map[string]*multiplexedListener), + enableReuseport: enableReuseport, + reuse: reuseport.Transport{}, + connGater: gater, + rcmgr: rcmgr, + listeners: make(map[string]*multiplexedListener), } } @@ -53,7 +53,7 @@ func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) { } func (t *ConnMgr) useReuseport() bool { - return !t.disableReuseport && ReuseportIsAvailable() + return t.enableReuseport && ReuseportIsAvailable() } func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) { diff --git a/p2p/transport/tcpreuse/listener_test.go b/p2p/transport/tcpreuse/listener_test.go index 6dff0901d8..ba192014c2 100644 --- a/p2p/transport/tcpreuse/listener_test.go +++ b/p2p/transport/tcpreuse/listener_test.go @@ -64,9 +64,9 @@ func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func TestListenerSingle(t *testing.T) { listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") const N = 64 - for _, disableReuseport := range []bool{true, false} { - t.Run(fmt.Sprintf("multistream-reuseport:%v", disableReuseport), func(t *testing.T) { - cm := NewConnMgr(disableReuseport, nil, nil) + for _, enableReuseport := range []bool{true, false} { + t.Run(fmt.Sprintf("multistream-reuseport:%v", enableReuseport), func(t *testing.T) { + cm := NewConnMgr(enableReuseport, nil, nil) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) go func() { @@ -116,8 +116,8 @@ func TestListenerSingle(t *testing.T) { wg.Wait() }) - t.Run(fmt.Sprintf("WebSocket-reuseport:%v", disableReuseport), func(t *testing.T) { - cm := NewConnMgr(disableReuseport, nil, nil) + t.Run(fmt.Sprintf("WebSocket-reuseport:%v", enableReuseport), func(t *testing.T) { + cm := NewConnMgr(enableReuseport, nil, nil) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) require.NoError(t, err) wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} @@ -168,8 +168,8 @@ func TestListenerSingle(t *testing.T) { wg.Wait() }) - t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", disableReuseport), func(t *testing.T) { - cm := NewConnMgr(disableReuseport, nil, nil) + t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", enableReuseport), func(t *testing.T) { + cm := NewConnMgr(enableReuseport, nil, nil) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) require.NoError(t, err) defer l.Close() @@ -227,8 +227,8 @@ func TestListenerSingle(t *testing.T) { func TestListenerMultiplexed(t *testing.T) { listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") const N = 20 - for _, disableReuseport := range []bool{true, false} { - cm := NewConnMgr(disableReuseport, nil, nil) + for _, enableReuseport := range []bool{true, false} { + cm := NewConnMgr(enableReuseport, nil, nil) msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) defer msl.Close() @@ -405,7 +405,7 @@ func TestListenerClose(t *testing.T) { testClose := func(listenAddr ma.Multiaddr) { // listen on port 0 - cm := NewConnMgr(true, nil, nil) + cm := NewConnMgr(false, nil, nil) ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) From c314a2ab80fa1ee7d297a1a6354923f87a351272 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:02:07 -0700 Subject: [PATCH 20/32] tcpreuse: return an error on multiple listeners for the same addr+conntype --- p2p/transport/tcpreuse/listener.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index c23586d2e0..c27b0fc37b 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -146,6 +146,8 @@ type multiplexedListener struct { wg sync.WaitGroup } +var ErrListenerExists = errors.New("listener already exists for this conn type on this address") + func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) { if !connType.IsKnown() { return nil, fmt.Errorf("unknown connection type: %s", connType) @@ -153,13 +155,12 @@ func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType m.mx.Lock() defer m.mx.Unlock() - l, ok := m.listeners[connType] - if ok { - return l, nil + if _, ok := m.listeners[connType]; ok { + return nil, ErrListenerExists } ctx, cancel := context.WithCancel(m.ctx) - l = &demultiplexedListener{ + l := &demultiplexedListener{ buffer: make(chan manet.Conn), inner: m.Listener, ctx: ctx, From 56dabd4f208ce9a560129e383054250557ec4513 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:10:29 -0700 Subject: [PATCH 21/32] websocket: return consistent error --- p2p/transport/websocket/conn.go | 39 +++++++++++++++------------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index df97189d90..ce51611703 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -1,6 +1,7 @@ package websocket import ( + "errors" "io" "net" "sync" @@ -24,7 +25,7 @@ type Conn struct { secure bool DefaultMessageType int reader io.Reader - closeOnce sync.Once + closeOnceVal func() error laddr ma.Multiaddr raddr ma.Multiaddr @@ -52,13 +53,15 @@ func NewConn(raw *ws.Conn, secure bool) *Conn { return nil } - return &Conn{ + c := &Conn{ Conn: raw, secure: secure, DefaultMessageType: ws.BinaryMessage, laddr: laddr, raddr: raddr, } + c.closeOnceVal = sync.OnceValue(c.closeOnceFn) + return c } // LocalMultiaddr implements manet.Conn. @@ -142,27 +145,21 @@ func (c *Conn) Scope() network.ConnManagementScope { return nil } -// Close closes the connection. Only the first call to Close will receive the -// close error, subsequent and concurrent calls will return nil. +// Close closes the connection. +// subsequent and concurrent calls will return the same error value. // This method is thread-safe. -// TODO: Fix this ^ func (c *Conn) Close() error { - var err error - c.closeOnce.Do(func() { - err1 := c.Conn.WriteControl( - ws.CloseMessage, - ws.FormatCloseMessage(ws.CloseNormalClosure, "closed"), - time.Now().Add(GracefulCloseTimeout), - ) - err2 := c.Conn.Close() - switch { - case err1 != nil: - err = err1 - case err2 != nil: - err = err2 - } - }) - return err + return c.closeOnceVal() +} + +func (c *Conn) closeOnceFn() error { + err1 := c.Conn.WriteControl( + ws.CloseMessage, + ws.FormatCloseMessage(ws.CloseNormalClosure, "closed"), + time.Now().Add(GracefulCloseTimeout), + ) + err2 := c.Conn.Close() + return errors.Join(err1, err2) } func (c *Conn) LocalAddr() net.Addr { From 9ebfc89b7f40208e798a9f20a2ea514e7f505e46 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:11:04 -0700 Subject: [PATCH 22/32] Remove todo --- p2p/transport/tcpreuse/demultiplex.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index ab79286bf9..bd531c2918 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -9,8 +9,6 @@ import ( manet "github.com/multiformats/go-multiaddr/net" ) -// TODO: We can unexport this type and rely completely on the multiaddr passed in to -// DemultiplexedListen. type DemultiplexedConnType int const ( From a722a7ca3862eeee51f06a3a8b112a7ea40f3f25 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:14:44 -0700 Subject: [PATCH 23/32] tcp: revert parameterize metrics collector --- p2p/transport/tcp/metrics.go | 29 ++++++++++------------------- p2p/transport/tcp/metrics_none.go | 8 ++------ p2p/transport/tcp/tcp.go | 6 ++---- 3 files changed, 14 insertions(+), 29 deletions(-) diff --git a/p2p/transport/tcp/metrics.go b/p2p/transport/tcp/metrics.go index 50820d870c..213ee2200a 100644 --- a/p2p/transport/tcp/metrics.go +++ b/p2p/transport/tcp/metrics.go @@ -24,7 +24,7 @@ var ( const collectFrequency = 10 * time.Second -var defaultCollector *aggregatingCollector +var collector *aggregatingCollector var initMetricsOnce sync.Once @@ -34,8 +34,8 @@ func initMetrics() { bytesSentDesc = prometheus.NewDesc("tcp_sent_bytes", "TCP bytes sent", nil, nil) bytesRcvdDesc = prometheus.NewDesc("tcp_rcvd_bytes", "TCP bytes received", nil, nil) - defaultCollector = newAggregatingCollector() - prometheus.MustRegister(defaultCollector) + collector = newAggregatingCollector() + prometheus.MustRegister(collector) const direction = "direction" @@ -196,7 +196,7 @@ func (c *aggregatingCollector) Collect(metrics chan<- prometheus.Metric) { func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) { c.mutex.Lock() - c.removeConn(conn.id) + collector.removeConn(conn.id) c.mutex.Unlock() closedConns.WithLabelValues(direction).Inc() } @@ -204,8 +204,6 @@ func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) { type tracingConn struct { id uint64 - collector *aggregatingCollector - startTime time.Time isClient bool @@ -215,8 +213,7 @@ type tracingConn struct { closeErr error } -// newTracingConn wraps a manet.Conn with a tracingConn. A nil collector will use the default collector. -func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (*tracingConn, error) { +func newTracingConn(c manet.Conn, isClient bool) (*tracingConn, error) { initMetricsOnce.Do(func() { initMetrics() }) conn, err := tcp.NewConn(c) if err != nil { @@ -227,12 +224,8 @@ func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool isClient: isClient, Conn: c, tcpConn: conn, - collector: collector, - } - if tc.collector == nil { - tc.collector = defaultCollector } - tc.id = tc.collector.AddConn(tc) + tc.id = collector.AddConn(tc) newConns.WithLabelValues(tc.getDirection()).Inc() return tc, nil } @@ -246,7 +239,7 @@ func (c *tracingConn) getDirection() string { func (c *tracingConn) Close() error { c.closeOnce.Do(func() { - c.collector.ClosedConn(c, c.getDirection()) + collector.ClosedConn(c, c.getDirection()) c.closeErr = c.Conn.Close() }) return c.closeErr @@ -265,12 +258,10 @@ func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) { type tracingListener struct { manet.Listener - collector *aggregatingCollector } -// newTracingListener wraps a manet.Listener with a tracingListener. A nil collector will use the default collector. -func newTracingListener(l manet.Listener, collector *aggregatingCollector) *tracingListener { - return &tracingListener{Listener: l, collector: collector} +func newTracingListener(l manet.Listener) *tracingListener { + return &tracingListener{Listener: l} } func (l *tracingListener) Accept() (manet.Conn, error) { @@ -278,5 +269,5 @@ func (l *tracingListener) Accept() (manet.Conn, error) { if err != nil { return nil, err } - return newTracingConn(conn, l.collector, false) + return newTracingConn(conn, false) } diff --git a/p2p/transport/tcp/metrics_none.go b/p2p/transport/tcp/metrics_none.go index cbee982070..8538b30c89 100644 --- a/p2p/transport/tcp/metrics_none.go +++ b/p2p/transport/tcp/metrics_none.go @@ -6,9 +6,5 @@ package tcp import manet "github.com/multiformats/go-multiaddr/net" -type aggregatingCollector struct{} - -func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (manet.Conn, error) { - return c, nil -} -func newTracingListener(l manet.Listener, collector *aggregatingCollector) manet.Listener { return l } +func newTracingConn(c manet.Conn, _ bool) (manet.Conn, error) { return c, nil } +func newTracingListener(l manet.Listener) manet.Listener { return l } diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index e197b26660..1b145c2b45 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -142,8 +142,6 @@ type TcpTransport struct { rcmgr network.ResourceManager reuse reuseport.Transport - - metricsCollector *aggregatingCollector } var _ transport.Transport = &TcpTransport{} @@ -233,7 +231,7 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p c := conn if t.enableMetrics { var err error - c, err = newTracingConn(conn, t.metricsCollector, true) + c, err = newTracingConn(conn, true) if err != nil { return nil, err } @@ -279,7 +277,7 @@ func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { } if t.enableMetrics { - list = newTracingListener(&tcpListener{list, 0}, t.metricsCollector) + list = newTracingListener(&tcpListener{list, 0}) } return t.upgrader.UpgradeListener(t, list), nil } From 85be1f5f5fb0b6ea146929942ea2897c1ed48c6c Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:18:44 -0700 Subject: [PATCH 24/32] typo --- p2p/net/upgrader/listener.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index 9bee564d45..a319ece31c 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -107,7 +107,7 @@ func (l *listener) handleIncoming() { if err != nil { log.Debugw("resource manager blocked accept of new connection", "error", err) if err := maconn.Close(); err != nil { - log.Warnf("failed to incoming connection rejected by resource manager: %s", err) + log.Warnf("failed to open incoming connection. rejected by resource manager: %s", err) } continue } From 18af9144b91b3618aedcdb1358b279387810338d Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:20:11 -0700 Subject: [PATCH 25/32] PR review --- libp2p_test.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/libp2p_test.go b/libp2p_test.go index 0aa261d23c..3de82946d8 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -761,6 +761,17 @@ func TestSharedTCPAddr(t *testing.T) { ListenAddrStrings("/ip4/0.0.0.0/tcp/8888/ws"), ) require.NoError(t, err) - fmt.Println(h.Addrs()) + sawTCP := false + sawWS := false + for _, addr := range h.Addrs() { + if strings.HasSuffix(addr.String(), "/tcp/8888") { + sawTCP = true + } + if strings.HasSuffix(addr.String(), "/tcp/8888/ws") { + sawWS = true + } + } + require.True(t, sawTCP) + require.True(t, sawWS) h.Close() } From f9f39b78f0da70d6db5faf28c969933435aa6ee4 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:21:37 -0700 Subject: [PATCH 26/32] typo --- p2p/net/upgrader/listener.go | 2 +- p2p/transport/tcpreuse/listener.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index a319ece31c..c2e81d2e93 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -107,7 +107,7 @@ func (l *listener) handleIncoming() { if err != nil { log.Debugw("resource manager blocked accept of new connection", "error", err) if err := maconn.Close(); err != nil { - log.Warnf("failed to open incoming connection. rejected by resource manager: %s", err) + log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err) } continue } diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index c27b0fc37b..0cbfd0738c 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -204,7 +204,7 @@ func (m *multiplexedListener) run() error { if err != nil { log.Debugw("resource manager blocked accept of new connection", "error", err) if err := c.Close(); err != nil { - log.Warnf("failed to incoming connection rejected by resource manager: %s", err) + log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err) } continue } From 9ca4ae7fda5e8c13132e9b5fd341eddee373fb06 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:23:35 -0700 Subject: [PATCH 27/32] add timeout --- p2p/transport/tcpreuse/listener.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 0cbfd0738c..e23ba99879 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "sync" + "time" logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/connmgr" @@ -18,6 +19,9 @@ import ( const acceptQueueSize = 64 // It is fine to read 3 bytes from 64 connections in parallel. +// How long we wait for a connection to be accepted before dropping it. +const acceptTimeout = 30 * time.Second + var log = logging.Logger("tcp-demultiplex") // ConnMgr enables you to share the same listen address between TCP and WebSocket transports. @@ -221,6 +225,8 @@ func (m *multiplexedListener) run() error { go func() { defer func() { <-acceptQueue }() defer m.wg.Done() + ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout) + defer cancelCtx() t, c, err := identifyConnType(c) if err != nil { connScope.Done() @@ -254,9 +260,8 @@ func (m *multiplexedListener) run() error { select { case demux.buffer <- connWithScope: - case <-m.ctx.Done(): + case <-ctx.Done(): connWithScope.Close() - return } }() } From 3d265837cfa2bdcd896fe253819a9c0da0897fbd Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:26:02 -0700 Subject: [PATCH 28/32] pr nit --- p2p/transport/tcpreuse/demultiplex.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index bd531c2918..62056ccfcb 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -9,6 +9,9 @@ import ( manet "github.com/multiformats/go-multiaddr/net" ) +// This is readiung the first 3 bytes of the packet. It should be instant. +const identifyConnTimeout = 1 * time.Second + type DemultiplexedConnType int const ( @@ -40,7 +43,7 @@ func (t DemultiplexedConnType) IsKnown() bool { // It Callers must not use the passed in Conn after this // function returns. if an error is returned, the connection will be closed. func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) { - if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { + if err := c.SetReadDeadline(time.Now().Add(identifyConnTimeout)); err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) } From 348736996bd8af2ac7ba31c4c96c9a197c00f84d Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:26:10 -0700 Subject: [PATCH 29/32] remove unused option --- p2p/transport/tcp/tcp.go | 7 ------- p2p/transport/websocket/websocket.go | 7 ------- 2 files changed, 14 deletions(-) diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index 1b145c2b45..c80723436e 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -117,13 +117,6 @@ func WithMetrics() Option { } } -func WithSharedTCP(mgr *tcpreuse.ConnMgr) Option { - return func(tr *TcpTransport) error { - tr.sharedTcp = mgr - return nil - } -} - // TcpTransport is the TCP transport. type TcpTransport struct { // Connection upgrader for upgrading insecure stream connections to diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 8388a7c1e3..e24cb88c6d 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -81,13 +81,6 @@ func WithTLSConfig(conf *tls.Config) Option { } } -func WithSharedTCP(mgr *tcpreuse.ConnMgr) Option { - return func(t *WebsocketTransport) error { - t.sharedTcp = mgr - return nil - } -} - // WebsocketTransport is the actual go-libp2p transport type WebsocketTransport struct { upgrader transport.Upgrader From 74e9c056919294a00f5f9c12e53c32324f58c185 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 20:51:10 -0700 Subject: [PATCH 30/32] Expand comment --- options.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/options.go b/options.go index 0821c1eb9b..0329b7e60b 100644 --- a/options.go +++ b/options.go @@ -644,7 +644,11 @@ func WithFxOption(opts ...fx.Option) Option { } } -// ShareTCPListener shares the same listen address between TCP and Websocket transports. +// ShareTCPListener shares the same listen address between TCP and Websocket +// transports. This lets both transports use the same TCP port. +// +// Currently this behavior is Opt-in. In a future release this will be the +// default, and this option will be removed. func ShareTCPListener() Option { return func(cfg *Config) error { cfg.ShareTCPListener = true From fd536432cea6f5cfb5f369a38c06b4b34fce304d Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 21:06:46 -0700 Subject: [PATCH 31/32] remove 0x160304 magic byte match --- p2p/transport/tcpreuse/demultiplex.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index 62056ccfcb..fe58243d67 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -89,7 +89,7 @@ func IsHTTP(s Prefix) bool { func IsTLS(s Prefix) bool { switch string(s[:]) { - case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03", "\x16\x03\x04": + case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03": return true default: return false From ccd1609deb5546dacdee305e5b002ad253486580 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Thu, 31 Oct 2024 12:30:08 -0700 Subject: [PATCH 32/32] Fix test to handle existing listener error --- p2p/transport/tcpreuse/listener.go | 5 +++-- p2p/transport/tcpreuse/listener_test.go | 26 +++++-------------------- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index e23ba99879..326e1e15b7 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -111,6 +111,7 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed t.mx.Lock() defer t.mx.Unlock() delete(t.listeners, laddr.String()) + delete(t.listeners, l.Multiaddr().String()) return l.Close() } ml = &multiplexedListener{ @@ -121,6 +122,8 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed connGater: t.connGater, rcmgr: t.rcmgr, } + t.listeners[laddr.String()] = ml + t.listeners[l.Multiaddr().String()] = ml dl, err := ml.DemultiplexedListen(connType) if err != nil { @@ -128,8 +131,6 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed return nil, errors.Join(err, cerr) } - t.listeners[laddr.String()] = ml - ml.wg.Add(1) go ml.run() diff --git a/p2p/transport/tcpreuse/listener_test.go b/p2p/transport/tcpreuse/listener_test.go index ba192014c2..bdb030a676 100644 --- a/p2p/transport/tcpreuse/listener_test.go +++ b/p2p/transport/tcpreuse/listener_test.go @@ -402,7 +402,6 @@ func TestListenerMultiplexed(t *testing.T) { } func TestListenerClose(t *testing.T) { - testClose := func(listenAddr ma.Multiaddr) { // listen on port 0 cm := NewConnMgr(false, nil, nil) @@ -417,24 +416,20 @@ func TestListenerClose(t *testing.T) { require.NoError(t, err) require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + ml.Close() + mll, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) - require.Equal(t, mll, ml) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + mll.Close() wl.Close() - ml.Close() ml, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) - require.NotEqual(t, ml.Multiaddr(), mll.Multiaddr()) - require.NotEqual(t, mll, ml) - ml.Close() - // Now listen on the specific port previously used listenAddr = ml.Multiaddr() - ml, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) - require.NoError(t, err) wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) require.NoError(t, err) require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) @@ -444,19 +439,8 @@ func TestListenerClose(t *testing.T) { require.NoError(t, err) require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) - mll, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) - require.NoError(t, err) - require.Equal(t, mll, ml) - - wl.Close() - ml.Close() - - ml, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) - require.NoError(t, err) - - require.Equal(t, ml.Multiaddr(), mll.Multiaddr()) - require.NotEqual(t, mll, ml) ml.Close() + wl.Close() } listenAddrs := []ma.Multiaddr{ma.StringCast("/ip4/0.0.0.0/tcp/0"), ma.StringCast("/ip6/::/tcp/0")} for _, listenAddr := range listenAddrs {