diff --git a/websocket/client.go b/websocket/client.go index 69a4ac7eef..fd519f35df 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -42,7 +42,20 @@ func NewConfig(server, origin string) (config *Config, err error) { func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) { br := bufio.NewReader(rwc) bw := bufio.NewWriter(rwc) - err = hybiClientHandshake(config, br, bw) + _, err = hybiClientHandshake(config, br, bw) + if err != nil { + return + } + buf := bufio.NewReadWriter(br, bw) + ws = newHybiClientConn(config, buf, rwc) + return +} + +// NewClient2 creates a new WebSocket client connection over rwc. +func NewClient2(config *Config, rwc io.ReadWriteCloser) (ws *Conn, resp *http.Response, err error) { + br := bufio.NewReader(rwc) + bw := bufio.NewWriter(rwc) + resp, err = hybiClientHandshake(config, br, bw) if err != nil { return } diff --git a/websocket/hybi.go b/websocket/hybi.go index 48a069e190..3805200838 100644 --- a/websocket/hybi.go +++ b/websocket/hybi.go @@ -106,6 +106,8 @@ func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil } func (frame *hybiFrameReader) Len() (n int) { return frame.length } +func (frame *hybiFrameReader) FrameDataLength() (n int64) { return frame.header.Length } + // A hybiFrameReaderFactory creates new frame reader based on its frame type. type hybiFrameReaderFactory struct { *bufio.Reader @@ -402,7 +404,7 @@ func getNonceAccept(nonce []byte) (expected []byte, err error) { } // Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17 -func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) { +func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (resp *http.Response, err error) { bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n") // According to RFC 6874, an HTTP client, proxy, or other @@ -419,7 +421,7 @@ func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (er bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") if config.Version != ProtocolVersionHybi13 { - return ErrBadProtocolVersion + return nil, ErrBadProtocolVersion } bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n") @@ -429,34 +431,34 @@ func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (er // TODO(ukai): send Sec-WebSocket-Extensions. err = config.Header.WriteSubset(bw, handshakeHeader) if err != nil { - return err + return nil, err } bw.WriteString("\r\n") if err = bw.Flush(); err != nil { - return err + return nil, err } - resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) + resp, err = http.ReadResponse(br, &http.Request{Method: "GET"}) if err != nil { - return err + return } if resp.StatusCode != 101 { - return ErrBadStatus + return resp, ErrBadStatus } if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { - return ErrBadUpgrade + return resp, ErrBadUpgrade } expectedAccept, err := getNonceAccept(nonce) if err != nil { - return err + return resp, err } if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) { - return ErrChallengeResponse + return resp, ErrChallengeResponse } if resp.Header.Get("Sec-WebSocket-Extensions") != "" { - return ErrUnsupportedExtensions + return resp, ErrUnsupportedExtensions } offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol") if offeredProtocol != "" { @@ -468,12 +470,12 @@ func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (er } } if !protocolMatched { - return ErrBadWebSocketProtocol + return resp, ErrBadWebSocketProtocol } config.Protocol = []string{offeredProtocol} } - return nil + return resp, nil } // newHybiClientConn creates a client WebSocket connection after handshake. diff --git a/websocket/hybi_test.go b/websocket/hybi_test.go index 9504aa2d30..f3402cdabc 100644 --- a/websocket/hybi_test.go +++ b/websocket/hybi_test.go @@ -68,7 +68,7 @@ Sec-WebSocket-Protocol: chat config.handshakeData = map[string]string{ "key": "dGhlIHNhbXBsZSBub25jZQ==", } - if err := hybiClientHandshake(&config, br, bw); err != nil { + if _, err := hybiClientHandshake(&config, br, bw); err != nil { t.Fatal("handshake", err) } req, err := http.ReadRequest(bufio.NewReader(&b)) @@ -132,7 +132,7 @@ Sec-WebSocket-Protocol: chat config.handshakeData = map[string]string{ "key": "dGhlIHNhbXBsZSBub25jZQ==", } - err = hybiClientHandshake(config, br, bw) + _, err = hybiClientHandshake(config, br, bw) if err != nil { t.Errorf("handshake failed: %v", err) } diff --git a/websocket/websocket.go b/websocket/websocket.go index 90a2257cd5..5505e7b2f7 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -134,6 +134,9 @@ type frameReader interface { // Len returns total length of the frame, including header and trailer. Len() int + + // FrameDataLength returns data length of the frame. + FrameDataLength() int64 } // frameReaderFactory is an interface to creates new frame reader. diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index 2054ce85a6..975e414131 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -136,6 +136,53 @@ func newConfig(t *testing.T, path string) *Config { return config } +func doWriteRead(t *testing.T, conn *Conn) { + var buf bytes.Buffer + for i := 0; i < 5; i++ { + n, err := conn.Write([]byte(fmt.Sprintf("websocket_test:%d", i))) + if err != nil { + t.Fatal("write:", err) + return + } + for { + var b [2]byte + n, err = conn.Read(b[:]) + if n > 0 { + buf.Write(b[:n]) + } else if err != nil { + t.Fatal("err:", err) + return + } + if buf.Len() == int(conn.FrameDataLength()) { + t.Log("read:", string(buf.Bytes())) + buf.Reset() + break + } + } + } +} + +func TestNewClient2(t *testing.T) { + once.Do(startServer) + + // websocket.Dial() + client, err := net.Dial("tcp", serverAddr) + if err != nil { + t.Fatal("dialing", err) + } + conn, resp, err := NewClient2(newConfig(t, "/echo"), client) + if err != nil { + if resp != nil { + t.Fatal("newClient2:StatusCode:", resp.StatusCode, " err:", err) + } else { + t.Fatal("newClient2:StatusCode:", 0, " err:", err) + } + return + } + + doWriteRead(t, conn) +} + func TestEcho(t *testing.T) { once.Do(startServer)