-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
214 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,7 @@ | |
.idea | ||
# vscode | ||
.vscode | ||
|
||
# Go workspace file | ||
go.work | ||
go.work.sum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package websocket | ||
|
||
import ( | ||
"bytes" | ||
"errors" | ||
"fmt" | ||
"time" | ||
|
||
"github.com/cloudwego/hertz/pkg/protocol" | ||
) | ||
|
||
// ErrBadHandshake is returned when the server response to opening handshake is | ||
// invalid. | ||
var ErrBadHandshake = errors.New("websocket: bad handshake") | ||
|
||
type ClientUpgrader struct { | ||
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer | ||
// size is zero, then buffers allocated by the HTTP server are used. The | ||
// I/O buffer sizes do not limit the size of the messages that can be sent | ||
// or received. | ||
ReadBufferSize, WriteBufferSize int | ||
|
||
// WriteBufferPool is a pool of buffers for write operations. If the value | ||
// is not set, then write buffers are allocated to the connection for the | ||
// lifetime of the connection. | ||
// | ||
// A pool is most useful when the application has a modest volume of writes | ||
// across a large number of connections. | ||
// | ||
// Applications should use a single pool for each unique value of | ||
// WriteBufferSize. | ||
WriteBufferPool BufferPool | ||
|
||
// EnableCompression specify if the server should attempt to negotiate per | ||
// message compression (RFC 7692). Setting this value to true does not | ||
// guarantee that compression will be supported. Currently only "no context | ||
// takeover" modes are supported. | ||
EnableCompression bool | ||
} | ||
|
||
func (p *ClientUpgrader) PrepareRequest(req *protocol.Request) { | ||
req.Header.Set("Upgrade", "websocket") | ||
req.Header.Set("Connection", "Upgrade") | ||
req.Header.Set("Sec-WebSocket-Version", "13") | ||
req.Header.Set("Sec-WebSocket-Key", generateChallengeKey()) | ||
if p.EnableCompression { | ||
req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") | ||
} | ||
} | ||
|
||
func (p *ClientUpgrader) UpgradeResponse(req *protocol.Request, resp *protocol.Response) (*Conn, error) { | ||
if resp.StatusCode() != 101 || | ||
!tokenContainsValue(resp.Header.Get("Upgrade"), "websocket") || | ||
!tokenContainsValue(resp.Header.Get("Connection"), "Upgrade") || | ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKeyBytes(req.Header.Peek("Sec-Websocket-Key")) { | ||
return nil, ErrBadHandshake | ||
} | ||
|
||
c, err := resp.Hijack() | ||
if err != nil { | ||
return nil, fmt.Errorf("Hijack response connection err: %w", err) | ||
} | ||
|
||
c.SetDeadline(time.Time{}) | ||
conn := newConn(c, false, p.ReadBufferSize, p.WriteBufferSize, p.WriteBufferPool, nil, nil) | ||
|
||
// can not use p.EnableCompression, always follow ext returned from server | ||
compress := false | ||
extensions := parseDataHeader(resp.Header.Peek("Sec-WebSocket-Extensions")) | ||
for _, ext := range extensions { | ||
if bytes.HasPrefix(ext, strPermessageDeflate) { | ||
compress = true | ||
} | ||
} | ||
if compress { | ||
conn.newCompressionWriter = compressNoContextTakeover | ||
conn.newDecompressionReader = decompressNoContextTakeover | ||
} | ||
return conn, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package websocket | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"log" | ||
"net" | ||
"runtime" | ||
"time" | ||
|
||
"github.com/cloudwego/hertz/pkg/app" | ||
"github.com/cloudwego/hertz/pkg/app/client" | ||
"github.com/cloudwego/hertz/pkg/app/server" | ||
"github.com/cloudwego/hertz/pkg/network/standard" | ||
"github.com/cloudwego/hertz/pkg/protocol" | ||
) | ||
|
||
const ( | ||
testaddr = "localhost:10012" | ||
testpath = "/echo" | ||
) | ||
|
||
func runServer(addr string) { | ||
upgrader := HertzUpgrader{} // use default options | ||
h := server.Default(server.WithHostPorts(addr)) | ||
// https://github.com/cloudwego/hertz/issues/121 | ||
h.NoHijackConnPool = true | ||
h.GET(testpath, func(_ context.Context, c *app.RequestContext) { | ||
err := upgrader.Upgrade(c, func(conn *Conn) { | ||
for { | ||
mt, message, err := conn.ReadMessage() | ||
if err != nil { | ||
log.Println("read:", err) | ||
break | ||
} | ||
log.Printf("[server] recv: %v %s", mt, message) | ||
err = conn.WriteMessage(mt, message) | ||
if err != nil { | ||
log.Println("write:", err) | ||
break | ||
} | ||
} | ||
}) | ||
if err != nil { | ||
log.Print("upgrade:", err) | ||
return | ||
} | ||
}) | ||
go h.Run() | ||
} | ||
|
||
func waitListener(addr string) { | ||
time.Sleep(5 * time.Millisecond) // likely it's up | ||
_, file, no, _ := runtime.Caller(1) | ||
for i := 0; i < 50; i++ { // 5s | ||
if ln, err := net.Dial("tcp", addr); err == nil { | ||
ln.Close() | ||
log.Printf("[server] %s is up @ %s:%d", addr, file, no) | ||
return | ||
} | ||
log.Printf("waiting server %s @ %s:%d", addr, file, no) | ||
time.Sleep(100 * time.Millisecond) | ||
} | ||
panic("server " + addr + " not ready") | ||
} | ||
|
||
func ExampleClient() { | ||
runServer(testaddr) | ||
waitListener(testaddr) | ||
|
||
c, err := client.NewClient(client.WithDialer(standard.NewDialer())) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() | ||
req.SetRequestURI("http://" + testaddr + testpath) | ||
req.SetMethod("GET") | ||
|
||
u := &ClientUpgrader{} | ||
u.PrepareRequest(req) | ||
err = c.Do(context.Background(), req, resp) | ||
if err != nil { | ||
panic(err) | ||
} | ||
conn, err := u.UpgradeResponse(req, resp) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
conn.WriteMessage(TextMessage, []byte("hello")) | ||
m, b, err := conn.ReadMessage() | ||
if err != nil { | ||
panic(err) | ||
} | ||
fmt.Println(m, string(b)) | ||
// Output: 1 hello | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters