Skip to content

Commit 7ad2712

Browse files
committed
Add connection pooling for controller-runtime controllers
1 parent e3bf14d commit 7ad2712

12 files changed

+332
-49
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ $(ENVTEST): $(LOCALBIN)
209209
test: envtest
210210
go vet ./controllers/... ./pkg/natsreloader/... ./internal/controller/...
211211
$(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path ## Get k8s binaries
212-
go test -race -cover -count=1 -timeout 10s ./controllers/... ./pkg/natsreloader/... ./internal/controller/...
212+
go test -race -cover -count=1 -timeout 30s ./controllers/... ./pkg/natsreloader/... ./internal/controller/...
213213

214214
.PHONY: clean
215215
clean:

cicd/Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#syntax=docker/dockerfile:1.13
22
ARG GO_APP
33

4-
FROM alpine:3.21.3 as deps
4+
FROM alpine:3.21.3 AS deps
55

66
ARG GO_APP
77
ARG GORELEASER_DIST_DIR=/go/src/dist

cicd/Dockerfile_goreleaser

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#syntax=docker/dockerfile:1.13
2-
FROM --platform=$BUILDPLATFORM golang:1.24.0-bullseye as build
2+
FROM --platform=$BUILDPLATFORM golang:1.24.0-bullseye AS build
33

44

55
RUN <<EOT

internal/controller/client.go

+86-31
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,90 @@
11
package controller
22

33
import (
4+
"crypto/sha256"
5+
"encoding/json"
46
"fmt"
7+
"os"
58

69
"github.com/nats-io/jsm.go"
710
"github.com/nats-io/nats.go"
811
"github.com/nats-io/nats.go/jetstream"
912
)
1013

1114
type NatsConfig struct {
12-
ClientName string
13-
ServerURL string
14-
Certificate string
15-
Key string
16-
TLSFirst bool
17-
CAs []string
18-
Credentials string
19-
NKey string
20-
Token string
21-
User string
22-
Password string
15+
ClientName string `json:"name,omitempty"`
16+
ServerURL string `json:"url,omitempty"`
17+
Certificate string `json:"tls_cert,omitempty"`
18+
Key string `json:"tls_key,omitempty"`
19+
TLSFirst bool `json:"tls_first,omitempty"`
20+
CAs []string `json:"tls_ca,omitempty"`
21+
Credentials string `json:"credential,omitempty"`
22+
NKey string `json:"nkey,omitempty"`
23+
Token string `json:"token,omitempty"`
24+
User string `json:"username,omitempty"`
25+
Password string `json:"password,omitempty"`
26+
}
27+
28+
func (o *NatsConfig) Copy() *NatsConfig {
29+
if o == nil {
30+
return nil
31+
}
32+
33+
cp := *o
34+
return &cp
35+
}
36+
37+
func (o *NatsConfig) Hash() (string, error) {
38+
b, err := json.Marshal(o)
39+
if err != nil {
40+
return "", fmt.Errorf("error marshaling config to json: %v", err)
41+
}
42+
43+
if o.NKey != "" {
44+
fb, err := os.ReadFile(o.NKey)
45+
if err != nil {
46+
return "", fmt.Errorf("error opening nkey file %s: %v", o.NKey, err)
47+
}
48+
b = append(b, fb...)
49+
}
50+
51+
if o.Credentials != "" {
52+
fb, err := os.ReadFile(o.Credentials)
53+
if err != nil {
54+
return "", fmt.Errorf("error opening creds file %s: %v", o.Credentials, err)
55+
}
56+
b = append(b, fb...)
57+
}
58+
59+
if len(o.CAs) > 0 {
60+
for _, cert := range o.CAs {
61+
fb, err := os.ReadFile(cert)
62+
if err != nil {
63+
return "", fmt.Errorf("error opening ca file %s: %v", cert, err)
64+
}
65+
b = append(b, fb...)
66+
}
67+
}
68+
69+
if o.Certificate != "" {
70+
fb, err := os.ReadFile(o.Certificate)
71+
if err != nil {
72+
return "", fmt.Errorf("error opening cert file %s: %v", o.Certificate, err)
73+
}
74+
b = append(b, fb...)
75+
}
76+
77+
if o.Key != "" {
78+
fb, err := os.ReadFile(o.Key)
79+
if err != nil {
80+
return "", fmt.Errorf("error opening key file %s: %v", o.Key, err)
81+
}
82+
b = append(b, fb...)
83+
}
84+
85+
hash := sha256.New()
86+
hash.Write(b)
87+
return fmt.Sprintf("%x", hash.Sum(nil)), nil
2388
}
2489

2590
func (o *NatsConfig) Overlay(overlay *NatsConfig) {
@@ -125,15 +190,10 @@ type Closable interface {
125190
Close()
126191
}
127192

128-
func CreateJSMClient(cfg *NatsConfig, pedantic bool) (*jsm.Manager, Closable, error) {
129-
nc, err := createNatsConn(cfg, pedantic)
193+
func CreateJSMClient(conn *pooledConnection, pedantic bool) (*jsm.Manager, error) {
194+
major, minor, _, err := versionComponents(conn.nc.ConnectedServerVersion())
130195
if err != nil {
131-
return nil, nil, fmt.Errorf("create nats connection: %w", err)
132-
}
133-
134-
major, minor, _, err := versionComponents(nc.ConnectedServerVersion())
135-
if err != nil {
136-
return nil, nil, fmt.Errorf("parse server version: %w", err)
196+
return nil, fmt.Errorf("parse server version: %w", err)
137197
}
138198

139199
// JetStream pedantic mode unsupported prior to NATS Server 2.11
@@ -146,28 +206,23 @@ func CreateJSMClient(cfg *NatsConfig, pedantic bool) (*jsm.Manager, Closable, er
146206
jsmOpts = append(jsmOpts, jsm.WithPedanticRequests())
147207
}
148208

149-
jsmClient, err := jsm.New(nc, jsmOpts...)
209+
jsmClient, err := jsm.New(conn.nc, jsmOpts...)
150210
if err != nil {
151-
return nil, nil, fmt.Errorf("new jsm client: %w", err)
211+
return nil, fmt.Errorf("new jsm client: %w", err)
152212
}
153213

154-
return jsmClient, nc, nil
214+
return jsmClient, nil
155215
}
156216

157217
// CreateJetStreamClient creates new Jetstream client with a connection based on the given NatsConfig.
158218
// Returns a jetstream.Jetstream client and the Closable of the underlying connection.
159219
// Close should be called when the client is no longer used.
160-
func CreateJetStreamClient(cfg *NatsConfig, pedantic bool) (jetstream.JetStream, Closable, error) {
161-
nc, err := createNatsConn(cfg, pedantic)
162-
if err != nil {
163-
return nil, nil, fmt.Errorf("create nats connection: %w", err)
164-
}
165-
166-
js, err := jetstream.New(nc)
220+
func CreateJetStreamClient(conn *pooledConnection, pedantic bool) (jetstream.JetStream, error) {
221+
js, err := jetstream.New(conn.nc)
167222
if err != nil {
168-
return nil, nil, fmt.Errorf("new jetstream: %w", err)
223+
return nil, fmt.Errorf("new jetstream: %w", err)
169224
}
170-
return js, nc, nil
225+
return js, nil
171226
}
172227

173228
func createNatsConn(cfg *NatsConfig, pedantic bool) (*nats.Conn, error) {
+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package controller
2+
3+
import (
4+
"sync"
5+
"time"
6+
7+
"github.com/nats-io/nats.go"
8+
)
9+
10+
type pooledConnection struct {
11+
nc *nats.Conn
12+
pool *connectionPool
13+
hash string
14+
refCount int
15+
}
16+
17+
func (pc *pooledConnection) Close() {
18+
if pc.pool != nil {
19+
pc.pool.release(pc.hash)
20+
} else if pc.nc != nil {
21+
pc.nc.Close() // Close directly if not pool-managed
22+
}
23+
}
24+
25+
type connectionPool struct {
26+
connections map[string]*pooledConnection
27+
gracePeriod time.Duration
28+
mu sync.Mutex
29+
}
30+
31+
func newConnPool(gracePeriod time.Duration) *connectionPool {
32+
return &connectionPool{
33+
connections: make(map[string]*pooledConnection),
34+
gracePeriod: gracePeriod,
35+
}
36+
}
37+
38+
func (p *connectionPool) Get(c *NatsConfig, pedantic bool) (*pooledConnection, error) {
39+
p.mu.Lock()
40+
defer p.mu.Unlock()
41+
42+
hash, err := c.Hash()
43+
if err != nil {
44+
// If hash fails, create a new non-pooled connection
45+
nc, err := createNatsConn(c, pedantic)
46+
if err != nil {
47+
return nil, err
48+
}
49+
return &pooledConnection{nc: nc}, nil
50+
}
51+
52+
if pc, ok := p.connections[hash]; ok && !pc.nc.IsClosed() {
53+
pc.refCount++
54+
return pc, nil
55+
}
56+
57+
nc, err := createNatsConn(c, pedantic)
58+
if err != nil {
59+
return nil, err
60+
}
61+
62+
pc := &pooledConnection{
63+
nc: nc,
64+
pool: p,
65+
hash: hash,
66+
refCount: 1,
67+
}
68+
p.connections[hash] = pc
69+
70+
return pc, nil
71+
}
72+
73+
func (p *connectionPool) release(hash string) {
74+
p.mu.Lock()
75+
defer p.mu.Unlock()
76+
77+
pc, ok := p.connections[hash]
78+
if !ok {
79+
return
80+
}
81+
82+
pc.refCount--
83+
if pc.refCount < 1 {
84+
go func() {
85+
if p.gracePeriod > 0 {
86+
time.Sleep(p.gracePeriod)
87+
}
88+
89+
p.mu.Lock()
90+
defer p.mu.Unlock()
91+
92+
if pc, ok := p.connections[hash]; ok && pc.refCount < 1 {
93+
pc.nc.Close()
94+
delete(p.connections, hash)
95+
}
96+
}()
97+
}
98+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package controller
2+
3+
import (
4+
"sync"
5+
"testing"
6+
"time"
7+
8+
natsservertest "github.com/nats-io/nats-server/v2/test"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestConnPool(t *testing.T) {
13+
t.Parallel()
14+
15+
s := natsservertest.RunRandClientPortServer()
16+
defer s.Shutdown()
17+
18+
c1 := &NatsConfig{
19+
ClientName: "Client 1",
20+
ServerURL: s.ClientURL(),
21+
}
22+
23+
c2 := &NatsConfig{
24+
ClientName: "Client 1",
25+
ServerURL: s.ClientURL(),
26+
}
27+
28+
c3 := &NatsConfig{
29+
ClientName: "Client 2",
30+
ServerURL: s.ClientURL(),
31+
}
32+
33+
pool := newConnPool(0)
34+
35+
var conn1, conn2, conn3 *pooledConnection
36+
var err1, err2, err3 error
37+
38+
wg := &sync.WaitGroup{}
39+
wg.Add(3)
40+
41+
go func() {
42+
conn1, err1 = pool.Get(c1, true)
43+
wg.Done()
44+
}()
45+
go func() {
46+
conn2, err2 = pool.Get(c2, true)
47+
wg.Done()
48+
}()
49+
go func() {
50+
conn3, err3 = pool.Get(c3, true)
51+
wg.Done()
52+
}()
53+
wg.Wait()
54+
55+
require := require.New(t)
56+
57+
require.NoError(err1)
58+
require.NoError(err2)
59+
require.NoError(err3)
60+
61+
require.Same(conn1, conn2)
62+
require.NotSame(conn1, conn3)
63+
require.NotSame(conn2, conn3)
64+
65+
conn1.Close()
66+
conn3.Close()
67+
68+
time.Sleep(time.Second)
69+
70+
require.False(conn1.nc.IsClosed())
71+
require.False(conn2.nc.IsClosed())
72+
require.True(conn3.nc.IsClosed())
73+
74+
conn4, err4 := pool.Get(c1, true)
75+
require.NoError(err4)
76+
require.Same(conn1, conn4)
77+
require.Same(conn2, conn4)
78+
79+
conn2.Close()
80+
conn4.Close()
81+
82+
time.Sleep(time.Second)
83+
84+
require.True(conn1.nc.IsClosed())
85+
require.True(conn2.nc.IsClosed())
86+
require.True(conn3.nc.IsClosed())
87+
require.True(conn4.nc.IsClosed())
88+
89+
conn5, err5 := pool.Get(c1, true)
90+
require.NoError(err5)
91+
require.NotSame(conn1, conn5)
92+
93+
conn5.Close()
94+
95+
time.Sleep(time.Second)
96+
97+
require.True(conn5.nc.IsClosed())
98+
}

internal/controller/consumer_controller_test.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,14 @@ var _ = Describe("Consumer Controller", func() {
541541
By("setting up the alternative server")
542542
altServer := CreateTestServer()
543543
defer altServer.Shutdown()
544+
545+
connPool := newConnPool(0)
546+
conn, err := connPool.Get(&NatsConfig{ServerURL: altServer.ClientURL()}, true)
547+
Expect(err).NotTo(HaveOccurred())
548+
544549
// Setup altClient for alternate server
545-
altClient, closer, err := CreateJetStreamClient(&NatsConfig{ServerURL: altServer.ClientURL()}, true)
546-
defer closer.Close()
550+
altClient, err := CreateJetStreamClient(conn, true)
551+
defer conn.Close()
547552
Expect(err).NotTo(HaveOccurred())
548553

549554
By("setting up the stream on the alternative server")

0 commit comments

Comments
 (0)