Skip to content

Commit

Permalink
Fixed data race on connection close().
Browse files Browse the repository at this point in the history
  • Loading branch information
tinybit committed Aug 29, 2024
1 parent 19d1a3b commit f45be27
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
conn net.Conn
debugf = func(format string, v ...any) {}
)

switch {
case opt.DialContext != nil:
conn, err = opt.DialContext(ctx, addr)
Expand All @@ -54,9 +55,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
conn, err = net.DialTimeout("tcp", addr, opt.DialTimeout)
}
}

if err != nil {
return nil, err
}

if opt.Debug {
if opt.Debugf != nil {
debugf = func(format string, v ...any) {
Expand All @@ -69,6 +72,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
debugf = log.New(os.Stdout, fmt.Sprintf("[clickhouse][conn=%d][%s]", num, conn.RemoteAddr()), 0).Printf
}
}

compression := CompressionNone
if opt.Compression != nil {
switch opt.Compression.Method {
Expand Down Expand Up @@ -97,9 +101,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
maxCompressionBuffer: opt.MaxCompressionBuffer,
}
)

if err := connect.handshake(opt.Auth.Database, opt.Auth.Username, opt.Auth.Password); err != nil {
return nil, err
}

if connect.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM {
if err := connect.sendAddendum(); err != nil {
return nil, err
Expand All @@ -110,6 +116,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
if num == 1 && !resources.ClientMeta.IsSupportedClickHouseVersion(connect.server.Version) {
debugf("[handshake] WARNING: version %v of ClickHouse is not supported by this client - client supports %v", connect.server.Version, resources.ClientMeta.SupportedVersions())
}

return connect, nil
}

Expand Down Expand Up @@ -155,9 +162,11 @@ func (c *connect) settings(querySettings Settings) []proto.Setting {
for k, v := range c.opt.Settings {
settings = append(settings, settingToProtoSetting(k, v))
}

for k, v := range querySettings {
settings = append(settings, settingToProtoSetting(k, v))
}

return settings
}

Expand All @@ -174,20 +183,25 @@ func (c *connect) isBad() bool {
if err := c.connCheck(); err != nil {
return true
}

return false
}

func (c *connect) close() error {
c.mutex.Lock()
if c.closed {
c.mutex.Unlock()
return nil
}

c.closed = true
c.mutex.Unlock()

c.buffer = nil
c.reader = nil
if err := c.conn.Close(); err != nil {
return err
}

return nil
}

Expand All @@ -196,6 +210,7 @@ func (c *connect) progress() (*Progress, error) {
if err := progress.Decode(c.reader, c.revision); err != nil {
return nil, err
}

c.debugf("[progress] %s", &progress)
return &progress, nil
}
Expand All @@ -205,6 +220,7 @@ func (c *connect) exception() error {
if err := e.Decode(c.reader); err != nil {
return err
}

c.debugf("[exception] %s", e.Error())
return &e
}
Expand Down Expand Up @@ -242,6 +258,7 @@ func (c *connect) sendData(block *proto.Block, name string) error {
if err := block.EncodeHeader(c.buffer, c.revision); err != nil {
return err
}

for i := range block.Columns {
if err := block.EncodeColumn(c.buffer, c.revision, i); err != nil {
return err
Expand All @@ -257,9 +274,11 @@ func (c *connect) sendData(block *proto.Block, name string) error {
compressionOffset = 0
}
}

if err := c.compressBuffer(compressionOffset); err != nil {
return err
}

if err := c.flush(); err != nil {
switch {
case errors.Is(err, syscall.EPIPE):
Expand All @@ -273,9 +292,11 @@ func (c *connect) sendData(block *proto.Block, name string) error {
}
return err
}

defer func() {
c.buffer.Reset()
}()

return nil
}

Expand Down Expand Up @@ -324,10 +345,12 @@ func (c *connect) flush() error {
// Nothing to flush.
return nil
}

n, err := c.conn.Write(c.buffer.Buf)
if err != nil {
return errors.Wrap(err, "write")
}

if n != len(c.buffer.Buf) {
return errors.New("wrote less than expected")
}
Expand Down

0 comments on commit f45be27

Please sign in to comment.