From 1acbef07eaf977cdf8c7392ab6e6107af8c27992 Mon Sep 17 00:00:00 2001 From: "daniel@poradnik-webmastera.com" Date: Sat, 6 Jul 2024 12:42:08 +0200 Subject: [PATCH] Fix packet length validation Missing packet length validations could cause crash. --- errors.go | 3 ++- srtcp.go | 4 ++++ srtcp_test.go | 24 ++++++++++++++++++++++++ srtp.go | 13 +++++++++---- srtp_cipher_aes_cm_hmac_sha1.go | 3 +++ srtp_test.go | 12 ++++++++++++ 6 files changed, 54 insertions(+), 5 deletions(-) diff --git a/errors.go b/errors.go index 5b1751d..f918b9c 100644 --- a/errors.go +++ b/errors.go @@ -18,7 +18,8 @@ var ( errNoConfig = errors.New("no config provided") errNoConn = errors.New("no conn provided") errFailedToVerifyAuthTag = errors.New("failed to verify auth tag") - errTooShortRTCP = errors.New("packet is too short to be rtcp packet") + errTooShortRTP = errors.New("packet is too short to be RTP packet") + errTooShortRTCP = errors.New("packet is too short to be RTCP packet") errPayloadDiffers = errors.New("payload differs") errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed") errBadIVLength = errors.New("bad iv length in xorBytesCTR") diff --git a/srtcp.go b/srtcp.go index 7fd0746..0aae9ab 100644 --- a/srtcp.go +++ b/srtcp.go @@ -63,6 +63,10 @@ func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byt } func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) { + if len(decrypted) < 8 { + return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(decrypted)) + } + ssrc := binary.BigEndian.Uint32(decrypted[4:]) s := c.getSRTCPSSRCState(ssrc) diff --git a/srtcp_test.go b/srtcp_test.go index 8cf8e3a..d879d0e 100644 --- a/srtcp_test.go +++ b/srtcp_test.go @@ -594,3 +594,27 @@ func TestRTCPReplayDetectorFactory(t *testing.T) { } assert.Equal(1, cntFactory) } + +func TestDecryptInvalidSRTCP(t *testing.T) { + assert := assert.New(t) + key := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01} + salt := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01} + decryptContext, err := CreateContext(key, salt, ProtectionProfileAes128CmHmacSha1_80) + assert.NoError(err) + + packet := []byte{0x8f, 0x48, 0xff, 0xff, 0xec, 0x77, 0xb0, 0x43, 0xf9, 0x04, 0x51, 0xff, 0xfb, 0xdf} + _, err = decryptContext.DecryptRTCP(nil, packet, nil) + assert.Error(err) +} + +func TestEncryptInvalidRTCP(t *testing.T) { + assert := assert.New(t) + key := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01} + salt := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01} + decryptContext, err := CreateContext(key, salt, ProtectionProfileAes128CmHmacSha1_80) + assert.NoError(err) + + packet := []byte{0xbb, 0xbb, 0x0a, 0x2f} + _, err = decryptContext.EncryptRTCP(nil, packet, nil) + assert.Error(err) +} diff --git a/srtp.go b/srtp.go index 42c71be..3a7ed11 100644 --- a/srtp.go +++ b/srtp.go @@ -9,6 +9,15 @@ import ( ) func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int) ([]byte, error) { + authTagLen, err := c.cipher.rtpAuthTagLen() + if err != nil { + return nil, err + } + + if len(ciphertext) < headerLen+authTagLen { + return nil, errTooShortRTP + } + s := c.getSRTPSSRCState(header.SSRC) roc, diff, _ := s.nextRolloverCount(header.SequenceNumber) @@ -21,10 +30,6 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL } } - authTagLen, err := c.cipher.rtpAuthTagLen() - if err != nil { - return nil, err - } dst = growBufferSize(dst, len(ciphertext)-authTagLen) dst, err = c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) diff --git a/srtp_cipher_aes_cm_hmac_sha1.go b/srtp_cipher_aes_cm_hmac_sha1.go index d56e6af..20e0376 100644 --- a/srtp_cipher_aes_cm_hmac_sha1.go +++ b/srtp_cipher_aes_cm_hmac_sha1.go @@ -161,6 +161,9 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc return nil, err } tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) + if tailOffset < 8 { + return nil, errTooShortRTCP + } out = out[0:tailOffset] expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-authTagLen]) diff --git a/srtp_test.go b/srtp_test.go index c27f598..6b23b7b 100644 --- a/srtp_test.go +++ b/srtp_test.go @@ -884,3 +884,15 @@ func TestRTPBurstLossWithSetROC(t *testing.T) { }) } } + +func TestDecryptInvalidSRTP(t *testing.T) { + assert := assert.New(t) + key := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01} + salt := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01} + decryptContext, err := CreateContext(key, salt, ProtectionProfileAes128CmHmacSha1_80) + assert.NoError(err) + + packet := []byte{0x41, 0x02, 0x07, 0xf9, 0xf9, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xb5, 0x73, 0x19, 0xf6, 0x91, 0xbb, 0x3e, 0xa5, 0x21, 0x07} + _, err = decryptContext.DecryptRTP(nil, packet, nil) + assert.Error(err) +}