Skip to content

Commit

Permalink
opt: release outboundBuffer immediately in AsyncWrite(v) if Conn is c…
Browse files Browse the repository at this point in the history
…losed (#673)

Fixes #672
  • Loading branch information
panjf2000 authored Jan 11, 2025
1 parent e9a1101 commit a15081a
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 5 deletions.
2 changes: 2 additions & 0 deletions connection_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ func (c *conn) asyncWrite(a any) (err error) {
}()

if !c.opened {
c.outboundBuffer.Release() // release all remaining bytes in the outbound buffer
return net.ErrClosed
}

Expand All @@ -273,6 +274,7 @@ func (c *conn) asyncWritev(a any) (err error) {
}()

if !c.opened {
c.outboundBuffer.Release() // release all remaining bytes in the outbound buffer
return net.ErrClosed
}

Expand Down
10 changes: 5 additions & 5 deletions connection_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,21 +421,21 @@ var workerPool = nonBlockingPool{Pool: goPool.Default()}
// func (c *conn) Gfd() gfd.GFD { return gfd.GFD{} }

func (c *conn) AsyncWrite(buf []byte, cb AsyncCallback) error {
_, err := c.Write(buf)

callback := func() error {
fn := func() error {
_, err := c.Write(buf)
if cb != nil {
_ = cb(c, err)
}
return err
}

var err error
select {
case c.loop.ch <- callback:
case c.loop.ch <- fn:
default:
// If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop.
err = workerPool.Go(func() {
c.loop.ch <- callback
c.loop.ch <- fn
})
}

Expand Down
83 changes: 83 additions & 0 deletions gnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,89 @@ func TestMultiInstLoggerRace(t *testing.T) {
assert.ErrorIs(t, g.Wait(), errorx.ErrUnsupportedProtocol)
}

type testDisconnectedAsyncWriteServer struct {
BuiltinEventEngine
tester *testing.T
addr string
writev, clientStarted bool
exit atomic.Bool
}

func (t *testDisconnectedAsyncWriteServer) OnTraffic(c Conn) Action {
_, err := c.Next(0)
require.NoErrorf(t.tester, err, "c.Next error: %v", err)

go func() {
for range time.Tick(100 * time.Millisecond) {
if t.exit.Load() {
break
}

if t.writev {
err = c.AsyncWritev([][]byte{[]byte("hello"), []byte("hello")}, func(_ Conn, err error) error {
if err == nil {
return nil
}

require.ErrorIsf(t.tester, err, net.ErrClosed, "expected error: %v, but got: %v", net.ErrClosed, err)
t.exit.Store(true)
return nil
})
} else {
err = c.AsyncWrite([]byte("hello"), func(_ Conn, err error) error {
if err == nil {
return nil
}

require.ErrorIsf(t.tester, err, net.ErrClosed, "expected error: %v, but got: %v", net.ErrClosed, err)
t.exit.Store(true)
return nil
})
}

if err != nil {
return
}
}
}()

return None
}

func (t *testDisconnectedAsyncWriteServer) OnTick() (delay time.Duration, action Action) {
delay = 500 * time.Millisecond

if t.exit.Load() {
action = Shutdown
return
}

if !t.clientStarted {
t.clientStarted = true
go func() {
c, err := net.Dial("tcp", t.addr)
require.NoError(t.tester, err)
_, err = c.Write([]byte("hello"))
require.NoError(t.tester, err)
require.NoError(t.tester, c.Close())
}()
}
return
}

func TestDisconnectedAsyncWrite(t *testing.T) {
t.Run("async-write", func(t *testing.T) {
events := &testDisconnectedAsyncWriteServer{tester: t, addr: ":10000"}
err := Run(events, "tcp://:10000", WithTicker(true))
assert.NoError(t, err)
})
t.Run("async-writev", func(t *testing.T) {
events := &testDisconnectedAsyncWriteServer{tester: t, addr: ":10001", writev: true}
err := Run(events, "tcp://:10001", WithTicker(true))
assert.NoError(t, err)
})
}

var errIncompletePacket = errors.New("incomplete packet")

type simServer struct {
Expand Down

0 comments on commit a15081a

Please sign in to comment.