Skip to content

Commit

Permalink
Merge pull request v2ray#1950 from keepalivesrc/patch-2
Browse files Browse the repository at this point in the history
Websocket Read Limit Fix
  • Loading branch information
kslr authored Oct 18, 2019
2 parents 607425b + 01c7bba commit 31a647b
Showing 1 changed file with 48 additions and 10 deletions.
58 changes: 48 additions & 10 deletions external/github.com/gorilla/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,12 @@ type Conn struct {
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser

// Read fields
reader io.ReadCloser // the current reader returned to the application
readErr error
br *bufio.Reader
readRemaining int64 // bytes remaining in current frame.
reader io.ReadCloser // the current reader returned to the application
readErr error
br *bufio.Reader
// bytes remaining in current frame.
// set setReadRemaining to safely update this value and prevent overflow
readRemaining int64
readFinal bool // true the current message has more frames.
readLength int64 // Message size.
readLimit int64 // Maximum message size.
Expand Down Expand Up @@ -320,6 +322,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
return c
}

// setReadRemaining tracks the number of bytes remaining on the connection. If n
// overflows, an ErrReadLimit is returned.
func (c *Conn) setReadRemaining(n int64) error {
if n < 0 {
return ErrReadLimit
}

c.readRemaining = n
return nil
}

// Subprotocol returns the negotiated protocol for the connection.
func (c *Conn) Subprotocol() string {
return c.subprotocol
Expand Down Expand Up @@ -770,7 +783,7 @@ func (c *Conn) advanceFrame() (int, error) {
final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
mask := p[1]&maskBit != 0
c.readRemaining = int64(p[1] & 0x7f)
c.setReadRemaining(int64(p[1] & 0x7f))

c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
Expand Down Expand Up @@ -804,21 +817,37 @@ func (c *Conn) advanceFrame() (int, error) {
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
}

// 3. Read and parse frame length.
// 3. Read and parse frame length as per
// https://tools.ietf.org/html/rfc6455#section-5.2
//
// The length of the "Payload data", in bytes: if 0-125, that is the payload
// length.
// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
// integer are the payload length.
// - If 127, the following 8 bytes interpreted as
// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
// payload length. Multibyte length quantities are expressed in network byte
// order.

switch c.readRemaining {
case 126:
p, err := c.read(2)
if err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint16(p))

if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
return noFrame, err
}
case 127:
p, err := c.read(8)
if err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint64(p))

if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
return noFrame, err
}
}

// 4. Handle frame masking.
Expand All @@ -841,6 +870,12 @@ func (c *Conn) advanceFrame() (int, error) {
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {

c.readLength += c.readRemaining
// Don't allow readLength to overflow in the presence of a large readRemaining
// counter.
if c.readLength < 0 {
return noFrame, ErrReadLimit
}

if c.readLimit > 0 && c.readLength > c.readLimit {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
return noFrame, ErrReadLimit
Expand All @@ -854,7 +889,7 @@ func (c *Conn) advanceFrame() (int, error) {
var payload []byte
if c.readRemaining > 0 {
payload, err = c.read(int(c.readRemaining))
c.readRemaining = 0
c.setReadRemaining(0)
if err != nil {
return noFrame, err
}
Expand Down Expand Up @@ -927,6 +962,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.readErr = hideTempErr(err)
break
}

if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c}
c.reader = c.messageReader
Expand Down Expand Up @@ -967,7 +1003,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
if c.isServer {
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
}
c.readRemaining -= int64(n)
rem := c.readRemaining
rem -= int64(n)
c.setReadRemaining(rem)
if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF
}
Expand Down

0 comments on commit 31a647b

Please sign in to comment.