diff --git a/connection_unix.go b/connection_unix.go index ac669d4b5..ff22a7e22 100644 --- a/connection_unix.go +++ b/connection_unix.go @@ -252,6 +252,7 @@ func (c *conn) asyncWrite(a any) (err error) { }() if !c.opened { + c.outboundBuffer.Release() // release all remaining bytes in the outbound buffer return net.ErrClosed } @@ -273,6 +274,7 @@ func (c *conn) asyncWritev(a any) (err error) { }() if !c.opened { + c.outboundBuffer.Release() // release all remaining bytes in the outbound buffer return net.ErrClosed } diff --git a/connection_windows.go b/connection_windows.go index fe8ccb187..e40136e2e 100644 --- a/connection_windows.go +++ b/connection_windows.go @@ -421,21 +421,21 @@ var workerPool = nonBlockingPool{Pool: goPool.Default()} // func (c *conn) Gfd() gfd.GFD { return gfd.GFD{} } func (c *conn) AsyncWrite(buf []byte, cb AsyncCallback) error { - _, err := c.Write(buf) - - callback := func() error { + fn := func() error { + _, err := c.Write(buf) if cb != nil { _ = cb(c, err) } return err } + var err error select { - case c.loop.ch <- callback: + case c.loop.ch <- fn: default: // If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop. err = workerPool.Go(func() { - c.loop.ch <- callback + c.loop.ch <- fn }) } diff --git a/gnet_test.go b/gnet_test.go index b6755e038..2b02a08b6 100644 --- a/gnet_test.go +++ b/gnet_test.go @@ -1532,6 +1532,89 @@ func TestMultiInstLoggerRace(t *testing.T) { assert.ErrorIs(t, g.Wait(), errorx.ErrUnsupportedProtocol) } +type testDisconnectedAsyncWriteServer struct { + BuiltinEventEngine + tester *testing.T + addr string + writev, clientStarted bool + exit atomic.Bool +} + +func (t *testDisconnectedAsyncWriteServer) OnTraffic(c Conn) Action { + _, err := c.Next(0) + require.NoErrorf(t.tester, err, "c.Next error: %v", err) + + go func() { + for range time.Tick(100 * time.Millisecond) { + if t.exit.Load() { + break + } + + if t.writev { + err = c.AsyncWritev([][]byte{[]byte("hello"), []byte("hello")}, func(_ Conn, err error) error { + if err == nil { + return nil + } + + require.ErrorIsf(t.tester, err, net.ErrClosed, "expected error: %v, but got: %v", net.ErrClosed, err) + t.exit.Store(true) + return nil + }) + } else { + err = c.AsyncWrite([]byte("hello"), func(_ Conn, err error) error { + if err == nil { + return nil + } + + require.ErrorIsf(t.tester, err, net.ErrClosed, "expected error: %v, but got: %v", net.ErrClosed, err) + t.exit.Store(true) + return nil + }) + } + + if err != nil { + return + } + } + }() + + return None +} + +func (t *testDisconnectedAsyncWriteServer) OnTick() (delay time.Duration, action Action) { + delay = 500 * time.Millisecond + + if t.exit.Load() { + action = Shutdown + return + } + + if !t.clientStarted { + t.clientStarted = true + go func() { + c, err := net.Dial("tcp", t.addr) + require.NoError(t.tester, err) + _, err = c.Write([]byte("hello")) + require.NoError(t.tester, err) + require.NoError(t.tester, c.Close()) + }() + } + return +} + +func TestDisconnectedAsyncWrite(t *testing.T) { + t.Run("async-write", func(t *testing.T) { + events := &testDisconnectedAsyncWriteServer{tester: t, addr: ":10000"} + err := Run(events, "tcp://:10000", WithTicker(true)) + assert.NoError(t, err) + }) + t.Run("async-writev", func(t *testing.T) { + events := &testDisconnectedAsyncWriteServer{tester: t, addr: ":10001", writev: true} + err := Run(events, "tcp://:10001", WithTicker(true)) + assert.NoError(t, err) + }) +} + var errIncompletePacket = errors.New("incomplete packet") type simServer struct {