diff --git a/connection/connection.go b/connection/connection.go index bf39972..470fcb5 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -166,13 +166,21 @@ func (c *Connection) HandleEvent(fd int, events poller.Event) { if !c.outBuffer.IsEmpty() { if events&poller.EventWrite != 0 { - c.handleWrite(fd) + // if return true, it means closed + if c.handleWrite(fd) { + return + } + if c.outBuffer.IsEmpty() { c.outBuffer.Reset() } } } else if events&poller.EventRead != 0 { - c.handleRead(fd) + // if return true, it means closed + if c.handleRead(fd) { + return + } + if c.inBuffer.IsEmpty() { c.inBuffer.Reset() } @@ -194,13 +202,14 @@ func (c *Connection) handlerProtocol(tmpBuffer *[]byte, buffer *ringbuffer.RingB } } -func (c *Connection) handleRead(fd int) { +func (c *Connection) handleRead(fd int) (closed bool) { // TODO 避免这次内存拷贝 buf := c.loop.PacketBuf() n, err := unix.Read(c.fd, buf) if n == 0 || err != nil { if err != unix.EAGAIN { c.handleClose(fd) + closed = true } return } @@ -221,11 +230,12 @@ func (c *Connection) handleRead(fd int) { } if len(buf) != 0 { - c.sendInLoop(buf) + closed = c.sendInLoop(buf) } + return } -func (c *Connection) handleWrite(fd int) { +func (c *Connection) handleWrite(fd int) (closed bool) { first, end := c.outBuffer.PeekAll() n, err := unix.Write(c.fd, first) if err != nil { @@ -233,6 +243,7 @@ func (c *Connection) handleWrite(fd int) { return } c.handleClose(fd) + closed = true return } c.outBuffer.Retrieve(n) @@ -244,6 +255,7 @@ func (c *Connection) handleWrite(fd int) { return } c.handleClose(fd) + closed = true return } c.outBuffer.Retrieve(n) @@ -254,6 +266,8 @@ func (c *Connection) handleWrite(fd int) { log.Error("[EnableRead]", err) } } + + return } func (c *Connection) handleClose(fd int) { @@ -271,13 +285,14 @@ func (c *Connection) handleClose(fd int) { } } -func (c *Connection) sendInLoop(data []byte) { +func (c *Connection) sendInLoop(data []byte) (closed bool) { if !c.outBuffer.IsEmpty() { _, _ = c.outBuffer.Write(data) } else { n, err := unix.Write(c.fd, data) if err != nil && err != unix.EAGAIN { c.handleClose(c.fd) + closed = true return } @@ -291,6 +306,8 @@ func (c *Connection) sendInLoop(data []byte) { _ = c.loop.EnableReadWrite(c.fd) } } + + return } func sockAddrToString(sa unix.Sockaddr) string { diff --git a/poller/epoll.go b/poller/epoll.go index 6ae64b0..e908fd5 100644 --- a/poller/epoll.go +++ b/poller/epoll.go @@ -17,6 +17,7 @@ const writeEvent = unix.EPOLLOUT type Poller struct { fd int eventFd int + buf []byte running atomic.Bool waitDone chan struct{} } @@ -47,6 +48,7 @@ func Create() (*Poller, error) { return &Poller{ fd: fd, eventFd: eventFd, + buf: make([]byte, 8), waitDone: make(chan struct{}), }, nil } @@ -59,10 +61,8 @@ func (ep *Poller) Wake() error { return err } -var buf = make([]byte, 8) - func (ep *Poller) wakeHandlerRead() { - n, err := unix.Read(ep.eventFd, buf) + n, err := unix.Read(ep.eventFd, ep.buf) if err != nil || n != 8 { log.Error("wakeHandlerRead", err, n) }