From 019832a029b10575f25bef202b4d78d91f4d17e1 Mon Sep 17 00:00:00 2001 From: Robin Thellend Date: Wed, 20 Nov 2024 11:42:06 -0800 Subject: [PATCH] Add support for WebSockets (#161) ### Description Add support for forwarding WebSocket requests to arbitrary TCP servers. WebSockets were already forwarded transparently to backends before, and that is not changing. The new feature lets tlsproxy itself handle the WebSocket request and forward them to any TCP servers. The content of BinaryMessages is streamed to the remote server, and data received from the server is sent back to the client also as BinaryMessages. ### Type of change * [x] New feature * [ ] Feature improvement * [ ] Bug fix * [ ] Documentation * [ ] Cleanup / refactoring * [ ] Other (please explain) ### How is this change tested ? * [ ] Unit tests * [x] Manual tests (explain) * [ ] Tests are not needed --- CHANGELOG.md | 6 +++ go.mod | 1 + go.sum | 2 + main.go | 11 ++++ proxy/config.go | 13 +++++ proxy/proxy.go | 11 ++++ proxy/websocket.go | 128 +++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 172 insertions(+) create mode 100644 proxy/websocket.go diff --git a/CHANGELOG.md b/CHANGELOG.md index b6d0be4..b31e31b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # TLSPROXY Release Notes +## next + +### :star2: New features + +* Add support for forwarding WebSocket requests to arbitrary TCP servers. WebSockets were already forwarded transparently to backends before, and that is not changing. The new feature lets tlsproxy itself handle the WebSocket request and forward them to any TCP servers. The content of BinaryMessages is streamed to the remote server, and data received from the server is sent back to the client also as BinaryMessages. + ## v0.12.0 ### :star: Feature improvement diff --git a/go.mod b/go.mod index 3f7839e..a56eb79 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/google/go-tpm v0.9.1 // indirect github.com/google/pprof v0.0.0-20241101162523-b92577c0c142 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/jonboulle/clockwork v0.4.0 // indirect github.com/onsi/ginkgo/v2 v2.21.0 // indirect github.com/quic-go/qpack v0.5.1 // indirect diff --git a/go.sum b/go.sum index 29f4944..6639462 100644 --- a/go.sum +++ b/go.sum @@ -39,6 +39,8 @@ github.com/google/pprof v0.0.0-20241101162523-b92577c0c142 h1:sAGdeJj0bnMgUNVeUp github.com/google/pprof v0.0.0-20241101162523-b92577c0c142/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= diff --git a/main.go b/main.go index a83f865..0211559 100644 --- a/main.go +++ b/main.go @@ -112,6 +112,17 @@ func main() { ctx, canc := context.WithTimeout(ctx, *shutdownGraceFlag) defer canc() + + go func() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGINT) + signal.Notify(ch, syscall.SIGTERM) + select { + case <-ch: + os.Exit(1) + case <-ctx.Done(): + } + }() p.Shutdown(ctx) } diff --git a/proxy/config.go b/proxy/config.go index 3f89f72..02bdf3b 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -166,6 +166,13 @@ type Config struct { // Each backend can be associated with one group. The group's limits // are shared between all the backends associated with it. BWLimits []*BWLimit `yaml:"bwLimits,omitempty"` + // WebSockets is a list of WebSocket endpoints and where they get + // forwarded. + // Incoming WebSocket requests are bridged to TCP connections. The + // content of BinaryMessages are streamed to the TCP server, and + // data received from the server is sent to the client also + // as BinaryMessages. + WebSockets []*WebSocketConfig `yaml:"webSockets,omitempty"` acceptProxyHeaderFrom []*net.IPNet } @@ -203,6 +210,12 @@ type TLSCertificate struct { CertFile string `yaml:"cert"` } +// WebSocketConfig specifies a WebSocket endpoint. +type WebSocketConfig struct { + Endpoint string `yaml:"endpoint"` + Address string `yaml:"address,omitempty"` +} + // Backend encapsulates the data of one backend. type Backend struct { // ServerNames is the list of all the server names for this service, diff --git a/proxy/proxy.go b/proxy/proxy.go index 988fc0d..13a7fce 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -55,6 +55,7 @@ import ( "github.com/c2FmZQ/storage/autocertcache" "github.com/c2FmZQ/storage/crypto" "github.com/c2FmZQ/tpm" + "github.com/gorilla/websocket" "github.com/pires/go-proxyproto" "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" @@ -117,6 +118,7 @@ type Proxy struct { mk crypto.MasterKey store *storage.Storage tokenManager *tokenmanager.TokenManager + wsUpgrader *websocket.Upgrader mu sync.RWMutex connClosed *sync.Cond @@ -791,6 +793,15 @@ func (p *Proxy) Reconfigure(cfg *Config) error { }, pp.Endpoint) } } + if len(cfg.WebSockets) > 0 && p.wsUpgrader == nil { + p.wsUpgrader = newWebSocketUpgrader() + } + for _, ws := range cfg.WebSockets { + addLocalHandler(localHandler{ + desc: "WebSocket Endpoint", + handler: logHandler(p.webSocketHandler(*ws)), + }, ws.Endpoint) + } for _, be := range backends { sort.Slice(be.localHandlers, func(i, j int) bool { a := be.localHandlers[i].host diff --git a/proxy/websocket.go b/proxy/websocket.go new file mode 100644 index 0000000..d089ca7 --- /dev/null +++ b/proxy/websocket.go @@ -0,0 +1,128 @@ +// MIT License +// +// Copyright (c) 2024 TTBT Enterprises LLC +// Copyright (c) 2024 Robin Thellend +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package proxy + +import ( + "io" + "net" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +func newWebSocketUpgrader() *websocket.Upgrader { + return &websocket.Upgrader{ + ReadBufferSize: 8192, + WriteBufferSize: 8192, + } +} + +func (p *Proxy) webSocketHandler(cfg WebSocketConfig) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + in, err := p.wsUpgrader.Upgrade(w, req, nil) + if err != nil { + p.logErrorF("ERR %v", err) + return + } + defer in.Close() + + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } + out, err := dialer.DialContext(req.Context(), "tcp", cfg.Address) + if err != nil { + p.logErrorF("ERR webSocketHandler: %v", err) + return + } + + done := make(chan bool, 2) + + lastActive := time.Now() + in.SetPongHandler(func(string) error { + lastActive = time.Now() + return nil + }) + go func() { + ctx := req.Context() + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if time.Since(lastActive) > 30*time.Second { + done <- true + return + } + if err := in.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(5*time.Second)); err != nil { + p.logErrorF("ERR WriteControl: %v", err) + } + } + } + }() + + // in -> out loop + go func() { + defer func() { + done <- true + }() + for { + messageType, r, err := in.NextReader() + if err != nil { + return + } + if messageType != websocket.BinaryMessage { + continue + } + if _, err := io.Copy(out, r); err != nil { + return + } + } + }() + + // out -> in loop + go func() { + defer func() { + done <- true + }() + buf := make([]byte, 1024) + for { + n, err := out.Read(buf) + if n > 0 { + if err := in.WriteMessage(websocket.BinaryMessage, buf[:n]); err != nil { + return + } + } + if err != nil { + return + } + } + }() + + <-done + }) +}