Skip to content

Commit

Permalink
feat: support client
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaost committed Oct 22, 2024
1 parent 0c152ad commit baa2cd2
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 19 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
.idea
# vscode
.vscode

# Go workspace file
go.work
go.work.sum
90 changes: 90 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package websocket

import (
"bytes"
"errors"
"fmt"
"time"

"github.com/cloudwego/hertz/pkg/protocol"
)

// ErrBadHandshake is returned when the server response to opening handshake is
// invalid.
var ErrBadHandshake = errors.New("websocket: bad handshake")

// ClientUpgrader is a helper for upgrading hertz http response to websocket conn.
// See ExampleClient for usage
type ClientUpgrader struct {
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then buffers allocated by the HTTP server are used. The
// I/O buffer sizes do not limit the size of the messages that can be sent
// or received.
ReadBufferSize, WriteBufferSize int

// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool

// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
}

// PrepareRequest prepares request for websocket
//
// It adds headers for websocket,
// and it must be called BEFORE sending http request via cli.DoXXX
func (p *ClientUpgrader) PrepareRequest(req *protocol.Request) {
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", generateChallengeKey())
if p.EnableCompression {
req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}
}

// UpgradeResponse upgrades a response to websocket conn
//
// It returns Conn if success. ErrBadHandshake is returned if headers go wrong.
// This method must be called after PrepareRequest and (*.Client).DoXXX
func (p *ClientUpgrader) UpgradeResponse(req *protocol.Request, resp *protocol.Response) (*Conn, error) {
if resp.StatusCode() != 101 ||
!tokenContainsValue(resp.Header.Get("Upgrade"), "websocket") ||
!tokenContainsValue(resp.Header.Get("Connection"), "Upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKeyBytes(req.Header.Peek("Sec-Websocket-Key")) {
return nil, ErrBadHandshake
}

c, err := resp.Hijack()
if err != nil {
return nil, fmt.Errorf("Hijack response connection err: %w", err)
}

c.SetDeadline(time.Time{})
conn := newConn(c, false, p.ReadBufferSize, p.WriteBufferSize, p.WriteBufferPool, nil, nil)

// can not use p.EnableCompression, always follow ext returned from server
compress := false
extensions := parseDataHeader(resp.Header.Peek("Sec-WebSocket-Extensions"))
for _, ext := range extensions {
if bytes.HasPrefix(ext, strPermessageDeflate) {
compress = true
}
}
if compress {
conn.newCompressionWriter = compressNoContextTakeover
conn.newDecompressionReader = decompressNoContextTakeover
}
return conn, nil
}
85 changes: 85 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package websocket

import (
"context"
"fmt"
"log"
"time"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/client"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/network/standard"
"github.com/cloudwego/hertz/pkg/protocol"
)

const (
testaddr = "localhost:10012"
testpath = "/echo"
)

func ExampleClient() {
runServer(testaddr)
time.Sleep(50 * time.Millisecond) // await server running

c, err := client.NewClient(client.WithDialer(standard.NewDialer()))
if err != nil {
panic(err)
}

req, resp := protocol.AcquireRequest(), protocol.AcquireResponse()
req.SetRequestURI("http://" + testaddr + testpath)
req.SetMethod("GET")

u := &ClientUpgrader{}
u.PrepareRequest(req)
err = c.Do(context.Background(), req, resp)
if err != nil {
panic(err)
}
conn, err := u.UpgradeResponse(req, resp)
if err != nil {
panic(err)
}

conn.WriteMessage(TextMessage, []byte("hello"))
m, b, err := conn.ReadMessage()
if err != nil {
panic(err)
}
fmt.Println(m, string(b))
// Output: 1 hello
}

func runServer(addr string) {
upgrader := HertzUpgrader{} // use default options
h := server.Default(server.WithHostPorts(addr))
// https://github.com/cloudwego/hertz/issues/121
h.NoHijackConnPool = true
h.GET(testpath, func(_ context.Context, c *app.RequestContext) {
err := upgrader.Upgrade(c, func(conn *Conn) {
for {
mt, message, err := conn.ReadMessage()
if err != nil {
log.Println("read:", err)
break
}
log.Printf("[server] recv: %v %s", mt, message)
err = conn.WriteMessage(mt, message)
if err != nil {
log.Println("write:", err)
break
}
}
})
if err != nil {
log.Print("upgrade:", err)
return
}
})
go func() {
if err := h.Run(); err != nil {
log.Fatal(err)
}
}()
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ module github.com/hertz-contrib/websocket
go 1.16

require (
github.com/bytedance/sonic v1.11.9
github.com/cloudwego/hertz v0.9.1
github.com/bytedance/sonic v1.12.0
github.com/cloudwego/hertz v0.9.4-0.20241021100040-3477b0309b81
)
38 changes: 21 additions & 17 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
github.com/bytedance/go-tagexpr/v2 v2.9.2 h1:QySJaAIQgOEDQBLS3x9BxOWrnhqu5sQ+f6HaZIxD39I=
github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM=
github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 h1:PtwsQyQJGxf8iaPptPNaduEIu9BnrNms+pcRdHAxZaM=
github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/mockey v1.2.1 h1:g84ngI88hz1DR4wZTL3yOuqlEcq67MretBfQUdXwrmw=
github.com/bytedance/mockey v1.2.1/go.mod h1:+Jm/fzWZAuhEDrPXVjDf/jLM2BlLXJkwk94zf2JZ3X4=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/bytedance/sonic v1.11.9 h1:LFHENlIY/SLzDWverzdOvgMztTxcfcF+cqNsz9pK5zg=
github.com/bytedance/sonic v1.11.9/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ=
github.com/bytedance/gopkg v0.1.0 h1:aAxB7mm1qms4Wz4sp8e1AtKDOeFLtdqvGiUe7aonRJs=
github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ=
github.com/bytedance/mockey v1.2.12 h1:aeszOmGw8CPX8CRx1DZ/Glzb1yXvhjDh6jdFBNZjsU4=
github.com/bytedance/mockey v1.2.12/go.mod h1:3ZA4MQasmqC87Tw0w7Ygdy7eHIc2xgpZ8Pona5rsYIk=
github.com/bytedance/sonic v1.12.0 h1:YGPgxF9xzaCNvd/ZKdQ28yRovhfMFZQjuk6fKBzZ3ls=
github.com/bytedance/sonic v1.12.0/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM=
github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/hertz v0.9.1 h1:+jK9A6MDNTUVy6q/zSOlhbnp1fFMiOaPIsq0jlOfjZE=
github.com/cloudwego/hertz v0.9.1/go.mod h1:cs8dH6unM4oaJ5k9m6pqbgLBPqakGWMG0+cthsxitsg=
github.com/cloudwego/hertz v0.9.4-0.20241021100040-3477b0309b81 h1:lrZ2nuRsR4M9KG1N+ihkict9Q2gzNwFxmId6NksKCAY=
github.com/cloudwego/hertz v0.9.4-0.20241021100040-3477b0309b81/go.mod h1:gGVUfJU/BOkJv/ZTzrw7FS7uy7171JeYIZvAyV3wS3o=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/cloudwego/netpoll v0.6.0 h1:JRMkrA1o8k/4quxzg6Q1XM+zIhwZsyoWlq6ef+ht31U=
github.com/cloudwego/netpoll v0.6.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ=
github.com/cloudwego/netpoll v0.6.2 h1:+KdILv5ATJU+222wNNXpHapYaBeRvvL8qhJyhcxRxrQ=
github.com/cloudwego/netpoll v0.6.2/go.mod h1:kaqvfZ70qd4T2WtIIpCOi5Cxyob8viEpzLhCrTrz3HM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down Expand Up @@ -74,12 +72,18 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VA
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand Down
9 changes: 9 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,21 @@ import (
"bytes"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"math/rand"
"unicode/utf8"
"unsafe"
)

var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

func generateChallengeKey() string {
b := make([]byte, 16)
binary.BigEndian.PutUint64(b, rand.Uint64())
binary.BigEndian.PutUint64(b[8:], rand.Uint64())
return base64.StdEncoding.EncodeToString(b)
}

// Token octets per RFC 2616.
var isTokenOctet = [256]bool{
'!': true,
Expand Down

0 comments on commit baa2cd2

Please sign in to comment.