Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to handle context cancellations for TCP protocol #1389

Merged
merged 13 commits into from
Sep 23, 2024
Prev Previous commit
Next Next commit
Fixed data race on connection close().
  • Loading branch information
tinybit committed Aug 29, 2024
commit f45be273206be2e46aefba79b1a8edf4c526aa3b
25 changes: 24 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
conn net.Conn
debugf = func(format string, v ...any) {}
)

switch {
case opt.DialContext != nil:
conn, err = opt.DialContext(ctx, addr)
Expand All @@ -54,9 +55,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
conn, err = net.DialTimeout("tcp", addr, opt.DialTimeout)
}
}

if err != nil {
return nil, err
}

if opt.Debug {
if opt.Debugf != nil {
debugf = func(format string, v ...any) {
Expand All @@ -69,6 +72,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
debugf = log.New(os.Stdout, fmt.Sprintf("[clickhouse][conn=%d][%s]", num, conn.RemoteAddr()), 0).Printf
}
}

compression := CompressionNone
if opt.Compression != nil {
switch opt.Compression.Method {
Expand Down Expand Up @@ -97,9 +101,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
maxCompressionBuffer: opt.MaxCompressionBuffer,
}
)

if err := connect.handshake(opt.Auth.Database, opt.Auth.Username, opt.Auth.Password); err != nil {
return nil, err
}

if connect.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM {
if err := connect.sendAddendum(); err != nil {
return nil, err
Expand All @@ -110,6 +116,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
if num == 1 && !resources.ClientMeta.IsSupportedClickHouseVersion(connect.server.Version) {
debugf("[handshake] WARNING: version %v of ClickHouse is not supported by this client - client supports %v", connect.server.Version, resources.ClientMeta.SupportedVersions())
}

return connect, nil
}

Expand Down Expand Up @@ -155,9 +162,11 @@ func (c *connect) settings(querySettings Settings) []proto.Setting {
for k, v := range c.opt.Settings {
settings = append(settings, settingToProtoSetting(k, v))
}

for k, v := range querySettings {
settings = append(settings, settingToProtoSetting(k, v))
}

return settings
}

Expand All @@ -174,20 +183,25 @@ func (c *connect) isBad() bool {
if err := c.connCheck(); err != nil {
return true
}

return false
}

func (c *connect) close() error {
c.mutex.Lock()
if c.closed {
c.mutex.Unlock()
return nil
}

c.closed = true
c.mutex.Unlock()

c.buffer = nil
c.reader = nil
if err := c.conn.Close(); err != nil {
return err
}

return nil
}

Expand All @@ -196,6 +210,7 @@ func (c *connect) progress() (*Progress, error) {
if err := progress.Decode(c.reader, c.revision); err != nil {
return nil, err
}

c.debugf("[progress] %s", &progress)
return &progress, nil
}
Expand All @@ -205,6 +220,7 @@ func (c *connect) exception() error {
if err := e.Decode(c.reader); err != nil {
return err
}

c.debugf("[exception] %s", e.Error())
return &e
}
Expand Down Expand Up @@ -242,6 +258,7 @@ func (c *connect) sendData(block *proto.Block, name string) error {
if err := block.EncodeHeader(c.buffer, c.revision); err != nil {
return err
}

for i := range block.Columns {
if err := block.EncodeColumn(c.buffer, c.revision, i); err != nil {
return err
Expand All @@ -257,9 +274,11 @@ func (c *connect) sendData(block *proto.Block, name string) error {
compressionOffset = 0
}
}

if err := c.compressBuffer(compressionOffset); err != nil {
return err
}

if err := c.flush(); err != nil {
switch {
case errors.Is(err, syscall.EPIPE):
Expand All @@ -273,9 +292,11 @@ func (c *connect) sendData(block *proto.Block, name string) error {
}
return err
}

defer func() {
c.buffer.Reset()
}()

return nil
}

Expand Down Expand Up @@ -324,10 +345,12 @@ func (c *connect) flush() error {
// Nothing to flush.
return nil
}

n, err := c.conn.Write(c.buffer.Buf)
if err != nil {
return errors.Wrap(err, "write")
}

if n != len(c.buffer.Buf) {
return errors.New("wrote less than expected")
}
Expand Down
Loading