From f45be273206be2e46aefba79b1a8edf4c526aa3b Mon Sep 17 00:00:00 2001 From: Rim Zaydullin Date: Thu, 29 Aug 2024 17:15:02 +0800 Subject: [PATCH] Fixed data race on connection close(). --- conn.go | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 259fea0ab3..11567b79f4 100644 --- a/conn.go +++ b/conn.go @@ -43,6 +43,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er conn net.Conn debugf = func(format string, v ...any) {} ) + switch { case opt.DialContext != nil: conn, err = opt.DialContext(ctx, addr) @@ -54,9 +55,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er conn, err = net.DialTimeout("tcp", addr, opt.DialTimeout) } } + if err != nil { return nil, err } + if opt.Debug { if opt.Debugf != nil { debugf = func(format string, v ...any) { @@ -69,6 +72,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er debugf = log.New(os.Stdout, fmt.Sprintf("[clickhouse][conn=%d][%s]", num, conn.RemoteAddr()), 0).Printf } } + compression := CompressionNone if opt.Compression != nil { switch opt.Compression.Method { @@ -97,9 +101,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er maxCompressionBuffer: opt.MaxCompressionBuffer, } ) + if err := connect.handshake(opt.Auth.Database, opt.Auth.Username, opt.Auth.Password); err != nil { return nil, err } + if connect.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM { if err := connect.sendAddendum(); err != nil { return nil, err @@ -110,6 +116,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er if num == 1 && !resources.ClientMeta.IsSupportedClickHouseVersion(connect.server.Version) { debugf("[handshake] WARNING: version %v of ClickHouse is not supported by this client - client supports %v", connect.server.Version, resources.ClientMeta.SupportedVersions()) } + return connect, nil } @@ -155,9 +162,11 @@ func (c *connect) settings(querySettings Settings) []proto.Setting { for k, v := range c.opt.Settings { settings = append(settings, settingToProtoSetting(k, v)) } + for k, v := range querySettings { settings = append(settings, settingToProtoSetting(k, v)) } + return settings } @@ -174,20 +183,25 @@ func (c *connect) isBad() bool { if err := c.connCheck(); err != nil { return true } + return false } func (c *connect) close() error { + c.mutex.Lock() if c.closed { + c.mutex.Unlock() return nil } - c.closed = true + c.mutex.Unlock() + c.buffer = nil c.reader = nil if err := c.conn.Close(); err != nil { return err } + return nil } @@ -196,6 +210,7 @@ func (c *connect) progress() (*Progress, error) { if err := progress.Decode(c.reader, c.revision); err != nil { return nil, err } + c.debugf("[progress] %s", &progress) return &progress, nil } @@ -205,6 +220,7 @@ func (c *connect) exception() error { if err := e.Decode(c.reader); err != nil { return err } + c.debugf("[exception] %s", e.Error()) return &e } @@ -242,6 +258,7 @@ func (c *connect) sendData(block *proto.Block, name string) error { if err := block.EncodeHeader(c.buffer, c.revision); err != nil { return err } + for i := range block.Columns { if err := block.EncodeColumn(c.buffer, c.revision, i); err != nil { return err @@ -257,9 +274,11 @@ func (c *connect) sendData(block *proto.Block, name string) error { compressionOffset = 0 } } + if err := c.compressBuffer(compressionOffset); err != nil { return err } + if err := c.flush(); err != nil { switch { case errors.Is(err, syscall.EPIPE): @@ -273,9 +292,11 @@ func (c *connect) sendData(block *proto.Block, name string) error { } return err } + defer func() { c.buffer.Reset() }() + return nil } @@ -324,10 +345,12 @@ func (c *connect) flush() error { // Nothing to flush. return nil } + n, err := c.conn.Write(c.buffer.Buf) if err != nil { return errors.Wrap(err, "write") } + if n != len(c.buffer.Buf) { return errors.New("wrote less than expected") }