Skip to content

Commit

Permalink
Fix packet length validation
Browse files Browse the repository at this point in the history
Missing packet length validations could cause crash.
  • Loading branch information
[email protected] authored and sirzooro committed Jul 6, 2024
1 parent 19b0fa0 commit 1acbef0
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 5 deletions.
3 changes: 2 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ var (
errNoConfig = errors.New("no config provided")
errNoConn = errors.New("no conn provided")
errFailedToVerifyAuthTag = errors.New("failed to verify auth tag")
errTooShortRTCP = errors.New("packet is too short to be rtcp packet")
errTooShortRTP = errors.New("packet is too short to be RTP packet")
errTooShortRTCP = errors.New("packet is too short to be RTCP packet")
errPayloadDiffers = errors.New("payload differs")
errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed")
errBadIVLength = errors.New("bad iv length in xorBytesCTR")
Expand Down
4 changes: 4 additions & 0 deletions srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byt
}

func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) {
if len(decrypted) < 8 {
return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(decrypted))
}

ssrc := binary.BigEndian.Uint32(decrypted[4:])
s := c.getSRTCPSSRCState(ssrc)

Expand Down
24 changes: 24 additions & 0 deletions srtcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,3 +594,27 @@ func TestRTCPReplayDetectorFactory(t *testing.T) {
}
assert.Equal(1, cntFactory)
}

func TestDecryptInvalidSRTCP(t *testing.T) {
assert := assert.New(t)
key := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}
salt := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}
decryptContext, err := CreateContext(key, salt, ProtectionProfileAes128CmHmacSha1_80)
assert.NoError(err)

packet := []byte{0x8f, 0x48, 0xff, 0xff, 0xec, 0x77, 0xb0, 0x43, 0xf9, 0x04, 0x51, 0xff, 0xfb, 0xdf}
_, err = decryptContext.DecryptRTCP(nil, packet, nil)
assert.Error(err)
}

func TestEncryptInvalidRTCP(t *testing.T) {
assert := assert.New(t)
key := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}
salt := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}
decryptContext, err := CreateContext(key, salt, ProtectionProfileAes128CmHmacSha1_80)
assert.NoError(err)

packet := []byte{0xbb, 0xbb, 0x0a, 0x2f}
_, err = decryptContext.EncryptRTCP(nil, packet, nil)
assert.Error(err)
}
13 changes: 9 additions & 4 deletions srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ import (
)

func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int) ([]byte, error) {
authTagLen, err := c.cipher.rtpAuthTagLen()
if err != nil {
return nil, err
}

if len(ciphertext) < headerLen+authTagLen {
return nil, errTooShortRTP
}

s := c.getSRTPSSRCState(header.SSRC)

roc, diff, _ := s.nextRolloverCount(header.SequenceNumber)
Expand All @@ -21,10 +30,6 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL
}
}

authTagLen, err := c.cipher.rtpAuthTagLen()
if err != nil {
return nil, err
}
dst = growBufferSize(dst, len(ciphertext)-authTagLen)

dst, err = c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc)
Expand Down
3 changes: 3 additions & 0 deletions srtp_cipher_aes_cm_hmac_sha1.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc
return nil, err
}
tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize)
if tailOffset < 8 {
return nil, errTooShortRTCP
}
out = out[0:tailOffset]

expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-authTagLen])
Expand Down
12 changes: 12 additions & 0 deletions srtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -884,3 +884,15 @@ func TestRTPBurstLossWithSetROC(t *testing.T) {
})
}
}

func TestDecryptInvalidSRTP(t *testing.T) {
assert := assert.New(t)
key := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}
salt := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}
decryptContext, err := CreateContext(key, salt, ProtectionProfileAes128CmHmacSha1_80)
assert.NoError(err)

packet := []byte{0x41, 0x02, 0x07, 0xf9, 0xf9, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xb5, 0x73, 0x19, 0xf6, 0x91, 0xbb, 0x3e, 0xa5, 0x21, 0x07}
_, err = decryptContext.DecryptRTP(nil, packet, nil)
assert.Error(err)
}

0 comments on commit 1acbef0

Please sign in to comment.