Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge upstream #385

Merged
merged 5 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ func TestBatch_Errors(t *testing.T) {
t.Fatal(err)
}

b := session.NewBatch(LoggedBatch)
b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil)
if err := session.ExecuteBatch(b); err == nil {
b := session.Batch(LoggedBatch)
b = b.Query("SELECT * FROM gocql_test.batch_errors WHERE id=2 AND val=?", nil)
if err := b.Exec(); err == nil {
t.Fatal("expected to get error for invalid query in batch")
}
}
Expand All @@ -44,15 +44,17 @@ func TestBatch_WithTimestamp(t *testing.T) {

micros := time.Now().UnixNano()/1e3 - 1000

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.WithTimestamp(micros)
b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val")
if err := session.ExecuteBatch(b); err != nil {
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 1, "val")
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 2, "val")

if err := b.Exec(); err != nil {
t.Fatal(err)
}

var storedTs int64
if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
if err := session.Query(`SELECT writetime(val) FROM gocql_test.batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
t.Fatal(err)
}

Expand Down
34 changes: 17 additions & 17 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"time"
"unicode"

inf "gopkg.in/inf.v0"
"gopkg.in/inf.v0"
)

func TestEmptyHosts(t *testing.T) {
Expand Down Expand Up @@ -565,15 +565,15 @@ func TestCAS(t *testing.T) {
t.Fatal("truncate:", err)
}

successBatch := session.NewBatch(LoggedBatch)
successBatch := session.Batch(LoggedBatch)
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
t.Fatal("insert:", err)
} else if !applied {
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
}

successBatch = session.NewBatch(LoggedBatch)
successBatch = session.Batch(LoggedBatch)
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title+"_foo", revid, modified)
casMap := make(map[string]interface{})
if applied, _, err := session.MapExecuteBatchCAS(successBatch, casMap); err != nil {
Expand All @@ -582,22 +582,22 @@ func TestCAS(t *testing.T) {
t.Fatal("insert should have been applied")
}

failBatch := session.NewBatch(LoggedBatch)
failBatch := session.Batch(LoggedBatch)
failBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
t.Fatal("insert:", err)
} else if applied {
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
}

insertBatch := session.NewBatch(LoggedBatch)
insertBatch := session.Batch(LoggedBatch)
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
if err := session.ExecuteBatch(insertBatch); err != nil {
t.Fatal("insert:", err)
}

failBatch = session.NewBatch(LoggedBatch)
failBatch = session.Batch(LoggedBatch)
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
Expand Down Expand Up @@ -722,7 +722,7 @@ func TestBatch(t *testing.T) {
t.Fatal("create table:", err)
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
for i := 0; i < 100; i++ {
batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i)
}
Expand Down Expand Up @@ -754,9 +754,9 @@ func TestUnpreparedBatch(t *testing.T) {

var batch *Batch
if session.cfg.ProtoVersion == 2 {
batch = session.NewBatch(CounterBatch)
batch = session.Batch(CounterBatch)
} else {
batch = session.NewBatch(UnloggedBatch)
batch = session.Batch(UnloggedBatch)
}

for i := 0; i < 100; i++ {
Expand Down Expand Up @@ -795,7 +795,7 @@ func TestBatchLimit(t *testing.T) {
t.Fatal("create table:", err)
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
for i := 0; i < 65537; i++ {
batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i)
}
Expand Down Expand Up @@ -849,7 +849,7 @@ func TestTooManyQueryArgs(t *testing.T) {
t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an error")
}

batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3)
err = session.ExecuteBatch(batch)

Expand Down Expand Up @@ -881,7 +881,7 @@ func TestNotEnoughQueryArgs(t *testing.T) {
t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an error")
}

batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2)
err = session.ExecuteBatch(batch)

Expand Down Expand Up @@ -1454,7 +1454,7 @@ func TestBatchQueryInfo(t *testing.T) {
return values, nil
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write)

if err := session.ExecuteBatch(batch); err != nil {
Expand Down Expand Up @@ -1582,7 +1582,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) {
}

stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query(stmt, "bar")
if err := conn.executeBatch(ctx, batch).Close(); err != nil {
t.Fatalf("Failed to execute query for reprepare statement: %v", err)
Expand Down Expand Up @@ -1966,7 +1966,7 @@ func TestBatchStats(t *testing.T) {
t.Fatalf("failed to create table with error '%v'", err)
}

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.Query("INSERT INTO batchStats (id) VALUES (?)", 1)
b.Query("INSERT INTO batchStats (id) VALUES (?)", 2)

Expand Down Expand Up @@ -2009,7 +2009,7 @@ func TestBatchObserve(t *testing.T) {

var observedBatch *observation

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
batch.Observer(funcBatchObserver(func(ctx context.Context, o ObservedBatch) {
if observedBatch != nil {
t.Fatal("batch observe called more than once")
Expand Down Expand Up @@ -2632,7 +2632,7 @@ func TestUnsetColBatch(t *testing.T) {
t.Fatalf("failed to create table with error '%v'", err)
}

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue)
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "")
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue)
Expand Down
8 changes: 7 additions & 1 deletion cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ type ClusterConfig struct {
// Initial keyspace. Optional.
Keyspace string

// Number of connections per host.
// The size of the connection pool for each host.
// The pool filling runs in separate gourutine during the session initialization phase.
// gocql will always try to get 1 connection on each host pool
// during session initialization AND it will attempt
// to fill each pool afterward asynchronously if NumConns > 1.
// Notice: There is no guarantee that pool filling will be finished in the initialization phase.
// Also, it describes a maximum number of connections at the same time.
// Default: 2
NumConns int

Expand Down
30 changes: 11 additions & 19 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,11 @@ import (
"github.com/gocql/gocql/internal/streams"
)

var (
defaultApprovedAuthenticators = []string{
"org.apache.cassandra.auth.PasswordAuthenticator",
"com.instaclustr.cassandra.auth.SharedSecretAuthenticator",
"com.datastax.bdp.cassandra.auth.DseAuthenticator",
"io.aiven.cassandra.auth.AivenAuthenticator",
"com.ericsson.bss.cassandra.ecaudit.auth.AuditPasswordAuthenticator",
"com.amazon.helenus.auth.HelenusAuthenticator",
"com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator",
"com.scylladb.auth.SaslauthdAuthenticator",
"com.scylladb.auth.TransitionalAuthenticator",
"com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator",
}
)

// approve the authenticator with the list of allowed authenticators or default list if approvedAuthenticators is empty.
// approve the authenticator with the list of allowed authenticators. If the provided list is empty,
// the given authenticator is allowed.
func approve(authenticator string, approvedAuthenticators []string) bool {
if len(approvedAuthenticators) == 0 {
approvedAuthenticators = defaultApprovedAuthenticators
return true
}
for _, s := range approvedAuthenticators {
if authenticator == s {
Expand Down Expand Up @@ -72,9 +58,15 @@ type WarningHandler interface {
HandleWarnings(qry ExecutableQuery, host *HostInfo, warnings []string)
}

// PasswordAuthenticator specifies credentials to be used when authenticating.
// It can be configured with an "allow list" of authenticator class names to avoid
// attempting to authenticate with Cassandra if it doesn't provide an expected authenticator.
type PasswordAuthenticator struct {
Username string
Password string
Username string
Password string
// Setting this to nil or empty will allow authenticating with any authenticator
// provided by the server. This is the default behavior of most other driver
// implementations.
AllowedAuthenticators []string
}

Expand Down
27 changes: 15 additions & 12 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,21 @@ const (

func TestApprove(t *testing.T) {
tests := map[bool]bool{
approve("org.apache.cassandra.auth.PasswordAuthenticator", []string{}): true,
approve("com.instaclustr.cassandra.auth.SharedSecretAuthenticator", []string{}): true,
approve("com.datastax.bdp.cassandra.auth.DseAuthenticator", []string{}): true,
approve("io.aiven.cassandra.auth.AivenAuthenticator", []string{}): true,
approve("com.amazon.helenus.auth.HelenusAuthenticator", []string{}): true,
approve("com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator", []string{}): true,
approve("com.scylladb.auth.SaslauthdAuthenticator", []string{}): true,
approve("com.scylladb.auth.TransitionalAuthenticator", []string{}): true,
approve("com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator", []string{}): true,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{}): false,
approve("com.apache.cassandra.auth.FakeAuthenticator", nil): false,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.FakeAuthenticator"}): true,
approve("org.apache.cassandra.auth.PasswordAuthenticator", []string{}): true,
approve("org.apache.cassandra.auth.MutualTlsWithPasswordFallbackAuthenticator", []string{}): true,
approve("org.apache.cassandra.auth.MutualTlsAuthenticator", []string{}): true,
approve("com.instaclustr.cassandra.auth.SharedSecretAuthenticator", []string{}): true,
approve("com.datastax.bdp.cassandra.auth.DseAuthenticator", []string{}): true,
approve("io.aiven.cassandra.auth.AivenAuthenticator", []string{}): true,
approve("com.amazon.helenus.auth.HelenusAuthenticator", []string{}): true,
approve("com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator", []string{}): true,
approve("com.scylladb.auth.SaslauthdAuthenticator", []string{}): true,
approve("com.scylladb.auth.TransitionalAuthenticator", []string{}): true,
approve("com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator", []string{}): true,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{}): true,
approve("com.apache.cassandra.auth.FakeAuthenticator", nil): true,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.FakeAuthenticator"}): true,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.NotFakeAuthenticator"}): false,
}
for k, v := range tests {
if k != v {
Expand Down
12 changes: 11 additions & 1 deletion doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@
// }
// defer session.Close()
//
// By default, PasswordAuthenticator will attempt to authenticate regardless of what implementation the server returns
// in its AUTHENTICATE message as its authenticator, (e.g. org.apache.cassandra.auth.PasswordAuthenticator). If you
// wish to restrict this you may use PasswordAuthenticator.AllowedAuthenticators:
//
// cluster.Authenticator = gocql.PasswordAuthenticator {
// Username: "user",
// Password: "password"
// AllowedAuthenticators: []string{"org.apache.cassandra.auth.PasswordAuthenticator"},
// }
//
// # Transport layer security
//
// It is possible to secure traffic between the client and server with TLS.
Expand Down Expand Up @@ -280,7 +290,7 @@
// # Batches
//
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
// Use Session.NewBatch to create a new batch and then fill-in details of individual queries.
// Use Session.Batch to create a new batch and then fill-in details of individual queries.
// Then execute the batch with Session.ExecuteBatch.
//
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have
Expand Down
15 changes: 13 additions & 2 deletions example_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package gocql_test
import (
"context"
"fmt"
"github.com/gocql/gocql"
"log"

"github.com/gocql/gocql"
)

// Example_batch demonstrates how to execute a batch of statements.
Expand All @@ -24,7 +25,7 @@ func Example_batch() {

ctx := context.Background()

b := session.NewBatch(gocql.UnloggedBatch).WithContext(ctx)
b := session.Batch(gocql.UnloggedBatch).WithContext(ctx)
b.Entries = append(b.Entries, gocql.BatchEntry{
Stmt: "INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)",
Args: []interface{}{1, 2, "1.2"},
Expand All @@ -35,11 +36,19 @@ func Example_batch() {
Args: []interface{}{1, 3, "1.3"},
Idempotent: true,
})

err = session.ExecuteBatch(b)
if err != nil {
log.Fatal(err)
}

err = b.Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 4, "1.4").
Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 5, "1.5").
Exec()
if err != nil {
log.Fatal(err)
}

scanner := session.Query("SELECT pk, ck, description FROM example.batches").Iter().Scanner()
for scanner.Next() {
var pk, ck int32
Expand All @@ -52,4 +61,6 @@ func Example_batch() {
}
// 1 2 1.2
// 1 3 1.3
// 1 4 1.4
// 1 5 1.5
}
5 changes: 3 additions & 2 deletions example_lwt_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package gocql_test
import (
"context"
"fmt"
"github.com/gocql/gocql"
"log"

"github.com/gocql/gocql"
)

// ExampleSession_MapExecuteBatchCAS demonstrates how to execute a batch lightweight transaction.
Expand Down Expand Up @@ -37,7 +38,7 @@ func ExampleSession_MapExecuteBatchCAS() {
}

executeBatch := func(ck2Version int) {
b := session.NewBatch(gocql.LoggedBatch)
b := session.Batch(gocql.LoggedBatch)
b.Entries = append(b.Entries, gocql.BatchEntry{
Stmt: "UPDATE my_lwt_batch_table SET value=? WHERE pk=? AND ck=? IF version=?",
Args: []interface{}{"b", "pk1", "ck1", 1},
Expand Down
2 changes: 1 addition & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func TestCustomPayloadMessages(t *testing.T) {
iter.Close()

// Batch Message
b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.CustomPayload = customPayload
b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)")
if err := session.ExecuteBatch(b); err != nil {
Expand Down
Loading
Loading