diff --git a/conn.go b/conn.go index 11567b79f4..0ef03905c6 100644 --- a/conn.go +++ b/conn.go @@ -140,6 +140,7 @@ type connect struct { blockBufferSize uint8 maxCompressionBuffer int mutex sync.Mutex + mutexClose sync.Mutex } func (c *connect) settings(querySettings Settings) []proto.Setting { @@ -188,13 +189,13 @@ func (c *connect) isBad() bool { } func (c *connect) close() error { - c.mutex.Lock() + c.mutexClose.Lock() if c.closed { - c.mutex.Unlock() + c.mutexClose.Unlock() return nil } c.closed = true - c.mutex.Unlock() + c.mutexClose.Unlock() c.buffer = nil c.reader = nil diff --git a/tests/context_cancel_test.go b/tests/context_cancel_test.go new file mode 100644 index 0000000000..066fe46333 --- /dev/null +++ b/tests/context_cancel_test.go @@ -0,0 +1,261 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tests + +import ( + "context" + "log" + "testing" + "time" + + "github.com/ClickHouse/clickhouse-go/v2" + "github.com/stretchr/testify/assert" +) + +func TestContextCancellationType1(t *testing.T) { + var ( + q1 = "CREATE DATABASE IF NOT EXISTS test_query_cancellation" + q2 = "DROP TABLE IF EXISTS test_query_cancellation.trips" + q3 = `CREATE TABLE test_query_cancellation.trips ( + trip_id UInt32, + pickup_datetime DateTime, + dropoff_datetime DateTime, + pickup_longitude Nullable(Float64), + pickup_latitude Nullable(Float64), + dropoff_longitude Nullable(Float64), + dropoff_latitude Nullable(Float64), + passenger_count UInt8, + trip_distance Float32, + fare_amount Float32, + extra Float32, + tip_amount Float32, + tolls_amount Float32, + total_amount Float32, + payment_type Enum('CSH' = 1, 'CRE' = 2, 'NOC' = 3, 'DIS' = 4, 'UNK' = 5), + pickup_ntaname LowCardinality(String), + dropoff_ntaname LowCardinality(String) + ) + ENGINE = MergeTree + PRIMARY KEY (pickup_datetime, dropoff_datetime);` + q4 = `INSERT INTO test_query_cancellation.trips + SELECT + number + 1 AS trip_id, + now() - INTERVAL intDiv(number, 100) SECOND AS pickup_datetime, + now() - INTERVAL intDiv(number, 100) SECOND + INTERVAL rand() % 3600 SECOND AS dropoff_datetime, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 - 74.00) AS pickup_longitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 + 40.50) AS pickup_latitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 - 74.00) AS dropoff_longitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 + 40.50) AS dropoff_latitude, + rand() % 6 + 1 AS passenger_count, + (rand() % 2000) / 100.0 AS trip_distance, + (rand() % 5000) / 100.0 AS fare_amount, + (rand() % 500) / 100.0 AS extra, + (rand() % 1000) / 100.0 AS tip_amount, + (rand() % 300) / 100.0 AS tolls_amount, + (rand() % 6000) / 100.0 AS total_amount, + CAST(rand() % 5 + 1 AS Enum('CSH' = 1, 'CRE' = 2, 'NOC' = 3, 'DIS' = 4, 'UNK' = 5)) AS payment_type, + 'Neighborhood ' || toString(rand() % 100 + 1) AS pickup_ntaname, + 'Neighborhood ' || toString(rand() % 100 + 1) AS dropoff_ntaname + FROM numbers(100000000);` + ) + + prepareQueries := []string{q1, q2, q3} + + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + + assert.Nil(t, err) + assert.NotNil(t, conn) + + if err = conn.Ping(context.Background()); err != nil { + return + } + + t.Log("Connected.") + + // prepare table + for _, query := range prepareQueries { + err = conn.Exec(context.Background(), query) + if err != nil { + log.Printf("Finished with error: %v\n", err) + conn.Close() + return + } + } + + // prepare context + ctx, cancelCtx := context.WithCancel(context.Background()) + defer cancelCtx() + + doneCh := make(chan bool, 1) + queryTimeCh := make(chan time.Duration, 1) + + // run query in background + go func() { + log.Println("Running heavy query...") + + start := time.Now() + + defer func() { + log.Printf("Query took: %v\n", time.Since(start)) + queryTimeCh <- time.Since(start) + doneCh <- true + }() + + //err = conn.Exec(ctx, "OPTIMIZE TABLE test_query_cancellation.trips FINAL") + err = conn.Exec(ctx, q4) + if err != nil { + log.Printf("Finished with error: %v\n", err) + return + } + }() + + cancelBackoff := 3 * time.Second + + // let workers run for awhile and stop + go func() { + time.Sleep(cancelBackoff) + cancelCtx() + log.Printf("Context cancelled after %v.", cancelBackoff) + }() + + <-doneCh + conn.Close() + log.Println("Done.") + + queryTime := <-queryTimeCh + + assert.Less(t, queryTime-cancelBackoff, time.Second) +} + +func TestContextCancellationType2(t *testing.T) { + var ( + q1 = "CREATE DATABASE IF NOT EXISTS test_query_cancellation" + q2 = "DROP TABLE IF EXISTS test_query_cancellation.trips" + q3 = `CREATE TABLE test_query_cancellation.trips ( + trip_id UInt32, + pickup_datetime DateTime, + dropoff_datetime DateTime, + pickup_longitude Nullable(Float64), + pickup_latitude Nullable(Float64), + dropoff_longitude Nullable(Float64), + dropoff_latitude Nullable(Float64), + passenger_count UInt8, + trip_distance Float32, + fare_amount Float32, + extra Float32, + tip_amount Float32, + tolls_amount Float32, + total_amount Float32, + payment_type Enum('CSH' = 1, 'CRE' = 2, 'NOC' = 3, 'DIS' = 4, 'UNK' = 5), + pickup_ntaname LowCardinality(String), + dropoff_ntaname LowCardinality(String) + ) + ENGINE = MergeTree + PRIMARY KEY (pickup_datetime, dropoff_datetime);` + q4 = `INSERT INTO test_query_cancellation.trips + SELECT + number + 1 AS trip_id, + now() - INTERVAL intDiv(number, 100) SECOND AS pickup_datetime, + now() - INTERVAL intDiv(number, 100) SECOND + INTERVAL rand() % 3600 SECOND AS dropoff_datetime, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 - 74.00) AS pickup_longitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 + 40.50) AS pickup_latitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 - 74.00) AS dropoff_longitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 + 40.50) AS dropoff_latitude, + rand() % 6 + 1 AS passenger_count, + (rand() % 2000) / 100.0 AS trip_distance, + (rand() % 5000) / 100.0 AS fare_amount, + (rand() % 500) / 100.0 AS extra, + (rand() % 1000) / 100.0 AS tip_amount, + (rand() % 300) / 100.0 AS tolls_amount, + (rand() % 6000) / 100.0 AS total_amount, + CAST(rand() % 5 + 1 AS Enum('CSH' = 1, 'CRE' = 2, 'NOC' = 3, 'DIS' = 4, 'UNK' = 5)) AS payment_type, + 'Neighborhood ' || toString(rand() % 100 + 1) AS pickup_ntaname, + 'Neighborhood ' || toString(rand() % 100 + 1) AS dropoff_ntaname + FROM numbers(30000000);` + ) + + prepareQueries := []string{q1, q2, q3, q4} + + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + + assert.Nil(t, err) + assert.NotNil(t, conn) + + if err = conn.Ping(context.Background()); err != nil { + return + } + + t.Log("Connected.") + + // prepare table + for _, query := range prepareQueries { + err = conn.Exec(context.Background(), query) + if err != nil { + log.Printf("Finished with error: %v\n", err) + conn.Close() + return + } + } + + // prepare context + ctx, cancelCtx := context.WithCancel(context.Background()) + defer cancelCtx() + + doneCh := make(chan bool, 1) + queryTimeCh := make(chan time.Duration, 1) + + // run query in background + go func() { + log.Println("Running heavy query...") + + start := time.Now() + + defer func() { + log.Printf("Query took: %v\n", time.Since(start)) + queryTimeCh <- time.Since(start) + doneCh <- true + }() + + err = conn.Exec(ctx, "OPTIMIZE TABLE test_query_cancellation.trips FINAL") + if err != nil { + log.Printf("Finished with error: %v\n", err) + return + } + }() + + cancelBackoff := 3 * time.Second + + // let workers run for awhile and stop + go func() { + time.Sleep(cancelBackoff) + cancelCtx() + log.Printf("Context cancelled after %v.", cancelBackoff) + }() + + <-doneCh + conn.Close() + log.Println("Done.") + + queryTime := <-queryTimeCh + + assert.Less(t, queryTime-cancelBackoff, time.Second) +}