Skip to content

Commit

Permalink
GT-292 Fix reusing same connection with different Authentication para…
Browse files Browse the repository at this point in the history
…meters (#452)

* Add test-case where same connection re-used with different Authentication params

* Add more comments of reusing same connection with different Auhtentication parameters

* Fix reusing same connection with different Authentication parameters passed via driver.NewClient
  • Loading branch information
nikita-vanyasin authored Dec 13, 2022
1 parent 875dd62 commit b89f42d
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## [master](https://github.com/arangodb/go-driver/tree/master) (N/A)
- Add support for `checksum` in Collections
- Fix reusing same connection with different Authentication parameters passed via driver.NewClient

## [1.4.0](https://github.com/arangodb/go-driver/tree/v1.4.0) (2022-10-04)
- Add `hex` property to analyzer's properties
Expand Down
17 changes: 13 additions & 4 deletions cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type ServerConnectionBuilder func(endpoint string) (driver.Connection, error)
// The given connections are existing connections to each of the servers.
func NewConnection(config ConnectionConfig, connectionBuilder ServerConnectionBuilder, endpoints []string) (driver.Connection, error) {
if connectionBuilder == nil {
return nil, driver.WithStack(driver.InvalidArgumentError{Message: "Must a connection builder"})
return nil, driver.WithStack(driver.InvalidArgumentError{Message: "Must provide a connection builder"})
}
if len(endpoints) == 0 {
return nil, driver.WithStack(driver.InvalidArgumentError{Message: "Must provide at least 1 endpoint"})
Expand Down Expand Up @@ -285,7 +285,7 @@ func (c *clusterConnection) UpdateEndpoints(endpoints []string) error {
return nil
}

// Configure the authentication used for this connection.
// SetAuthentication creates a copy of connection wrapper for given auth parameters.
func (c *clusterConnection) SetAuthentication(auth driver.Authentication) (driver.Connection, error) {
c.mutex.Lock()
defer c.mutex.Unlock()
Expand All @@ -300,11 +300,20 @@ func (c *clusterConnection) SetAuthentication(auth driver.Authentication) (drive
newServerConnections[i] = authConn
}

// Save authentication
// These two lines are not required for normal work but left for backward compatibility
// of SetAuthentication method - it was returning self object
c.auth = auth
c.servers = newServerConnections

return c, nil
return &clusterConnection{
connectionBuilder: c.connectionBuilder,
servers: c.servers,
endpoints: c.endpoints,
current: c.current,
mutex: sync.RWMutex{},
defaultTimeout: c.defaultTimeout,
auth: c.auth,
}, nil
}

// Protocols returns all protocols used by this connection.
Expand Down
4 changes: 2 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
velocypack "github.com/arangodb/go-velocypack"
)

// Connection is a connenction to a database server using a specific protocol.
// Connection is a connection to a database server using a specific protocol.
type Connection interface {
// NewRequest creates a new request with given method and path.
NewRequest(method, path string) (Request, error)
Expand All @@ -47,7 +47,7 @@ type Connection interface {
// UpdateEndpoints reconfigures the connection to use the given endpoints.
UpdateEndpoints(endpoints []string) error

// Configure the authentication used for this connection.
// SetAuthentication creates a copy of connection wrapper for given auth parameters.
SetAuthentication(Authentication) (Connection, error)

// Protocols returns all protocols used by this connection.
Expand Down
2 changes: 1 addition & 1 deletion http/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func (c *authenticatedConnection) UpdateEndpoints(endpoints []string) error {
return nil
}

// Configure the authentication used for this connection.
// SetAuthentication creates a copy of connection wrapper for given auth parameters.
func (c *authenticatedConnection) SetAuthentication(auth driver.Authentication) (driver.Connection, error) {
result, err := c.conn.SetAuthentication(auth)
if err != nil {
Expand Down
11 changes: 6 additions & 5 deletions http/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func (c *httpConnection) UpdateEndpoints(endpoints []string) error {
return nil
}

// Configure the authentication used for this connection.
// SetAuthentication creates a copy of connection wrapper for given auth parameters.
func (c *httpConnection) SetAuthentication(auth driver.Authentication) (driver.Connection, error) {
var httpAuth httpAuthentication
switch auth.Type() {
Expand Down Expand Up @@ -471,8 +471,8 @@ func (h *RepeatConnection) UpdateEndpoints(endpoints []string) error {
return h.conn.UpdateEndpoints(endpoints)
}

// Configure the authentication used for this connection.
// Returns ErrAuthenticationNotChanged in when the authentication is not changed.
// SetAuthentication configure the authentication used for this connection.
// Returns ErrAuthenticationNotChanged when the authentication is not changed.
func (h *RepeatConnection) SetAuthentication(authentication driver.Authentication) (driver.Connection, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
Expand All @@ -481,16 +481,17 @@ func (h *RepeatConnection) SetAuthentication(authentication driver.Authenticatio
return h, ErrAuthenticationNotChanged
}

_, err := h.conn.SetAuthentication(authentication)
newConn, err := h.conn.SetAuthentication(authentication)
if err != nil {
return nil, driver.WithStack(err)
}
h.conn = newConn
h.auth = authentication

return h, nil
}

// Protocols returns all protocols used by this connection.
func (h RepeatConnection) Protocols() driver.ProtocolSet {
func (h *RepeatConnection) Protocols() driver.ProtocolSet {
return h.conn.Protocols()
}
95 changes: 95 additions & 0 deletions test/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@ package test
import (
"context"
"crypto/tls"
"fmt"
"log"
httplib "net/http"
_ "net/http/pprof"
"os"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"

"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -500,3 +503,95 @@ func TestCreateClientHttpRepeatConnection(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, 2, requestRepeat.counter)
}

// TestClientConnectionReuse checks that reusing same connection with different auth parameters is possible using
func TestClientConnectionReuse(t *testing.T) {
if os.Getenv("TEST_CONNECTION") == "vst" {
t.Skip("not possible with VST connections by design")
return
}

c := createClientFromEnv(t, true)
ctx := context.Background()

prefix := t.Name()
dbUsers := map[string]driver.CreateDatabaseUserOptions{
prefix + "-db1": {UserName: prefix + "-user1", Password: "password1"},
prefix + "-db2": {UserName: prefix + "-user2", Password: "password2"},
}
for dbName, userOptions := range dbUsers {
ensureDatabase(ctx, c, dbName, &driver.CreateDatabaseOptions{
Users: []driver.CreateDatabaseUserOptions{userOptions},
Options: driver.CreateDatabaseDefaultOptions{},
}, t)
}

var wg sync.WaitGroup
const clientsPerDB = 20
startTime := time.Now()

const testDuration = time.Second * 10
if testing.Verbose() {
wg.Add(1)
go func() {
defer wg.Done()

for {
stats, _ := c.Statistics(ctx)
t.Logf("goroutine count: %d, server connections: %d", runtime.NumGoroutine(), stats.Client.HTTPConnections)
if time.Now().Sub(startTime) > testDuration {
break
}
time.Sleep(1 * time.Second)
}
}()
}

conn := createConnection(t, false)
for dbName, userOptions := range dbUsers {
t.Logf("Starting %d goroutines for DB %s ...", clientsPerDB, dbName)
for i := 0; i < clientsPerDB; i++ {
wg.Add(1)
go func(dbName string, userOptions driver.CreateDatabaseUserOptions, conn driver.Connection) {
defer wg.Done()
for {
if time.Now().Sub(startTime) > testDuration {
break
}

// the test will pass only if checkDBAccess is using mutex
err := checkDBAccess(ctx, conn, dbName, userOptions.UserName, userOptions.Password)
require.NoError(t, err)

time.Sleep(10 * time.Millisecond)
}
}(dbName, userOptions, conn)
}
}
wg.Wait()
}

func checkDBAccess(ctx context.Context, conn driver.Connection, dbName, username, password string) error {
client, err := driver.NewClient(driver.ClientConfig{
Connection: conn,
Authentication: driver.BasicAuthentication(username, password),
})
if err != nil {
return err
}

dbExists, err := client.DatabaseExists(ctx, dbName)
if err != nil {
return errors.Wrapf(err, "DatabaseExists failed")
}
if !dbExists {
return fmt.Errorf("db %s must exist for any user", dbName)
}

_, err = client.Database(ctx, dbName)
if err != nil {
return errors.Wrapf(err, "db %s must be accessible for user %s", dbName, username)
}

return nil
}

0 comments on commit b89f42d

Please sign in to comment.