-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmux.go
286 lines (244 loc) · 7.25 KB
/
mux.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
package gomux
import (
"crypto/cipher"
"errors"
"net"
"sync"
"time"
"golang.org/x/crypto/blake2b"
"golang.org/x/crypto/chacha20poly1305"
)
// NOTE: This package makes heavy use of sync.Cond to manage concurrent streams
// multiplexed onto a single connection. sync.Cond is rarely used (since it is
// almost never the right tool for the job), and consequently, Go programmers
// tend to be unfamiliar with its semantics. Nevertheless, it is currently the
// only way to achieve optimal throughput in a stream multiplexer, so we make
// careful use of it here. Please make sure you understand sync.Cond thoroughly
// before attempting to modify this code.
// Errors relating to stream or mux shutdown.
var (
ErrClosedConn = errors.New("underlying connection was closed")
ErrClosedStream = errors.New("stream was gracefully closed")
ErrPeerClosedStream = errors.New("peer closed stream gracefully")
ErrPeerClosedConn = errors.New("peer closed mux gracefully")
ErrWriteClosed = errors.New("write end of stream closed")
)
// A Mux multiplexes multiple duplex Streams onto a single net.Conn.
type Mux struct {
conn net.Conn
acceptChan chan *Stream
aead cipher.AEAD
readMutex sync.Mutex
// subsequent fields are used by readLoop() and guarded by readMutex
readErr error
streams map[uint32]*Stream
nextID uint32
writeMutex sync.Mutex
// subsequent fields are used by writeLoop() and guarded by writeMutex
writeErr error
writeCond sync.Cond
bufferCond sync.Cond
writeBuf []byte
sendBuf []byte
writeBufA []byte
writeBufB []byte
}
// setErr sets the Mux error and wakes up all Mux-related goroutines. If m.err
// is already set, setErr is a no-op.
func (m *Mux) setErr(err error) error {
// Set readErr
m.readMutex.Lock()
defer m.readMutex.Unlock()
if m.readErr != nil {
return m.readErr
}
m.readErr = err
// Set writeErr
m.writeMutex.Lock()
defer m.writeMutex.Unlock()
if m.writeErr != nil {
return m.writeErr
}
m.writeErr = err
for _, s := range m.streams {
s.cond.L.Lock()
s.err = err
s.cond.L.Unlock()
s.cond.Broadcast()
}
m.conn.Close()
m.writeCond.Broadcast()
m.bufferCond.Broadcast()
close(m.acceptChan)
return err
}
// bufferFrame blocks until it can store its frame in m.writeBuf. It returns
// early with an error if m.err is set.
func (m *Mux) bufferFrame(h frameHeader, payload []byte) error {
m.writeMutex.Lock()
// block until we can add the frame to the buffer
for len(m.writeBuf)+frameHeaderSize+len(payload)+chacha20poly1305.Overhead > cap(m.writeBuf) && m.writeErr == nil {
m.bufferCond.Wait()
}
if m.writeErr != nil {
m.writeMutex.Unlock()
return m.writeErr
}
// queue our frame
m.writeBuf = appendFrame(m.writeBuf, m.aead, h, payload)
m.writeMutex.Unlock()
// wake the writeLoop
m.writeCond.Signal()
// wake at most one bufferFrame call
m.bufferCond.Signal()
return nil
}
// writeLoop handles the actual Writes to the Mux's net.Conn. It waits for
// bufferFrame calls to fill m.writeBuf, then flushes the buffer to the
// underlying connection. It also handles keepalives.
func (m *Mux) writeLoop() {
keepaliveInterval := time.Minute * 10
nextKeepalive := time.Now().Add(keepaliveInterval)
timer := time.AfterFunc(keepaliveInterval, m.writeCond.Signal)
defer timer.Stop()
for {
m.writeMutex.Lock()
for len(m.writeBuf) == 0 && m.writeErr == nil && time.Now().Before(nextKeepalive) {
m.writeCond.Wait()
}
if m.writeErr != nil {
m.writeMutex.Unlock()
return
}
// if we have a normal frame, use that; otherwise, send a keepalive
//
// NOTE: even if we were woken by the keepalive timer, there might be a
// normal frame ready to send, in which case we don't need a keepalive
if len(m.writeBuf) == 0 {
m.writeBuf = appendFrame(m.writeBuf[:0], m.aead, frameHeader{flags: flagKeepalive}, nil)
}
// to avoid blocking bufferFrame while we Write, swap writeBufA and writeBufB
m.writeBuf, m.sendBuf = m.sendBuf, m.writeBuf
m.writeMutex.Unlock()
// wake at most one bufferFrame call
m.bufferCond.Signal()
// reset keepalive timer
timer.Stop()
timer.Reset(keepaliveInterval)
nextKeepalive = time.Now().Add(keepaliveInterval)
// write the packet(s)
if _, err := m.conn.Write(m.sendBuf); err != nil {
m.setErr(err)
return
}
// clear sendBuf
m.sendBuf = m.sendBuf[:0]
}
}
// Delete stream from Mux
func (m *Mux) deleteStream(id uint32) {
m.readMutex.Lock()
delete(m.streams, id)
m.readMutex.Unlock()
}
// readLoop handles the actual Reads from the Mux's net.Conn. It waits for a
// frame to arrive, then routes it to the appropriate Stream, creating a new
// Stream if none exists.
func (m *Mux) readLoop() {
frameBuf := make([]byte, maxPayloadSize+chacha20poly1305.Overhead)
for {
h, payload, err := readFrame(m.conn, m.aead, frameBuf)
if err != nil {
m.setErr(err)
return
}
switch h.flags {
case flagKeepalive:
continue
case flagOpenStream:
m.readMutex.Lock()
s := newStream(h.id, m)
m.streams[h.id] = s
m.readMutex.Unlock()
m.acceptChan <- s
case flagCloseMux:
m.setErr(ErrPeerClosedConn)
return
default:
m.readMutex.Lock()
stream, found := m.streams[h.id]
m.readMutex.Unlock()
if found {
stream.consumeFrame(h, payload)
}
}
}
}
// Close closes the underlying net.Conn.
func (m *Mux) Close() error {
// tell perr we are shutting down
h := frameHeader{flags: flagCloseMux}
m.bufferFrame(h, nil)
// if there's a buffered Write, wait for it to be sent
m.writeMutex.Lock()
for len(m.writeBuf) > 0 && m.writeErr == nil {
m.bufferCond.Wait()
}
m.writeMutex.Unlock()
err := m.setErr(ErrClosedConn)
if err == ErrClosedConn {
return nil
}
return err
}
// AcceptStream waits for and returns the next peer-initiated Stream.
func (m *Mux) AcceptStream() (*Stream, error) {
if s, ok := <-m.acceptChan; ok {
return s, nil
}
return nil, m.readErr
}
// OpenStream creates a new Stream.
func (m *Mux) OpenStream() (*Stream, error) {
m.readMutex.Lock()
s := newStream(m.nextID, m)
m.streams[s.id] = s
m.nextID += 2 // int wraparound intended
m.readMutex.Unlock()
// send flagOpenStream to tell peer the stream exists
h := frameHeader{
id: s.id,
flags: flagOpenStream,
}
return s, m.bufferFrame(h, nil)
}
// newMux initializes a Mux and spawns its readLoop and writeLoop goroutines.
func newMux(conn net.Conn, startID uint32, psk string) *Mux {
m := &Mux{
conn: conn,
acceptChan: make(chan *Stream, 256),
streams: make(map[uint32]*Stream),
nextID: startID,
writeBufA: make([]byte, 0, writeBufferSize),
writeBufB: make([]byte, 0, writeBufferSize),
}
key := blake2b.Sum256([]byte(psk))
m.aead, _ = chacha20poly1305.NewX(key[:])
m.writeCond.L = &m.writeMutex // both conds use the same mutex
m.bufferCond.L = &m.writeMutex
m.writeBuf = m.writeBufA // initial writeBuf is writeBufA
m.sendBuf = m.writeBufB // initial sendBuf is writeBufB
go m.readLoop()
go m.writeLoop()
return m
}
// Client creates and initializes a new client-side Mux on the provided conn.
// Client takes overship of the conn.
func Client(conn net.Conn, psk string) *Mux {
return newMux(conn, 0, psk)
}
// Server creates and initializes a new server-side Mux on the provided conn.
// Server takes overship of the conn.
func Server(conn net.Conn, psk string) *Mux {
return newMux(conn, 1, psk)
}