diff --git a/conn.go b/conn.go index 0ef03905c6..6adb94618e 100644 --- a/conn.go +++ b/conn.go @@ -172,8 +172,7 @@ func (c *connect) settings(querySettings Settings) []proto.Setting { } func (c *connect) isBad() bool { - switch { - case c.closed: + if c.isClosed() { return true } @@ -188,6 +187,20 @@ func (c *connect) isBad() bool { return false } +func (c *connect) isClosed() bool { + c.mutexClose.Lock() + defer c.mutexClose.Unlock() + + return c.closed +} + +func (c *connect) setClosed() { + c.mutexClose.Lock() + defer c.mutexClose.Unlock() + + c.closed = true +} + func (c *connect) close() error { c.mutexClose.Lock() if c.closed { @@ -198,7 +211,11 @@ func (c *connect) close() error { c.mutexClose.Unlock() c.buffer = nil + + c.mutex.Lock() c.reader = nil + c.mutex.Unlock() + if err := c.conn.Close(); err != nil { return err } @@ -238,7 +255,7 @@ func (c *connect) compressBuffer(start int) error { } func (c *connect) sendData(block *proto.Block, name string) error { - if c.closed { + if c.isClosed() { err := errors.New("attempted sending on closed connection") c.debugf("[send data] err: %v", err) return err @@ -284,10 +301,10 @@ func (c *connect) sendData(block *proto.Block, name string) error { switch { case errors.Is(err, syscall.EPIPE): c.debugf("[send data] pipe is broken, closing connection") - c.closed = true + c.setClosed() case errors.Is(err, io.EOF): c.debugf("[send data] unexpected EOF, closing connection") - c.closed = true + c.setClosed() default: c.debugf("[send data] unexpected error: %v", err) } @@ -302,7 +319,7 @@ func (c *connect) sendData(block *proto.Block, name string) error { } func (c *connect) readData(ctx context.Context, packet byte, compressible bool) (*proto.Block, error) { - if c.closed { + if c.isClosed() { err := errors.New("attempted reading on closed connection") c.debugf("[read data] err: %v", err) return nil, err