Skip to content

Commit

Permalink
Add custom Dialer to ClientOptions
Browse files Browse the repository at this point in the history
GODRIVER-195

Change-Id: I4060ae2af015d13b0ba206eb0a597c319a550c49
  • Loading branch information
skriptble committed Apr 6, 2018
1 parent 6b85f6b commit 915670b
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 7 deletions.
19 changes: 12 additions & 7 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ const defaultLocalThreshold = 15 * time.Millisecond

// Client performs operations on a given topology.
type Client struct {
topology *topology.Topology
connString connstring.ConnString
localThreshold time.Duration
readPreference *readpref.ReadPref
readConcern *readconcern.ReadConcern
writeConcern *writeconcern.WriteConcern
topologyOptions []topology.Option
topology *topology.Topology
connString connstring.ConnString
localThreshold time.Duration
readPreference *readpref.ReadPref
readConcern *readconcern.ReadConcern
writeConcern *writeconcern.WriteConcern
}

// NewClient creates a new client to connect to a cluster specified by the uri.
Expand Down Expand Up @@ -78,7 +79,11 @@ func newClient(cs connstring.ConnString, opts *ClientOptions) (*Client, error) {
}
}

topo, err := topology.New(topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return client.connString }))
topts := append(
client.topologyOptions,
topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return client.connString }),
)
topo, err := topology.New(topts...)
if err != nil {
return nil, err
}
Expand Down
26 changes: 26 additions & 0 deletions mongo/client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ package mongo
import (
"time"

"github.com/mongodb/mongo-go-driver/core/connection"
"github.com/mongodb/mongo-go-driver/core/connstring"
"github.com/mongodb/mongo-go-driver/core/topology"
)

type option func(*Client) error
Expand Down Expand Up @@ -83,6 +85,30 @@ func (co *ClientOptions) ConnectTimeout(d time.Duration) *ClientOptions {
return &ClientOptions{next: co, opt: fn}
}

// Dialer specifies a custom dialer used to dial new connections to a server.
func (co *ClientOptions) Dialer(d Dialer) *ClientOptions {
var fn option = func(c *Client) error {
c.topologyOptions = append(
c.topologyOptions,
topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
return append(
opts,
topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option {
return append(
opts,
connection.WithDialer(func(connection.Dialer) connection.Dialer {
return d
}),
)
}),
)
}),
)
return nil
}
return &ClientOptions{next: co, opt: fn}
}

// HeartbeatInterval specifies the interval to wait between server monitoring checks.
func (co *ClientOptions) HeartbeatInterval(d time.Duration) *ClientOptions {
var fn option = func(c *Client) error {
Expand Down
26 changes: 26 additions & 0 deletions mongo/client_options_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package mongo

import (
"context"
"net"
"sync/atomic"
"testing"

"time"
Expand Down Expand Up @@ -97,3 +100,26 @@ func TestClientOptions_chainAll(t *testing.T) {
opts = opts.next
}
}

func TestClientOptions_CustomDialer(t *testing.T) {
td := &testDialer{d: &net.Dialer{}}
opts := ClientOpt.Dialer(td)
client, err := newClient(testutil.ConnString(t), opts)
require.NoError(t, err)
_, err = client.ListDatabases(context.Background(), nil)
require.NoError(t, err)
got := atomic.LoadInt32(&td.called)
if got < 1 {
t.Errorf("Custom dialer was not used when dialing new connections")
}
}

type testDialer struct {
called int32
d Dialer
}

func (td *testDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
atomic.AddInt32(&td.called, 1)
return td.d.DialContext(ctx, network, address)
}
7 changes: 7 additions & 0 deletions mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@
package mongo

import (
"context"
"errors"
"fmt"
"io"
"net"
"reflect"
"strings"

"github.com/mongodb/mongo-go-driver/bson"
"github.com/mongodb/mongo-go-driver/bson/objectid"
)

// Dialer is used to make network connections.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

// TransformDocument handles transforming a document of an allowable type into
// a *bson.Document. This method is called directly after most methods that
// have one or more parameters that are documents.
Expand Down

0 comments on commit 915670b

Please sign in to comment.