Skip to content

Commit

Permalink
Add support for WebSockets (#161)
Browse files Browse the repository at this point in the history
### Description

Add support for forwarding WebSocket requests to arbitrary TCP servers.

WebSockets were already forwarded transparently to backends before, and
that is not changing.

The new feature lets tlsproxy itself handle the WebSocket request and
forward them to any TCP servers. The content of BinaryMessages is
streamed to the remote server, and data received from the server is sent
back to the client also as BinaryMessages.

### Type of change

* [x] New feature
* [ ] Feature improvement
* [ ] Bug fix
* [ ] Documentation
* [ ] Cleanup / refactoring
* [ ] Other (please explain)


### How is this change tested ?

* [ ] Unit tests
* [x] Manual tests (explain)
* [ ] Tests are not needed
  • Loading branch information
rthellend authored Nov 20, 2024
1 parent 7d8beab commit 019832a
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 0 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# TLSPROXY Release Notes

## next

### :star2: New features

* Add support for forwarding WebSocket requests to arbitrary TCP servers. WebSockets were already forwarded transparently to backends before, and that is not changing. The new feature lets tlsproxy itself handle the WebSocket request and forward them to any TCP servers. The content of BinaryMessages is streamed to the remote server, and data received from the server is sent back to the client also as BinaryMessages.

## v0.12.0

### :star: Feature improvement
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ require (
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/google/go-tpm v0.9.1 // indirect
github.com/google/pprof v0.0.0-20241101162523-b92577c0c142 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/jonboulle/clockwork v0.4.0 // indirect
github.com/onsi/ginkgo/v2 v2.21.0 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ github.com/google/pprof v0.0.0-20241101162523-b92577c0c142 h1:sAGdeJj0bnMgUNVeUp
github.com/google/pprof v0.0.0-20241101162523-b92577c0c142/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8=
Expand Down
11 changes: 11 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ func main() {

ctx, canc := context.WithTimeout(ctx, *shutdownGraceFlag)
defer canc()

go func() {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGINT)
signal.Notify(ch, syscall.SIGTERM)
select {
case <-ch:
os.Exit(1)
case <-ctx.Done():
}
}()
p.Shutdown(ctx)
}

Expand Down
13 changes: 13 additions & 0 deletions proxy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ type Config struct {
// Each backend can be associated with one group. The group's limits
// are shared between all the backends associated with it.
BWLimits []*BWLimit `yaml:"bwLimits,omitempty"`
// WebSockets is a list of WebSocket endpoints and where they get
// forwarded.
// Incoming WebSocket requests are bridged to TCP connections. The
// content of BinaryMessages are streamed to the TCP server, and
// data received from the server is sent to the client also
// as BinaryMessages.
WebSockets []*WebSocketConfig `yaml:"webSockets,omitempty"`

acceptProxyHeaderFrom []*net.IPNet
}
Expand Down Expand Up @@ -203,6 +210,12 @@ type TLSCertificate struct {
CertFile string `yaml:"cert"`
}

// WebSocketConfig specifies a WebSocket endpoint.
type WebSocketConfig struct {
Endpoint string `yaml:"endpoint"`
Address string `yaml:"address,omitempty"`
}

// Backend encapsulates the data of one backend.
type Backend struct {
// ServerNames is the list of all the server names for this service,
Expand Down
11 changes: 11 additions & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import (
"github.com/c2FmZQ/storage/autocertcache"
"github.com/c2FmZQ/storage/crypto"
"github.com/c2FmZQ/tpm"
"github.com/gorilla/websocket"
"github.com/pires/go-proxyproto"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
Expand Down Expand Up @@ -117,6 +118,7 @@ type Proxy struct {
mk crypto.MasterKey
store *storage.Storage
tokenManager *tokenmanager.TokenManager
wsUpgrader *websocket.Upgrader

mu sync.RWMutex
connClosed *sync.Cond
Expand Down Expand Up @@ -791,6 +793,15 @@ func (p *Proxy) Reconfigure(cfg *Config) error {
}, pp.Endpoint)
}
}
if len(cfg.WebSockets) > 0 && p.wsUpgrader == nil {
p.wsUpgrader = newWebSocketUpgrader()
}
for _, ws := range cfg.WebSockets {
addLocalHandler(localHandler{
desc: "WebSocket Endpoint",
handler: logHandler(p.webSocketHandler(*ws)),
}, ws.Endpoint)
}
for _, be := range backends {
sort.Slice(be.localHandlers, func(i, j int) bool {
a := be.localHandlers[i].host
Expand Down
128 changes: 128 additions & 0 deletions proxy/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// MIT License
//
// Copyright (c) 2024 TTBT Enterprises LLC
// Copyright (c) 2024 Robin Thellend <[email protected]>
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

package proxy

import (
"io"
"net"
"net/http"
"time"

"github.com/gorilla/websocket"
)

func newWebSocketUpgrader() *websocket.Upgrader {
return &websocket.Upgrader{
ReadBufferSize: 8192,
WriteBufferSize: 8192,
}
}

func (p *Proxy) webSocketHandler(cfg WebSocketConfig) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
in, err := p.wsUpgrader.Upgrade(w, req, nil)
if err != nil {
p.logErrorF("ERR %v", err)
return
}
defer in.Close()

dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
out, err := dialer.DialContext(req.Context(), "tcp", cfg.Address)
if err != nil {
p.logErrorF("ERR webSocketHandler: %v", err)
return
}

done := make(chan bool, 2)

lastActive := time.Now()
in.SetPongHandler(func(string) error {
lastActive = time.Now()
return nil
})
go func() {
ctx := req.Context()
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if time.Since(lastActive) > 30*time.Second {
done <- true
return
}
if err := in.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(5*time.Second)); err != nil {
p.logErrorF("ERR WriteControl: %v", err)
}
}
}
}()

// in -> out loop
go func() {
defer func() {
done <- true
}()
for {
messageType, r, err := in.NextReader()
if err != nil {
return
}
if messageType != websocket.BinaryMessage {
continue
}
if _, err := io.Copy(out, r); err != nil {
return
}
}
}()

// out -> in loop
go func() {
defer func() {
done <- true
}()
buf := make([]byte, 1024)
for {
n, err := out.Read(buf)
if n > 0 {
if err := in.WriteMessage(websocket.BinaryMessage, buf[:n]); err != nil {
return
}
}
if err != nil {
return
}
}
}()

<-done
})
}

0 comments on commit 019832a

Please sign in to comment.