diff --git a/context.go b/context.go index 229527c..0a64840 100644 --- a/context.go +++ b/context.go @@ -4,6 +4,7 @@ package srtp import ( + "bytes" "fmt" "github.com/pion/transport/v3/replaydetector" @@ -56,6 +57,10 @@ type Context struct { newSRTCPReplayDetector func() replaydetector.ReplayDetector newSRTPReplayDetector func() replaydetector.ReplayDetector + + profile ProtectionProfile + sendMKI []byte + mkis map[string]srtpCipher } // CreateContext creates a new SRTP Context. @@ -66,52 +71,99 @@ type Context struct { // // decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256)) func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) { - keyLen, err := profile.KeyLen() + c = &Context{ + srtpSSRCStates: map[uint32]*srtpSSRCState{}, + srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, + profile: profile, + mkis: map[string]srtpCipher{}, + } + + for _, o := range append( + []ContextOption{ // Default options + SRTPNoReplayProtection(), + SRTCPNoReplayProtection(), + }, + opts..., // User specified options + ) { + if errOpt := o(c); errOpt != nil { + return nil, errOpt + } + } + + err = c.AddCipherForMKI(c.sendMKI, masterKey, masterSalt) if err != nil { return nil, err } + c.cipher = c.mkis[string(c.sendMKI)] - saltLen, err := profile.SaltLen() + return c, nil +} + +// AddCipherForMKI adds new MKI with associated masker key and salt. Context must be created with MasterKeyIndicator option to enable MKI support. +func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error { + if len(mki) != len(c.sendMKI) { + return errInvalidMKILength + } + if _, ok := c.mkis[string(mki)]; ok { + return errMKIAlreadyInUse + } + + keyLen, err := c.profile.KeyLen() if err != nil { - return nil, err + return err } - if masterKeyLen := len(masterKey); masterKeyLen != keyLen { - return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen) - } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen { - return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen) + saltLen, err := c.profile.SaltLen() + if err != nil { + return err } - c = &Context{ - srtpSSRCStates: map[uint32]*srtpSSRCState{}, - srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, + if masterKeyLen := len(masterKey); masterKeyLen != keyLen { + return fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen) + } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen { + return fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen) } - switch profile { + var cipher srtpCipher + switch c.profile { case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: - c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt) + cipher, err = newSrtpCipherAeadAesGcm(c.profile, masterKey, masterSalt, mki) case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: - c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt) + cipher, err = newSrtpCipherAesCmHmacSha1(c.profile, masterKey, masterSalt, mki) default: - return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile) + return fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, c.profile) } if err != nil { - return nil, err + return err } - for _, o := range append( - []ContextOption{ // Default options - SRTPNoReplayProtection(), - SRTCPNoReplayProtection(), - }, - opts..., // User specified options - ) { - if errOpt := o(c); errOpt != nil { - return nil, errOpt - } + c.mkis[string(mki)] = cipher + return nil +} + +// RemoveMKI removes one of MKIs. You cannot remove last MKI and one used for encrypting RTP/RTCP packets. +// Operation is not thread-safe, you need to provide synchronization with decrypting packets. +func (c *Context) RemoveMKI(mki []byte) error { + if _, ok := c.mkis[string(mki)]; !ok { + return ErrMKINotFound + } + if bytes.Equal(mki, c.sendMKI) { + return errMKIAlreadyInUse } + delete(c.mkis, string(mki)) + return nil +} - return c, nil +// SetSendMKI switches MKI and cipher used for encrypting RTP/RTCP packets. +// Operation is not thread-safe, you need to provide synchronization with encrypting packets. +func (c *Context) SetSendMKI(mki []byte) error { + cipher, ok := c.mkis[string(mki)] + if !ok { + return ErrMKINotFound + } + c.sendMKI = mki + c.cipher = cipher + return nil } // https://tools.ietf.org/html/rfc3550#appendix-A.1 diff --git a/context_test.go b/context_test.go index 60b078c..3df0f36 100644 --- a/context_test.go +++ b/context_test.go @@ -5,6 +5,8 @@ package srtp import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestContextROC(t *testing.T) { @@ -44,3 +46,123 @@ func TestContextIndex(t *testing.T) { t.Errorf("Index is set to 100, but returned %d", index) } } + +func TestContextWithoutMKI(t *testing.T) { + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR) + if err != nil { + t.Fatal(err) + } + + err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 4), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.SetSendMKI(nil) + assert.NoError(t, err) + + err = c.SetSendMKI(make([]byte, 0)) + assert.NoError(t, err) + + err = c.RemoveMKI(nil) + assert.Error(t, err) + + err = c.RemoveMKI(make([]byte, 0)) + assert.Error(t, err) + + err = c.RemoveMKI(make([]byte, 2)) + assert.Error(t, err) +} + +func TestAddMKIToContextWithMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 3), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(mki1, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) +} + +func TestContextSetSendMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.SetSendMKI(mki1) + assert.NoError(t, err) + + err = c.SetSendMKI(mki2) + assert.NoError(t, err) + + err = c.SetSendMKI(make([]byte, 4)) + assert.Error(t, err) +} + +func TestContextRemoveMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + mki3 := []byte{3, 4, 5, 6} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki3, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.RemoveMKI(make([]byte, 4)) + assert.Error(t, err) + + err = c.RemoveMKI(mki1) + assert.Error(t, err) + + err = c.SetSendMKI(mki3) + assert.NoError(t, err) + + err = c.RemoveMKI(mki1) + assert.NoError(t, err) + + err = c.RemoveMKI(mki2) + assert.NoError(t, err) + + err = c.RemoveMKI(mki3) + assert.Error(t, err) +} diff --git a/errors.go b/errors.go index f918b9c..6be0d1d 100644 --- a/errors.go +++ b/errors.go @@ -9,6 +9,11 @@ import ( ) var ( + // ErrFailedToVerifyAuthTag is returned when decryption fails due to invalid authentication tag + ErrFailedToVerifyAuthTag = errors.New("failed to verify auth tag") + // ErrMKINotFound is returned when decryption fails due to unknown MKI value in packet + ErrMKINotFound = errors.New("MKI not found") + errDuplicated = errors.New("duplicated packet") errShortSrtpMasterKey = errors.New("SRTP master key is not long enough") errShortSrtpMasterSalt = errors.New("SRTP master salt is not long enough") @@ -17,13 +22,14 @@ var ( errExporterWrongLabel = errors.New("exporter called with wrong label") errNoConfig = errors.New("no config provided") errNoConn = errors.New("no conn provided") - errFailedToVerifyAuthTag = errors.New("failed to verify auth tag") errTooShortRTP = errors.New("packet is too short to be RTP packet") errTooShortRTCP = errors.New("packet is too short to be RTCP packet") errPayloadDiffers = errors.New("payload differs") errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed") errBadIVLength = errors.New("bad iv length in xorBytesCTR") errExceededMaxPackets = errors.New("exceeded the maximum number of packets") + errMKIAlreadyInUse = errors.New("MKI already in use") + errInvalidMKILength = errors.New("invalid MKI length") errStreamNotInited = errors.New("stream has not been inited, unable to close") errStreamAlreadyClosed = errors.New("stream is already closed") diff --git a/option.go b/option.go index 67bdf2e..f7f011b 100644 --- a/option.go +++ b/option.go @@ -71,3 +71,14 @@ type nopReplayDetector struct{} func (s *nopReplayDetector) Check(uint64) (func() bool, bool) { return func() bool { return true }, true } + +// MasterKeyIndicator sets MKI for RTP and RTCP. +func MasterKeyIndicator(mki []byte) ContextOption { + return func(c *Context) error { + if len(mki) > 0 { + c.sendMKI = make([]byte, len(mki)) + copy(c.sendMKI, mki) + } + return nil + } +} diff --git a/srtcp.go b/srtcp.go index 86a963f..8f120b5 100644 --- a/srtcp.go +++ b/srtcp.go @@ -23,9 +23,10 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { if err != nil { return nil, err } - tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) + mkiLen := len(c.sendMKI) + tailOffset := len(encrypted) - (authTagLen + mkiLen + srtcpIndexSize) - if tailOffset < aeadAuthTagLen { + if tailOffset < aeadAuthTagLen+8 { return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(encrypted)) } else if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 { return out, nil @@ -40,7 +41,17 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { return nil, &duplicatedError{Proto: "srtcp", SSRC: ssrc, Index: index} } - out, err = c.cipher.decryptRTCP(out, encrypted, index, ssrc) + cipher := c.cipher + if len(c.mkis) > 0 { + // Find cipher for MKI + actualMKI := c.cipher.getMKI(encrypted, false) + cipher, ok = c.mkis[string(actualMKI)] + if !ok { + return nil, ErrMKINotFound + } + } + + out, err = cipher.decryptRTCP(out, encrypted, index, ssrc) if err != nil { return nil, err } diff --git a/srtcp_test.go b/srtcp_test.go index d965e62..ca1a809 100644 --- a/srtcp_test.go +++ b/srtcp_test.go @@ -618,3 +618,163 @@ func TestEncryptInvalidRTCP(t *testing.T) { _, err = decryptContext.EncryptRTCP(nil, packet, nil) assert.Error(err) } + +func TestRTCPInvalidMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + _, err = decryptContext.DecryptRTCP(nil, encrypted, nil) + if err == nil { + t.Errorf("Managed to decrypt with incorrect MKI for packet with SSRC: %d", pkt.ssrc) + } else { + assert.ErrorIs(t, err, ErrMKINotFound) + } + } + }) + } +} + +func TestRTCPHandleMultipleMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + masterKey2 := make([]byte, len(testCase.masterKey)) + copy(masterKey2, testCase.masterKey) + masterKey2[0] = ^masterKey2[0] + + encryptContext1, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + encryptContext2, err := CreateContext(masterKey2, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + err = decryptContext.AddCipherForMKI(mki2, masterKey2, testCase.masterSalt) + if err != nil { + t.Errorf("AddCipherForMKI failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted1, err := encryptContext1.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + encrypted2, err := encryptContext2.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + decrypted1, err := decryptContext.DecryptRTCP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext.DecryptRTCP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, rtcpPacket, decrypted1) + assert.Equal(t, rtcpPacket, decrypted2) + } + }) + } +} + +func TestRTCPSwitchMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + masterKey2 := make([]byte, len(testCase.masterKey)) + copy(masterKey2, testCase.masterKey) + masterKey2[0] = ^masterKey2[0] + + encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + err = encryptContext.AddCipherForMKI(mki2, masterKey2, testCase.masterSalt) + if err != nil { + t.Errorf("AddCipherForMKI failed: %v", err) + } + + decryptContext1, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + decryptContext2, err := CreateContext(masterKey2, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted1, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + err = encryptContext.SetSendMKI(mki2) + if err != nil { + t.Fatal(err) + } + + encrypted2, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + assert.NotEqual(t, encrypted1, encrypted2) + + decrypted1, err := decryptContext1.DecryptRTCP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext2.DecryptRTCP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, rtcpPacket, decrypted1) + assert.Equal(t, rtcpPacket, decrypted2) + + err = encryptContext.SetSendMKI(mki1) + if err != nil { + t.Fatal(err) + } + } + }) + } +} diff --git a/srtp.go b/srtp.go index e02ceff..ca7c3fc 100644 --- a/srtp.go +++ b/srtp.go @@ -14,7 +14,7 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL return nil, err } - if len(ciphertext) < headerLen+authTagLen { + if len(ciphertext) < headerLen+len(c.sendMKI)+authTagLen { return nil, errTooShortRTP } @@ -30,9 +30,19 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL } } - dst = growBufferSize(dst, len(ciphertext)-authTagLen) + cipher := c.cipher + if len(c.mkis) > 0 { + // Find cipher for MKI + actualMKI := c.cipher.getMKI(ciphertext, true) + cipher, ok = c.mkis[string(actualMKI)] + if !ok { + return nil, ErrMKINotFound + } + } + + dst = growBufferSize(dst, len(ciphertext)-authTagLen-len(c.sendMKI)) - dst, err = c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) + dst, err = cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) if err != nil { return nil, err } diff --git a/srtp_cipher.go b/srtp_cipher.go index 1af88df..da745e7 100644 --- a/srtp_cipher.go +++ b/srtp_cipher.go @@ -16,6 +16,7 @@ type srtpCipher interface { // See the note below. AEADAuthTagLen() (int, error) getRTCPIndex([]byte) uint32 + getMKI([]byte, bool) []byte encryptRTP([]byte, *rtp.Header, []byte, uint32) ([]byte, error) encryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error) diff --git a/srtp_cipher_aead_aes_gcm.go b/srtp_cipher_aead_aes_gcm.go index 744fbbb..e59c141 100644 --- a/srtp_cipher_aead_aes_gcm.go +++ b/srtp_cipher_aead_aes_gcm.go @@ -21,9 +21,11 @@ type srtpCipherAeadAesGcm struct { srtpCipher, srtcpCipher cipher.AEAD srtpSessionSalt, srtcpSessionSalt []byte + + mki []byte } -func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAeadAesGcm, error) { +func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt, mki []byte) (*srtpCipherAeadAesGcm, error) { s := &srtpCipherAeadAesGcm{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) @@ -62,6 +64,12 @@ func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt [] return nil, err } + mkiLen := len(mki) + if mkiLen > 0 { + s.mki = make([]byte, mkiLen) + copy(s.mki, mki) + } + return s, nil } @@ -71,7 +79,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payloa if err != nil { return nil, err } - dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) + dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen+len(s.mki)) n, err := header.MarshalTo(dst) if err != nil { @@ -80,6 +88,12 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payloa iv := s.rtpInitializationVector(header, roc) s.srtpCipher.Seal(dst[n:n], iv[:], payload, dst[:n]) + + // Add MKI after the encrypted payload + if len(s.mki) > 0 { + copy(dst[len(dst)-len(s.mki):], s.mki) + } + return dst, nil } @@ -89,17 +103,18 @@ func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.He if err != nil { return nil, err } - nDst := len(ciphertext) - authTagLen - if nDst < 0 { + nDst := len(ciphertext) - authTagLen - len(s.mki) + if nDst < headerLen { // Size of ciphertext is shorter than AEAD auth tag len. - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } dst = growBufferSize(dst, nDst) iv := s.rtpInitializationVector(header, roc) + nEnd := len(ciphertext) - len(s.mki) if _, err := s.srtpCipher.Open( - dst[headerLen:headerLen], iv[:], ciphertext[headerLen:], ciphertext[:headerLen], + dst[headerLen:headerLen], iv[:], ciphertext[headerLen:nEnd], ciphertext[:headerLen], ); err != nil { return nil, err } @@ -115,7 +130,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin } aadPos := len(decrypted) + authTagLen // Grow the given buffer to fit the output. - dst = growBufferSize(dst, aadPos+srtcpIndexSize) + dst = growBufferSize(dst, aadPos+srtcpIndexSize+len(s.mki)) iv := s.rtcpInitializationVector(srtcpIndex, ssrc) aad := s.rtcpAdditionalAuthenticatedData(decrypted, srtcpIndex) @@ -124,11 +139,12 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin copy(dst[:8], decrypted[:8]) copy(dst[aadPos:aadPos+4], aad[8:12]) + copy(dst[aadPos+4:], s.mki) return dst, nil } func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ssrc uint32) ([]byte, error) { - aadPos := len(encrypted) - srtcpIndexSize + aadPos := len(encrypted) - srtcpIndexSize - len(s.mki) // Grow the given buffer to fit the output. authTagLen, err := s.AEADAuthTagLen() if err != nil { @@ -137,7 +153,7 @@ func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ss nDst := aadPos - authTagLen if nDst < 0 { // Size of ciphertext is shorter than AEAD auth tag len. - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } dst = growBufferSize(dst, nDst) @@ -205,5 +221,15 @@ func (s *srtpCipherAeadAesGcm) rtcpAdditionalAuthenticatedData(rtcpPacket []byte } func (s *srtpCipherAeadAesGcm) getRTCPIndex(in []byte) uint32 { - return binary.BigEndian.Uint32(in[len(in)-4:]) &^ (rtcpEncryptionFlag << 24) + return binary.BigEndian.Uint32(in[len(in)-len(s.mki)-4:]) &^ (rtcpEncryptionFlag << 24) +} + +func (s *srtpCipherAeadAesGcm) getMKI(in []byte, _ bool) []byte { + mkiLen := len(s.mki) + if mkiLen == 0 { + return nil + } + + tailOffset := len(in) - mkiLen + return in[tailOffset:] } diff --git a/srtp_cipher_aes_cm_hmac_sha1.go b/srtp_cipher_aes_cm_hmac_sha1.go index 11369af..0ffe51a 100644 --- a/srtp_cipher_aes_cm_hmac_sha1.go +++ b/srtp_cipher_aes_cm_hmac_sha1.go @@ -25,9 +25,11 @@ type srtpCipherAesCmHmacSha1 struct { srtcpSessionSalt []byte srtcpSessionAuth hash.Hash srtcpBlock cipher.Block + + mki []byte } -func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAesCmHmacSha1, error) { +func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt, mki []byte) (*srtpCipherAesCmHmacSha1, error) { s := &srtpCipherAesCmHmacSha1{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { @@ -66,6 +68,13 @@ func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt s.srtcpSessionAuth = hmac.New(sha1.New, srtcpSessionAuthTag) s.srtpSessionAuth = hmac.New(sha1.New, srtpSessionAuthTag) + + mkiLen := len(mki) + if mkiLen > 0 { + s.mki = make([]byte, mkiLen) + copy(s.mki, mki) + } + return s, nil } @@ -75,7 +84,7 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay if err != nil { return nil, err } - dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) + dst = growBufferSize(dst, header.MarshalSize()+len(payload)+len(s.mki)+authTagLen) // Copy the header unencrypted. n, err := header.MarshalTo(dst) @@ -96,6 +105,12 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay return nil, err } + // Append the MKI (if used) + if len(s.mki) > 0 { + copy(dst[n:], s.mki) + n += len(s.mki) + } + // Write the auth tag to the dest. copy(dst[n:], authTag) @@ -108,8 +123,10 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp if err != nil { return nil, err } + + // Split the auth tag and the cipher text into two parts. actualTag := ciphertext[len(ciphertext)-authTagLen:] - ciphertext = ciphertext[:len(ciphertext)-authTagLen] + ciphertext = ciphertext[:len(ciphertext)-len(s.mki)-authTagLen] // Generate the auth tag we expect to see from the ciphertext. expectedTag, err := s.generateSrtpAuthTag(ciphertext, roc) @@ -120,7 +137,7 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp // See if the auth tag actually matches. // We use a constant time comparison to prevent timing attacks. if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } // Write the plaintext header to the destination buffer. @@ -148,10 +165,18 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex binary.BigEndian.PutUint32(dst[len(dst)-4:], srtcpIndex) dst[len(dst)-4] |= 0x80 + // Generate the authentication tag authTag, err := s.generateSrtcpAuthTag(dst) if err != nil { return nil, err } + + // Include the MKI if provided + if len(s.mki) > 0 { + dst = append(dst, s.mki...) + } + + // Append the auth tag at the end of the buffer return append(dst, authTag...), nil } @@ -160,20 +185,20 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc if err != nil { return nil, err } - tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) + tailOffset := len(encrypted) - (authTagLen + len(s.mki) + srtcpIndexSize) if tailOffset < 8 { return nil, errTooShortRTCP } out = out[0:tailOffset] - expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-authTagLen]) + expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-len(s.mki)-authTagLen]) if err != nil { return nil, err } actualTag := encrypted[len(encrypted)-authTagLen:] if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } counter := generateCounter(uint16(index&0xffff), index>>16, ssrc, s.srtcpSessionSalt) @@ -247,7 +272,23 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, erro func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 { authTagLen, _ := s.AuthTagRTCPLen() - tailOffset := len(in) - (authTagLen + srtcpIndexSize) + tailOffset := len(in) - (authTagLen + srtcpIndexSize + len(s.mki)) srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize] return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) } + +func (s *srtpCipherAesCmHmacSha1) getMKI(in []byte, rtp bool) []byte { + mkiLen := len(s.mki) + if mkiLen == 0 { + return nil + } + + var authTagLen int + if rtp { + authTagLen, _ = s.AuthTagRTPLen() + } else { + authTagLen, _ = s.AuthTagRTCPLen() + } + tailOffset := len(in) - (authTagLen + mkiLen) + return in[tailOffset : tailOffset+mkiLen] +} diff --git a/srtp_cipher_test.go b/srtp_cipher_test.go index 2066c94..bb2d4b9 100644 --- a/srtp_cipher_test.go +++ b/srtp_cipher_test.go @@ -13,11 +13,15 @@ type testCipher struct { profile ProtectionProfile // Protection profile masterKey []byte // Master key masterSalt []byte // Master salt + mki []byte // Master key identifier - decryptedRTPPacket []byte - encryptedRTPPacket []byte - decryptedRTCPPacket []byte - encryptedRTCPPacket []byte + decryptedRTPPacket []byte + encryptedRTPPacket []byte + encryptedRTPPacketWithMKI []byte + + decryptedRTCPPacket []byte + encryptedRTCPPacket []byte + encryptedRTCPPacketWithMKI []byte } // create array of testCiphers for each supported profile @@ -38,6 +42,21 @@ func createTestCiphers() []testCipher { 0x80, 0x00, 0x00, 0x01, 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, 0xcb, 0xd2, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xe2, 0xd8, 0xdf, 0x8f, + 0x7a, 0x75, 0xd6, 0x88, 0xc3, 0x50, 0x2e, 0xee, + 0xc2, 0xa9, 0x80, 0x66, 0x01, 0x02, 0x03, 0x04, + 0xcd, 0x7c, 0x0d, 0x09, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0x56, 0x74, 0xbf, 0x01, 0x81, 0x3d, 0xc0, 0x62, + 0xac, 0x1d, 0xf6, 0xf7, 0x5f, 0x77, 0xc6, 0x88, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x03, 0x04, + 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, + 0xcb, 0xd2, + }, }, { profile: ProtectionProfileAes128CmHmacSha1_80, @@ -55,8 +74,24 @@ func createTestCiphers() []testCipher { 0x80, 0x00, 0x00, 0x01, 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, 0xcb, 0xd2, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xe2, 0xd8, 0xdf, 0x8f, + 0x7a, 0x75, 0xd6, 0x88, 0xc3, 0x50, 0x2e, 0xee, + 0xc2, 0xa9, 0x80, 0x66, 0x01, 0x02, 0x03, 0x04, + 0xcd, 0x7c, 0x0d, 0x09, 0xca, 0x44, 0x32, 0xa5, + 0x6e, 0x3d, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0x56, 0x74, 0xbf, 0x01, 0x81, 0x3d, 0xc0, 0x62, + 0xac, 0x1d, 0xf6, 0xf7, 0x5f, 0x77, 0xc6, 0x88, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x03, 0x04, + 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, + 0xcb, 0xd2, + }, }, - { + { //nolint:dupl profile: ProtectionProfileAeadAes128Gcm, encryptedRTPPacket: []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, @@ -74,11 +109,27 @@ func createTestCiphers() []testCipher { 0xea, 0x31, 0x12, 0x3b, 0xa8, 0x8c, 0xe6, 0x1e, 0x80, 0x00, 0x00, 0x01, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xc5, 0x00, 0x2e, 0xde, + 0x04, 0xcf, 0xdd, 0x2e, 0xb9, 0x11, 0x59, 0xe0, + 0x88, 0x0a, 0xa0, 0x6e, 0xd2, 0x97, 0x68, 0x26, + 0xf7, 0x96, 0xb2, 0x01, 0xdf, 0x31, 0x31, 0xa1, + 0x27, 0xe8, 0xa3, 0x92, 0x01, 0x02, 0x03, 0x04, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0xc9, 0x8b, 0x8b, 0x5d, 0xf0, 0x39, 0x2a, 0x55, + 0x85, 0x2b, 0x6c, 0x21, 0xac, 0x8e, 0x70, 0x25, + 0xc5, 0x2c, 0x6f, 0xbe, 0xa2, 0xb3, 0xb4, 0x46, + 0xea, 0x31, 0x12, 0x3b, 0xa8, 0x8c, 0xe6, 0x1e, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x03, 0x04, + }, }, - { + { //nolint:dupl profile: ProtectionProfileAeadAes256Gcm, encryptedRTPPacket: []byte{ - 0x80, 0xf, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xaf, 0x49, 0x96, 0x8f, 0x7e, 0x9c, 0x43, 0xf8, 0x01, 0xdd, 0x0c, 0x84, 0x8b, 0x1e, 0xc9, 0xb0, 0x29, 0xcd, 0xf8, 0x5c, @@ -93,6 +144,22 @@ func createTestCiphers() []testCipher { 0xb9, 0x51, 0xb6, 0x66, 0x84, 0x24, 0xd4, 0xe2, 0x80, 0x00, 0x00, 0x01, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xaf, 0x49, 0x96, 0x8f, + 0x7e, 0x9c, 0x43, 0xf8, 0x01, 0xdd, 0x0c, 0x84, + 0x8b, 0x1e, 0xc9, 0xb0, 0x29, 0xcd, 0xf8, 0x5c, + 0xb7, 0x9a, 0x2f, 0x95, 0x60, 0xd4, 0x69, 0x75, + 0x98, 0x50, 0x77, 0x25, 0x01, 0x02, 0x03, 0x04, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0x98, 0x22, 0xba, 0x22, 0x96, 0x1c, 0x31, 0x48, + 0xe7, 0xb7, 0xec, 0x4f, 0x09, 0xf4, 0x26, 0xdc, + 0xf6, 0xb5, 0x9a, 0x75, 0xad, 0xec, 0x74, 0xfd, + 0xb9, 0x51, 0xb6, 0x66, 0x84, 0x24, 0xd4, 0xe2, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x03, 0x04, + }, }, } @@ -106,6 +173,7 @@ func createTestCiphers() []testCipher { 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0x0ac, 0xad, } + mki := []byte{0x01, 0x02, 0x03, 0x04} decryptedRTPPacket := []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, @@ -129,6 +197,7 @@ func createTestCiphers() []testCipher { } tests[k].masterKey = masterKey[:keyLen] tests[k].masterSalt = masterSalt[:saltLen] + tests[k].mki = mki tests[k].decryptedRTPPacket = decryptedRTPPacket tests[k].decryptedRTCPPacket = decryptedRTCPPacket } @@ -182,6 +251,50 @@ func TestSrtpCipher(t *testing.T) { assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) }) }) + + t.Run("Encrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTP(nil, c.decryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, c.encryptedRTPPacketWithMKI, actualEncrypted) + }) + }) + + t.Run("Decrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTP(nil, c.encryptedRTPPacketWithMKI, nil) + assert.NoError(t, err) + assert.Equal(t, c.decryptedRTPPacket, actualDecrypted) + }) + }) + + t.Run("Encrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTCP(nil, c.decryptedRTCPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, c.encryptedRTCPPacketWithMKI, actualEncrypted) + }) + }) + + t.Run("Decrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTCP(nil, c.encryptedRTCPPacketWithMKI, nil) + assert.NoError(t, err) + assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) + }) + }) }) } } diff --git a/srtp_test.go b/srtp_test.go index 8b81242..794e474 100644 --- a/srtp_test.go +++ b/srtp_test.go @@ -896,3 +896,161 @@ func TestDecryptInvalidSRTP(t *testing.T) { _, err = decryptContext.DecryptRTP(nil, packet, nil) assert.Error(err) } + +func TestRTPInvalidMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + encryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + + decryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + out, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + if _, err := decryptContext.DecryptRTP(nil, out, nil); err == nil { + t.Errorf("Managed to decrypt with incorrect MKI for packet with SeqNum: %d", testCase.sequenceNumber) + } else { + assert.ErrorIs(t, err, ErrMKINotFound) + } + } +} + +func TestRTPHandleMultipleMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + masterKey2 := []byte{0xff, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} + masterSalt2 := []byte{0xff, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} + + encryptContext1, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + + encryptContext2, err := CreateContext(masterKey2, masterSalt2, profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + decryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = decryptContext.AddCipherForMKI(mki2, masterKey2, masterSalt2) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + encrypted1, err := encryptContext1.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + encrypted2, err := encryptContext2.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + decrypted1, err := decryptContext.DecryptRTP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext.DecryptRTP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, pktRaw, decrypted1) + assert.Equal(t, pktRaw, decrypted2) + } +} + +func TestRTPSwitchMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + masterKey2 := []byte{0xff, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} + masterSalt2 := []byte{0xff, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} + + encryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = encryptContext.AddCipherForMKI(mki2, masterKey2, masterSalt2) + if err != nil { + t.Fatal(err) + } + + decryptContext1, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + decryptContext2, err := CreateContext(masterKey2, masterSalt2, profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + encrypted1, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + err = encryptContext.SetSendMKI(mki2) + if err != nil { + t.Fatal(err) + } + + encrypted2, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + assert.NotEqual(t, encrypted1, encrypted2) + + decrypted1, err := decryptContext1.DecryptRTP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext2.DecryptRTP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, pktRaw, decrypted1) + assert.Equal(t, pktRaw, decrypted2) + + err = encryptContext.SetSendMKI(mki1) + if err != nil { + t.Fatal(err) + } + } +}