Skip to content

Commit d670bf2

Browse files
authored
Merge pull request #152 from nats-io/connection-pool
add connection pool so we don't leak connections
2 parents b6bb02b + 0768570 commit d670bf2

11 files changed

+493
-110
lines changed

controllers/jetstream/conn_pool.go

+273
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
package jetstream
2+
3+
import (
4+
"crypto/sha256"
5+
"crypto/tls"
6+
"encoding/json"
7+
"fmt"
8+
"os"
9+
"sync"
10+
11+
"github.com/nats-io/nats.go"
12+
"github.com/sirupsen/logrus"
13+
"golang.org/x/sync/singleflight"
14+
)
15+
16+
type natsContext struct {
17+
Name string `json:"name"`
18+
URL string `json:"url"`
19+
JWT string `json:"jwt"`
20+
Seed string `json:"seed"`
21+
Credentials string `json:"credential"`
22+
Nkey string `json:"nkey"`
23+
Token string `json:"token"`
24+
Username string `json:"username"`
25+
Password string `json:"password"`
26+
TLSCAs []string `json:"tls_ca"`
27+
TLSCert string `json:"tls_cert"`
28+
TLSKey string `json:"tls_key"`
29+
}
30+
31+
func (c *natsContext) copy() *natsContext {
32+
if c == nil {
33+
return nil
34+
}
35+
cp := *c
36+
return &cp
37+
}
38+
39+
func (c *natsContext) hash() (string, error) {
40+
b, err := json.Marshal(c)
41+
if err != nil {
42+
return "", fmt.Errorf("error marshaling context to json: %v", err)
43+
}
44+
if c.Nkey != "" {
45+
fb, err := os.ReadFile(c.Nkey)
46+
if err != nil {
47+
return "", fmt.Errorf("error opening nkey file %s: %v", c.Nkey, err)
48+
}
49+
b = append(b, fb...)
50+
}
51+
if c.Credentials != "" {
52+
fb, err := os.ReadFile(c.Credentials)
53+
if err != nil {
54+
return "", fmt.Errorf("error opening creds file %s: %v", c.Credentials, err)
55+
}
56+
b = append(b, fb...)
57+
}
58+
if len(c.TLSCAs) > 0 {
59+
for _, cert := range c.TLSCAs {
60+
fb, err := os.ReadFile(cert)
61+
if err != nil {
62+
return "", fmt.Errorf("error opening ca file %s: %v", cert, err)
63+
}
64+
b = append(b, fb...)
65+
}
66+
}
67+
if c.TLSCert != "" {
68+
fb, err := os.ReadFile(c.TLSCert)
69+
if err != nil {
70+
return "", fmt.Errorf("error opening cert file %s: %v", c.TLSCert, err)
71+
}
72+
b = append(b, fb...)
73+
}
74+
if c.TLSKey != "" {
75+
fb, err := os.ReadFile(c.TLSKey)
76+
if err != nil {
77+
return "", fmt.Errorf("error opening key file %s: %v", c.TLSKey, err)
78+
}
79+
b = append(b, fb...)
80+
}
81+
hash := sha256.New()
82+
hash.Write(b)
83+
return fmt.Sprintf("%x", hash.Sum(nil)), nil
84+
}
85+
86+
type natsContextDefaults struct {
87+
Name string
88+
URL string
89+
TLSCAs []string
90+
TLSCert string
91+
TLSKey string
92+
TLSConfig *tls.Config
93+
}
94+
95+
type pooledNatsConn struct {
96+
nc *nats.Conn
97+
cp *natsConnPool
98+
key string
99+
count uint64
100+
closed bool
101+
}
102+
103+
func (pc *pooledNatsConn) ReturnToPool() {
104+
pc.cp.Lock()
105+
pc.count--
106+
if pc.count == 0 {
107+
if pooledConn, ok := pc.cp.cache[pc.key]; ok && pc == pooledConn {
108+
delete(pc.cp.cache, pc.key)
109+
}
110+
pc.closed = true
111+
pc.cp.Unlock()
112+
pc.nc.Close()
113+
return
114+
}
115+
pc.cp.Unlock()
116+
}
117+
118+
type natsConnPool struct {
119+
sync.Mutex
120+
cache map[string]*pooledNatsConn
121+
logger *logrus.Logger
122+
group *singleflight.Group
123+
natsDefaults *natsContextDefaults
124+
natsOpts []nats.Option
125+
}
126+
127+
func newNatsConnPool(logger *logrus.Logger, natsDefaults *natsContextDefaults, natsOpts []nats.Option) *natsConnPool {
128+
return &natsConnPool{
129+
cache: map[string]*pooledNatsConn{},
130+
group: &singleflight.Group{},
131+
logger: logger,
132+
natsDefaults: natsDefaults,
133+
natsOpts: natsOpts,
134+
}
135+
}
136+
137+
const getPooledConnMaxTries = 10
138+
139+
// Get returns a *pooledNatsConn
140+
func (cp *natsConnPool) Get(cfg *natsContext) (*pooledNatsConn, error) {
141+
if cfg == nil {
142+
return nil, fmt.Errorf("nats context must not be nil")
143+
}
144+
145+
// copy cfg
146+
cfg = cfg.copy()
147+
148+
// set defaults
149+
if cfg.Name == "" {
150+
cfg.Name = cp.natsDefaults.Name
151+
}
152+
if cfg.URL == "" {
153+
cfg.URL = cp.natsDefaults.URL
154+
}
155+
if len(cfg.TLSCAs) == 0 {
156+
cfg.TLSCAs = cp.natsDefaults.TLSCAs
157+
}
158+
if cfg.TLSCert == "" {
159+
cfg.TLSCert = cp.natsDefaults.TLSCert
160+
}
161+
if cfg.TLSKey == "" {
162+
cfg.TLSKey = cp.natsDefaults.TLSKey
163+
}
164+
165+
// get hash
166+
key, err := cfg.hash()
167+
if err != nil {
168+
return nil, err
169+
}
170+
171+
for i := 0; i < getPooledConnMaxTries; i++ {
172+
connection, err := cp.getPooledConn(key, cfg)
173+
if err != nil {
174+
return nil, err
175+
}
176+
177+
cp.Lock()
178+
if connection.closed {
179+
// ReturnToPool closed this while lock not held, try again
180+
cp.Unlock()
181+
continue
182+
}
183+
184+
// increment count out of the pool
185+
connection.count++
186+
cp.Unlock()
187+
return connection, nil
188+
}
189+
190+
return nil, fmt.Errorf("failed to get pooled connection after %d attempts", getPooledConnMaxTries)
191+
}
192+
193+
// getPooledConn gets or establishes a *pooledNatsConn in a singleflight group, but does not increment its count
194+
func (cp *natsConnPool) getPooledConn(key string, cfg *natsContext) (*pooledNatsConn, error) {
195+
conn, err, _ := cp.group.Do(key, func() (interface{}, error) {
196+
cp.Lock()
197+
pooledConn, ok := cp.cache[key]
198+
if ok && pooledConn.nc.IsConnected() {
199+
cp.Unlock()
200+
return pooledConn, nil
201+
}
202+
cp.Unlock()
203+
204+
opts := cp.natsOpts
205+
opts = append(opts, func(options *nats.Options) error {
206+
if cfg.Name != "" {
207+
options.Name = cfg.Name
208+
}
209+
if cfg.Token != "" {
210+
options.Token = cfg.Token
211+
}
212+
if cfg.Username != "" {
213+
options.User = cfg.Username
214+
}
215+
if cfg.Password != "" {
216+
options.Password = cfg.Password
217+
}
218+
return nil
219+
})
220+
221+
if cfg.JWT != "" && cfg.Seed != "" {
222+
opts = append(opts, nats.UserJWTAndSeed(cfg.JWT, cfg.Seed))
223+
}
224+
225+
if cfg.Nkey != "" {
226+
opt, err := nats.NkeyOptionFromSeed(cfg.Nkey)
227+
if err != nil {
228+
return nil, fmt.Errorf("unable to load nkey: %v", err)
229+
}
230+
opts = append(opts, opt)
231+
}
232+
233+
if cfg.Credentials != "" {
234+
opts = append(opts, nats.UserCredentials(cfg.Credentials))
235+
}
236+
237+
if len(cfg.TLSCAs) > 0 {
238+
opts = append(opts, nats.RootCAs(cfg.TLSCAs...))
239+
}
240+
241+
if cfg.TLSCert != "" && cfg.TLSKey != "" {
242+
opts = append(opts, nats.ClientCert(cfg.TLSCert, cfg.TLSKey))
243+
}
244+
245+
nc, err := nats.Connect(cfg.URL, opts...)
246+
if err != nil {
247+
return nil, err
248+
}
249+
cp.logger.Infof("%s connected to NATS Deployment: %s", cfg.Name, nc.ConnectedAddr())
250+
251+
connection := &pooledNatsConn{
252+
nc: nc,
253+
cp: cp,
254+
key: key,
255+
}
256+
257+
cp.Lock()
258+
cp.cache[key] = connection
259+
cp.Unlock()
260+
261+
return connection, err
262+
})
263+
264+
if err != nil {
265+
return nil, err
266+
}
267+
268+
connection, ok := conn.(*pooledNatsConn)
269+
if !ok {
270+
return nil, fmt.Errorf("not a pooledNatsConn")
271+
}
272+
return connection, nil
273+
}
+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package jetstream
2+
3+
import (
4+
"sync"
5+
"testing"
6+
"time"
7+
8+
"github.com/nats-io/nats.go"
9+
10+
natsservertest "github.com/nats-io/nats-server/v2/test"
11+
"github.com/sirupsen/logrus"
12+
testifyAssert "github.com/stretchr/testify/assert"
13+
)
14+
15+
func TestConnPool(t *testing.T) {
16+
t.Parallel()
17+
18+
s := natsservertest.RunRandClientPortServer()
19+
defer s.Shutdown()
20+
o1 := &natsContext{
21+
Name: "Client 1",
22+
}
23+
o2 := &natsContext{
24+
Name: "Client 1",
25+
}
26+
o3 := &natsContext{
27+
Name: "Client 2",
28+
}
29+
30+
natsDefaults := &natsContextDefaults{
31+
URL: s.ClientURL(),
32+
}
33+
natsOptions := []nats.Option{
34+
nats.MaxReconnects(10240),
35+
}
36+
cp := newNatsConnPool(logrus.New(), natsDefaults, natsOptions)
37+
38+
var c1, c2, c3 *pooledNatsConn
39+
var c1e, c2e, c3e error
40+
wg := &sync.WaitGroup{}
41+
wg.Add(3)
42+
go func() {
43+
c1, c1e = cp.Get(o1)
44+
wg.Done()
45+
}()
46+
go func() {
47+
c2, c2e = cp.Get(o2)
48+
wg.Done()
49+
}()
50+
go func() {
51+
c3, c3e = cp.Get(o3)
52+
wg.Done()
53+
}()
54+
wg.Wait()
55+
56+
assert := testifyAssert.New(t)
57+
if assert.NoError(c1e) && assert.NoError(c2e) {
58+
assert.Same(c1, c2)
59+
}
60+
if assert.NoError(c3e) {
61+
assert.NotSame(c1, c3)
62+
assert.NotSame(c2, c3)
63+
}
64+
65+
c1.ReturnToPool()
66+
c3.ReturnToPool()
67+
time.Sleep(1 * time.Second)
68+
assert.False(c1.nc.IsClosed())
69+
assert.False(c2.nc.IsClosed())
70+
assert.True(c3.nc.IsClosed())
71+
72+
c4, c4e := cp.Get(o1)
73+
if assert.NoError(c4e) {
74+
assert.Same(c2, c4)
75+
}
76+
77+
c2.ReturnToPool()
78+
c4.ReturnToPool()
79+
time.Sleep(1 * time.Second)
80+
assert.True(c1.nc.IsClosed())
81+
assert.True(c2.nc.IsClosed())
82+
assert.True(c4.nc.IsClosed())
83+
84+
c5, c5e := cp.Get(o1)
85+
if assert.NoError(c5e) {
86+
assert.NotSame(c1, c5)
87+
}
88+
89+
c5.ReturnToPool()
90+
time.Sleep(1 * time.Second)
91+
assert.True(c5.nc.IsClosed())
92+
}

0 commit comments

Comments
 (0)