From 0df38f9190159806c237412c955d767f68a7d1d8 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Thu, 24 Oct 2024 22:03:54 +0200 Subject: [PATCH] Daily commit --- p2p/test/transport/transport_test.go | 16 ++++++++ p2p/transport/memory/conn.go | 12 +++--- p2p/transport/memory/listener.go | 24 ++++++----- p2p/transport/memory/stream.go | 39 +++++++----------- p2p/transport/memory/stream_test.go | 55 ++++++++++++++++++++++++++ p2p/transport/memory/transport.go | 59 ++++++++++++++++++++++------ 6 files changed, 151 insertions(+), 54 deletions(-) create mode 100644 p2p/transport/memory/stream_test.go diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 7cfab5f3ca..e353ba6526 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -31,6 +31,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" + libp2pmemory "github.com/libp2p/go-libp2p/p2p/transport/memory" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "go.uber.org/mock/gomock" @@ -156,6 +157,21 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "Memory", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.Transport(libp2pmemory.NewTransport)) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/memory/1234")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, } func TestPing(t *testing.T) { diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index 6dac7f87a1..b01f05ed1a 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -2,6 +2,7 @@ package memory import ( "context" + "io" "sync" "sync/atomic" @@ -66,8 +67,8 @@ func (c *conn) IsClosed() bool { func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { id := c.nextStreamID.Add(1) - ra := make(chan []byte) - wa := make(chan []byte) + // TODO: Figure out how to exchange the pipes between the two streams + ra, wa := io.Pipe() return newStream(id, ra, wa), nil } @@ -76,10 +77,9 @@ func (c *conn) AcceptStream() (network.MuxedStream, error) { select { case in := <-c.streamC: id := c.nextStreamID.Add(1) - s := newStream(id, in.outC, in.inC) - c.addStream(id, s) + c.addStream(id, in) - return s, nil + return in, nil } } @@ -88,7 +88,7 @@ func (c *conn) LocalPeer() peer.ID { return c.localPeer } // RemotePeer returns the peer ID of the remote peer. func (c *conn) RemotePeer() peer.ID { return c.remotePeerID } -// RemotePublicKey returns the public key of the remote peer. +// RemotePublicKey returns the public pkey of the remote peer. func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } // LocalMultiaddr returns the local Multiaddr associated diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go index a53f317815..8041e02aae 100644 --- a/p2p/transport/memory/listener.go +++ b/p2p/transport/memory/listener.go @@ -6,17 +6,16 @@ import ( ma "github.com/multiformats/go-multiaddr" "net" "sync" - "sync/atomic" ) type listener struct { + t *transport ctx context.Context cancel context.CancelFunc laddr ma.Multiaddr mu sync.Mutex - connID atomic.Int32 - streamCh chan *stream + connCh chan *conn connections map[int32]*conn } @@ -24,27 +23,26 @@ func (l *listener) Multiaddr() ma.Multiaddr { return l.laddr } -func newListener(laddr ma.Multiaddr, streamCh chan *stream) tpt.Listener { +func newListener(t *transport, laddr ma.Multiaddr) *listener { ctx, cancel := context.WithCancel(context.Background()) return &listener{ - ctx: ctx, - cancel: cancel, - laddr: laddr, - streamCh: streamCh, + t: t, + ctx: ctx, + cancel: cancel, + laddr: laddr, + connCh: make(chan *conn, listenerQueueSize), } } // Accept accepts new connections. func (l *listener) Accept() (tpt.CapableConn, error) { select { - case s := <-l.streamCh: + case c := <-l.connCh: l.mu.Lock() defer l.mu.Unlock() - id := l.connID.Add(1) - c := newConnection(id, s) - l.connections[id] = c - return nil, nil + l.connections[c.id] = c + return c, nil case <-l.ctx.Done(): return nil, l.ctx.Err() } diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index e816daf952..4e425ee5af 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -2,6 +2,7 @@ package memory import ( "errors" + "io" "sync/atomic" "time" @@ -11,8 +12,8 @@ import ( type stream struct { id int32 - inC chan []byte - outC chan []byte + r *io.PipeReader + w *io.PipeWriter readCloseC chan struct{} writeCloseC chan struct{} @@ -20,50 +21,40 @@ type stream struct { closed atomic.Bool } -func newStream(id int32, in, out chan []byte) *stream { +func newStream(id int32, r *io.PipeReader, w *io.PipeWriter) *stream { return &stream{ id: id, - inC: in, - outC: out, - readCloseC: make(chan struct{}), - writeCloseC: make(chan struct{}), + r: r, + w: w, + readCloseC: make(chan struct{}, 1), + writeCloseC: make(chan struct{}, 1), } } -func (s *stream) Read(b []byte) (n int, err error) { +func (s *stream) Read(b []byte) (int, error) { if s.closed.Load() { return 0, network.ErrReset } select { case <-s.readCloseC: - err = network.ErrReset - case r, ok := <-s.inC: - if !ok { - err = network.ErrReset - } else { - n = copy(b, r) - } + return 0, network.ErrReset + default: + return s.r.Read(b) } - - return n, err } -func (s *stream) Write(b []byte) (n int, err error) { +func (s *stream) Write(b []byte) (int, error) { if s.closed.Load() { return 0, network.ErrReset } select { case <-s.writeCloseC: - err = network.ErrReset - case s.outC <- b: - n = len(b) + return 0, network.ErrReset default: - err = network.ErrReset + return s.w.Write(b) } - - return n, err } func (s *stream) Reset() error { diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go new file mode 100644 index 0000000000..844000cd9d --- /dev/null +++ b/p2p/transport/memory/stream_test.go @@ -0,0 +1,55 @@ +package memory + +import ( + "github.com/stretchr/testify/require" + "io" + "testing" +) + +func TestStreamSimpleReadWriteClose(t *testing.T) { + //client, server := getDetachedDataChannels(t) + ra, wb := io.Pipe() + rb, wa := io.Pipe() + + clientStr := newStream(0, ra, wa) + serverStr := newStream(0, rb, wb) + + // send a foobar from the client + n, err := clientStr.Write([]byte("foobar")) + require.NoError(t, err) + require.Equal(t, 6, n) + require.NoError(t, clientStr.CloseWrite()) + // writing after closing should error + _, err = clientStr.Write([]byte("foobar")) + require.Error(t, err) + //require.False(t, clientDone.Load()) + + // now read all the data on the server side + b, err := io.ReadAll(serverStr) + require.NoError(t, err) + require.Equal(t, []byte("foobar"), b) + // reading again should give another io.EOF + n, err = serverStr.Read(make([]byte, 10)) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + //require.False(t, serverDone.Load()) + + // send something back + _, err = serverStr.Write([]byte("lorem ipsum")) + require.NoError(t, err) + require.NoError(t, serverStr.CloseWrite()) + + // and read it at the client + //require.False(t, clientDone.Load()) + b, err = io.ReadAll(clientStr) + require.NoError(t, err) + require.Equal(t, []byte("lorem ipsum"), b) + + // stream is only cleaned up on calling Close or Reset + clientStr.Close() + serverStr.Close() + //require.Eventually(t, func() bool { return clientDone.Load() }, 5*time.Second, 100*time.Millisecond) + // Need to call Close for cleanup. Otherwise the FIN_ACK is never read + require.NoError(t, serverStr.Close()) + //require.Eventually(t, func() bool { return serverDone.Load() }, 5*time.Second, 100*time.Millisecond) +} diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index e7f0f27e02..02eb1d24ee 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -2,28 +2,52 @@ package memory import ( "context" + ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/pnet" tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" + "io" "sync" "sync/atomic" ) +const ( + listenerQueueSize = 16 +) + type transport struct { + pkey ic.PrivKey + pid peer.ID + psk pnet.PSK rcmgr network.ResourceManager mu sync.RWMutex connID atomic.Int32 - listeners map[ma.Multiaddr]*listener + listeners map[string]*listener connections map[int32]*conn } -func NewTransport() *transport { +func NewTransport(key ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { + if rcmgr == nil { + rcmgr = &network.NullResourceManager{} + } + + id, err := peer.IDFromPrivateKey(key) + if err != nil { + return nil, err + } + return &transport{ + rcmgr: rcmgr, + pid: id, + pkey: key, + psk: psk, + listeners: make(map[string]*listener), connections: make(map[int32]*conn), - } + }, nil } func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { @@ -48,14 +72,16 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee // TODO: Check if there is an existing listener for this address t.mu.RLock() defer t.mu.RUnlock() - l := t.listeners[raddr] + l := t.listeners[raddr.String()] - in := make(chan []byte) - out := make(chan []byte) - s := newStream(0, in, out) - l.streamCh <- s + ra, wb := io.Pipe() + rb, wa := io.Pipe() + in, out := newStream(0, ra, wb), newStream(0, rb, wa) + inId, outId := t.connID.Add(1), t.connID.Add(1) - return newConnection(0, s), nil + l.connCh <- newConnection(inId, in) + + return newConnection(outId, out), nil } func (t *transport) CanDial(addr ma.Multiaddr) bool { @@ -63,8 +89,15 @@ func (t *transport) CanDial(addr ma.Multiaddr) bool { } func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { - // TODO: Figure out correct channel type - return newListener(laddr, nil), nil + // TODO: Check if we need to add scope via conn mngr + l := newListener(t, laddr) + + t.mu.Lock() + defer t.mu.Unlock() + + t.listeners[laddr.String()] = l + + return l, nil } func (t *transport) Proxy() bool { @@ -82,6 +115,10 @@ func (t *transport) String() string { func (t *transport) Close() error { // TODO: Go trough all listeners and close them + for _, l := range t.listeners { + l.Close() + } + return nil }