From 4bbc17f9768e73f06e8178e2fe16908d138b9b40 Mon Sep 17 00:00:00 2001 From: "daniel@poradnik-webmastera.com" Date: Sat, 29 Jun 2024 21:30:57 +0200 Subject: [PATCH] Add support for Master Key Indicator This adds support for Master Key Indicator (MKI). It is used to select one of pre-configured SRTP/SRTCP encryption keys. To use it, Context has to be created with MasterKeyIndicator option, it specifies MKI for master key and salt passed to CreateContext. Additional master keys/salts with their MKIs can be added using AddCipherForMKI. To remove MKIs, use RemoveMKI. All MKIs must have the same length, and use the same length of master key and salt - they use the same crypto profile. SRTP/SRTCP packets by default are encrypted using first key/salt/MKI. To select other key/salt/MKI, use SetSendMKI. key/salt/MKI used for decryption are chosen automatically, using MKI sent in encrypted SRTP/SRTCP packet. --- context.go | 104 +++++++++++++++------ context_test.go | 122 ++++++++++++++++++++++++ errors.go | 8 +- option.go | 11 +++ srtcp.go | 17 +++- srtcp_test.go | 160 ++++++++++++++++++++++++++++++++ srtp.go | 16 +++- srtp_cipher.go | 1 + srtp_cipher_aead_aes_gcm.go | 46 +++++++-- srtp_cipher_aes_cm_hmac_sha1.go | 57 ++++++++++-- srtp_cipher_test.go | 127 +++++++++++++++++++++++-- srtp_test.go | 158 +++++++++++++++++++++++++++++++ 12 files changed, 769 insertions(+), 58 deletions(-) diff --git a/context.go b/context.go index 229527c..0a64840 100644 --- a/context.go +++ b/context.go @@ -4,6 +4,7 @@ package srtp import ( + "bytes" "fmt" "github.com/pion/transport/v3/replaydetector" @@ -56,6 +57,10 @@ type Context struct { newSRTCPReplayDetector func() replaydetector.ReplayDetector newSRTPReplayDetector func() replaydetector.ReplayDetector + + profile ProtectionProfile + sendMKI []byte + mkis map[string]srtpCipher } // CreateContext creates a new SRTP Context. @@ -66,52 +71,99 @@ type Context struct { // // decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256)) func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) { - keyLen, err := profile.KeyLen() + c = &Context{ + srtpSSRCStates: map[uint32]*srtpSSRCState{}, + srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, + profile: profile, + mkis: map[string]srtpCipher{}, + } + + for _, o := range append( + []ContextOption{ // Default options + SRTPNoReplayProtection(), + SRTCPNoReplayProtection(), + }, + opts..., // User specified options + ) { + if errOpt := o(c); errOpt != nil { + return nil, errOpt + } + } + + err = c.AddCipherForMKI(c.sendMKI, masterKey, masterSalt) if err != nil { return nil, err } + c.cipher = c.mkis[string(c.sendMKI)] - saltLen, err := profile.SaltLen() + return c, nil +} + +// AddCipherForMKI adds new MKI with associated masker key and salt. Context must be created with MasterKeyIndicator option to enable MKI support. +func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error { + if len(mki) != len(c.sendMKI) { + return errInvalidMKILength + } + if _, ok := c.mkis[string(mki)]; ok { + return errMKIAlreadyInUse + } + + keyLen, err := c.profile.KeyLen() if err != nil { - return nil, err + return err } - if masterKeyLen := len(masterKey); masterKeyLen != keyLen { - return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen) - } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen { - return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen) + saltLen, err := c.profile.SaltLen() + if err != nil { + return err } - c = &Context{ - srtpSSRCStates: map[uint32]*srtpSSRCState{}, - srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, + if masterKeyLen := len(masterKey); masterKeyLen != keyLen { + return fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen) + } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen { + return fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen) } - switch profile { + var cipher srtpCipher + switch c.profile { case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: - c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt) + cipher, err = newSrtpCipherAeadAesGcm(c.profile, masterKey, masterSalt, mki) case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: - c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt) + cipher, err = newSrtpCipherAesCmHmacSha1(c.profile, masterKey, masterSalt, mki) default: - return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile) + return fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, c.profile) } if err != nil { - return nil, err + return err } - for _, o := range append( - []ContextOption{ // Default options - SRTPNoReplayProtection(), - SRTCPNoReplayProtection(), - }, - opts..., // User specified options - ) { - if errOpt := o(c); errOpt != nil { - return nil, errOpt - } + c.mkis[string(mki)] = cipher + return nil +} + +// RemoveMKI removes one of MKIs. You cannot remove last MKI and one used for encrypting RTP/RTCP packets. +// Operation is not thread-safe, you need to provide synchronization with decrypting packets. +func (c *Context) RemoveMKI(mki []byte) error { + if _, ok := c.mkis[string(mki)]; !ok { + return ErrMKINotFound + } + if bytes.Equal(mki, c.sendMKI) { + return errMKIAlreadyInUse } + delete(c.mkis, string(mki)) + return nil +} - return c, nil +// SetSendMKI switches MKI and cipher used for encrypting RTP/RTCP packets. +// Operation is not thread-safe, you need to provide synchronization with encrypting packets. +func (c *Context) SetSendMKI(mki []byte) error { + cipher, ok := c.mkis[string(mki)] + if !ok { + return ErrMKINotFound + } + c.sendMKI = mki + c.cipher = cipher + return nil } // https://tools.ietf.org/html/rfc3550#appendix-A.1 diff --git a/context_test.go b/context_test.go index 60b078c..3df0f36 100644 --- a/context_test.go +++ b/context_test.go @@ -5,6 +5,8 @@ package srtp import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestContextROC(t *testing.T) { @@ -44,3 +46,123 @@ func TestContextIndex(t *testing.T) { t.Errorf("Index is set to 100, but returned %d", index) } } + +func TestContextWithoutMKI(t *testing.T) { + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR) + if err != nil { + t.Fatal(err) + } + + err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 4), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.SetSendMKI(nil) + assert.NoError(t, err) + + err = c.SetSendMKI(make([]byte, 0)) + assert.NoError(t, err) + + err = c.RemoveMKI(nil) + assert.Error(t, err) + + err = c.RemoveMKI(make([]byte, 0)) + assert.Error(t, err) + + err = c.RemoveMKI(make([]byte, 2)) + assert.Error(t, err) +} + +func TestAddMKIToContextWithMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 3), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(mki1, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) +} + +func TestContextSetSendMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.SetSendMKI(mki1) + assert.NoError(t, err) + + err = c.SetSendMKI(mki2) + assert.NoError(t, err) + + err = c.SetSendMKI(make([]byte, 4)) + assert.Error(t, err) +} + +func TestContextRemoveMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + mki3 := []byte{3, 4, 5, 6} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki3, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.RemoveMKI(make([]byte, 4)) + assert.Error(t, err) + + err = c.RemoveMKI(mki1) + assert.Error(t, err) + + err = c.SetSendMKI(mki3) + assert.NoError(t, err) + + err = c.RemoveMKI(mki1) + assert.NoError(t, err) + + err = c.RemoveMKI(mki2) + assert.NoError(t, err) + + err = c.RemoveMKI(mki3) + assert.Error(t, err) +} diff --git a/errors.go b/errors.go index f918b9c..6be0d1d 100644 --- a/errors.go +++ b/errors.go @@ -9,6 +9,11 @@ import ( ) var ( + // ErrFailedToVerifyAuthTag is returned when decryption fails due to invalid authentication tag + ErrFailedToVerifyAuthTag = errors.New("failed to verify auth tag") + // ErrMKINotFound is returned when decryption fails due to unknown MKI value in packet + ErrMKINotFound = errors.New("MKI not found") + errDuplicated = errors.New("duplicated packet") errShortSrtpMasterKey = errors.New("SRTP master key is not long enough") errShortSrtpMasterSalt = errors.New("SRTP master salt is not long enough") @@ -17,13 +22,14 @@ var ( errExporterWrongLabel = errors.New("exporter called with wrong label") errNoConfig = errors.New("no config provided") errNoConn = errors.New("no conn provided") - errFailedToVerifyAuthTag = errors.New("failed to verify auth tag") errTooShortRTP = errors.New("packet is too short to be RTP packet") errTooShortRTCP = errors.New("packet is too short to be RTCP packet") errPayloadDiffers = errors.New("payload differs") errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed") errBadIVLength = errors.New("bad iv length in xorBytesCTR") errExceededMaxPackets = errors.New("exceeded the maximum number of packets") + errMKIAlreadyInUse = errors.New("MKI already in use") + errInvalidMKILength = errors.New("invalid MKI length") errStreamNotInited = errors.New("stream has not been inited, unable to close") errStreamAlreadyClosed = errors.New("stream is already closed") diff --git a/option.go b/option.go index 67bdf2e..f7f011b 100644 --- a/option.go +++ b/option.go @@ -71,3 +71,14 @@ type nopReplayDetector struct{} func (s *nopReplayDetector) Check(uint64) (func() bool, bool) { return func() bool { return true }, true } + +// MasterKeyIndicator sets MKI for RTP and RTCP. +func MasterKeyIndicator(mki []byte) ContextOption { + return func(c *Context) error { + if len(mki) > 0 { + c.sendMKI = make([]byte, len(mki)) + copy(c.sendMKI, mki) + } + return nil + } +} diff --git a/srtcp.go b/srtcp.go index 86a963f..8f120b5 100644 --- a/srtcp.go +++ b/srtcp.go @@ -23,9 +23,10 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { if err != nil { return nil, err } - tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) + mkiLen := len(c.sendMKI) + tailOffset := len(encrypted) - (authTagLen + mkiLen + srtcpIndexSize) - if tailOffset < aeadAuthTagLen { + if tailOffset < aeadAuthTagLen+8 { return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(encrypted)) } else if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 { return out, nil @@ -40,7 +41,17 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { return nil, &duplicatedError{Proto: "srtcp", SSRC: ssrc, Index: index} } - out, err = c.cipher.decryptRTCP(out, encrypted, index, ssrc) + cipher := c.cipher + if len(c.mkis) > 0 { + // Find cipher for MKI + actualMKI := c.cipher.getMKI(encrypted, false) + cipher, ok = c.mkis[string(actualMKI)] + if !ok { + return nil, ErrMKINotFound + } + } + + out, err = cipher.decryptRTCP(out, encrypted, index, ssrc) if err != nil { return nil, err } diff --git a/srtcp_test.go b/srtcp_test.go index d965e62..ca1a809 100644 --- a/srtcp_test.go +++ b/srtcp_test.go @@ -618,3 +618,163 @@ func TestEncryptInvalidRTCP(t *testing.T) { _, err = decryptContext.EncryptRTCP(nil, packet, nil) assert.Error(err) } + +func TestRTCPInvalidMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + _, err = decryptContext.DecryptRTCP(nil, encrypted, nil) + if err == nil { + t.Errorf("Managed to decrypt with incorrect MKI for packet with SSRC: %d", pkt.ssrc) + } else { + assert.ErrorIs(t, err, ErrMKINotFound) + } + } + }) + } +} + +func TestRTCPHandleMultipleMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + masterKey2 := make([]byte, len(testCase.masterKey)) + copy(masterKey2, testCase.masterKey) + masterKey2[0] = ^masterKey2[0] + + encryptContext1, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + encryptContext2, err := CreateContext(masterKey2, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + err = decryptContext.AddCipherForMKI(mki2, masterKey2, testCase.masterSalt) + if err != nil { + t.Errorf("AddCipherForMKI failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted1, err := encryptContext1.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + encrypted2, err := encryptContext2.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + decrypted1, err := decryptContext.DecryptRTCP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext.DecryptRTCP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, rtcpPacket, decrypted1) + assert.Equal(t, rtcpPacket, decrypted2) + } + }) + } +} + +func TestRTCPSwitchMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + masterKey2 := make([]byte, len(testCase.masterKey)) + copy(masterKey2, testCase.masterKey) + masterKey2[0] = ^masterKey2[0] + + encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + err = encryptContext.AddCipherForMKI(mki2, masterKey2, testCase.masterSalt) + if err != nil { + t.Errorf("AddCipherForMKI failed: %v", err) + } + + decryptContext1, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + decryptContext2, err := CreateContext(masterKey2, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted1, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + err = encryptContext.SetSendMKI(mki2) + if err != nil { + t.Fatal(err) + } + + encrypted2, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + assert.NotEqual(t, encrypted1, encrypted2) + + decrypted1, err := decryptContext1.DecryptRTCP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext2.DecryptRTCP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, rtcpPacket, decrypted1) + assert.Equal(t, rtcpPacket, decrypted2) + + err = encryptContext.SetSendMKI(mki1) + if err != nil { + t.Fatal(err) + } + } + }) + } +} diff --git a/srtp.go b/srtp.go index e02ceff..ca7c3fc 100644 --- a/srtp.go +++ b/srtp.go @@ -14,7 +14,7 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL return nil, err } - if len(ciphertext) < headerLen+authTagLen { + if len(ciphertext) < headerLen+len(c.sendMKI)+authTagLen { return nil, errTooShortRTP } @@ -30,9 +30,19 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL } } - dst = growBufferSize(dst, len(ciphertext)-authTagLen) + cipher := c.cipher + if len(c.mkis) > 0 { + // Find cipher for MKI + actualMKI := c.cipher.getMKI(ciphertext, true) + cipher, ok = c.mkis[string(actualMKI)] + if !ok { + return nil, ErrMKINotFound + } + } + + dst = growBufferSize(dst, len(ciphertext)-authTagLen-len(c.sendMKI)) - dst, err = c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) + dst, err = cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) if err != nil { return nil, err } diff --git a/srtp_cipher.go b/srtp_cipher.go index 1af88df..da745e7 100644 --- a/srtp_cipher.go +++ b/srtp_cipher.go @@ -16,6 +16,7 @@ type srtpCipher interface { // See the note below. AEADAuthTagLen() (int, error) getRTCPIndex([]byte) uint32 + getMKI([]byte, bool) []byte encryptRTP([]byte, *rtp.Header, []byte, uint32) ([]byte, error) encryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error) diff --git a/srtp_cipher_aead_aes_gcm.go b/srtp_cipher_aead_aes_gcm.go index 744fbbb..e59c141 100644 --- a/srtp_cipher_aead_aes_gcm.go +++ b/srtp_cipher_aead_aes_gcm.go @@ -21,9 +21,11 @@ type srtpCipherAeadAesGcm struct { srtpCipher, srtcpCipher cipher.AEAD srtpSessionSalt, srtcpSessionSalt []byte + + mki []byte } -func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAeadAesGcm, error) { +func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt, mki []byte) (*srtpCipherAeadAesGcm, error) { s := &srtpCipherAeadAesGcm{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) @@ -62,6 +64,12 @@ func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt [] return nil, err } + mkiLen := len(mki) + if mkiLen > 0 { + s.mki = make([]byte, mkiLen) + copy(s.mki, mki) + } + return s, nil } @@ -71,7 +79,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payloa if err != nil { return nil, err } - dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) + dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen+len(s.mki)) n, err := header.MarshalTo(dst) if err != nil { @@ -80,6 +88,12 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payloa iv := s.rtpInitializationVector(header, roc) s.srtpCipher.Seal(dst[n:n], iv[:], payload, dst[:n]) + + // Add MKI after the encrypted payload + if len(s.mki) > 0 { + copy(dst[len(dst)-len(s.mki):], s.mki) + } + return dst, nil } @@ -89,17 +103,18 @@ func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.He if err != nil { return nil, err } - nDst := len(ciphertext) - authTagLen - if nDst < 0 { + nDst := len(ciphertext) - authTagLen - len(s.mki) + if nDst < headerLen { // Size of ciphertext is shorter than AEAD auth tag len. - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } dst = growBufferSize(dst, nDst) iv := s.rtpInitializationVector(header, roc) + nEnd := len(ciphertext) - len(s.mki) if _, err := s.srtpCipher.Open( - dst[headerLen:headerLen], iv[:], ciphertext[headerLen:], ciphertext[:headerLen], + dst[headerLen:headerLen], iv[:], ciphertext[headerLen:nEnd], ciphertext[:headerLen], ); err != nil { return nil, err } @@ -115,7 +130,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin } aadPos := len(decrypted) + authTagLen // Grow the given buffer to fit the output. - dst = growBufferSize(dst, aadPos+srtcpIndexSize) + dst = growBufferSize(dst, aadPos+srtcpIndexSize+len(s.mki)) iv := s.rtcpInitializationVector(srtcpIndex, ssrc) aad := s.rtcpAdditionalAuthenticatedData(decrypted, srtcpIndex) @@ -124,11 +139,12 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin copy(dst[:8], decrypted[:8]) copy(dst[aadPos:aadPos+4], aad[8:12]) + copy(dst[aadPos+4:], s.mki) return dst, nil } func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ssrc uint32) ([]byte, error) { - aadPos := len(encrypted) - srtcpIndexSize + aadPos := len(encrypted) - srtcpIndexSize - len(s.mki) // Grow the given buffer to fit the output. authTagLen, err := s.AEADAuthTagLen() if err != nil { @@ -137,7 +153,7 @@ func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ss nDst := aadPos - authTagLen if nDst < 0 { // Size of ciphertext is shorter than AEAD auth tag len. - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } dst = growBufferSize(dst, nDst) @@ -205,5 +221,15 @@ func (s *srtpCipherAeadAesGcm) rtcpAdditionalAuthenticatedData(rtcpPacket []byte } func (s *srtpCipherAeadAesGcm) getRTCPIndex(in []byte) uint32 { - return binary.BigEndian.Uint32(in[len(in)-4:]) &^ (rtcpEncryptionFlag << 24) + return binary.BigEndian.Uint32(in[len(in)-len(s.mki)-4:]) &^ (rtcpEncryptionFlag << 24) +} + +func (s *srtpCipherAeadAesGcm) getMKI(in []byte, _ bool) []byte { + mkiLen := len(s.mki) + if mkiLen == 0 { + return nil + } + + tailOffset := len(in) - mkiLen + return in[tailOffset:] } diff --git a/srtp_cipher_aes_cm_hmac_sha1.go b/srtp_cipher_aes_cm_hmac_sha1.go index 11369af..0ffe51a 100644 --- a/srtp_cipher_aes_cm_hmac_sha1.go +++ b/srtp_cipher_aes_cm_hmac_sha1.go @@ -25,9 +25,11 @@ type srtpCipherAesCmHmacSha1 struct { srtcpSessionSalt []byte srtcpSessionAuth hash.Hash srtcpBlock cipher.Block + + mki []byte } -func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAesCmHmacSha1, error) { +func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt, mki []byte) (*srtpCipherAesCmHmacSha1, error) { s := &srtpCipherAesCmHmacSha1{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { @@ -66,6 +68,13 @@ func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt s.srtcpSessionAuth = hmac.New(sha1.New, srtcpSessionAuthTag) s.srtpSessionAuth = hmac.New(sha1.New, srtpSessionAuthTag) + + mkiLen := len(mki) + if mkiLen > 0 { + s.mki = make([]byte, mkiLen) + copy(s.mki, mki) + } + return s, nil } @@ -75,7 +84,7 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay if err != nil { return nil, err } - dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) + dst = growBufferSize(dst, header.MarshalSize()+len(payload)+len(s.mki)+authTagLen) // Copy the header unencrypted. n, err := header.MarshalTo(dst) @@ -96,6 +105,12 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay return nil, err } + // Append the MKI (if used) + if len(s.mki) > 0 { + copy(dst[n:], s.mki) + n += len(s.mki) + } + // Write the auth tag to the dest. copy(dst[n:], authTag) @@ -108,8 +123,10 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp if err != nil { return nil, err } + + // Split the auth tag and the cipher text into two parts. actualTag := ciphertext[len(ciphertext)-authTagLen:] - ciphertext = ciphertext[:len(ciphertext)-authTagLen] + ciphertext = ciphertext[:len(ciphertext)-len(s.mki)-authTagLen] // Generate the auth tag we expect to see from the ciphertext. expectedTag, err := s.generateSrtpAuthTag(ciphertext, roc) @@ -120,7 +137,7 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp // See if the auth tag actually matches. // We use a constant time comparison to prevent timing attacks. if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } // Write the plaintext header to the destination buffer. @@ -148,10 +165,18 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex binary.BigEndian.PutUint32(dst[len(dst)-4:], srtcpIndex) dst[len(dst)-4] |= 0x80 + // Generate the authentication tag authTag, err := s.generateSrtcpAuthTag(dst) if err != nil { return nil, err } + + // Include the MKI if provided + if len(s.mki) > 0 { + dst = append(dst, s.mki...) + } + + // Append the auth tag at the end of the buffer return append(dst, authTag...), nil } @@ -160,20 +185,20 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc if err != nil { return nil, err } - tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) + tailOffset := len(encrypted) - (authTagLen + len(s.mki) + srtcpIndexSize) if tailOffset < 8 { return nil, errTooShortRTCP } out = out[0:tailOffset] - expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-authTagLen]) + expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-len(s.mki)-authTagLen]) if err != nil { return nil, err } actualTag := encrypted[len(encrypted)-authTagLen:] if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } counter := generateCounter(uint16(index&0xffff), index>>16, ssrc, s.srtcpSessionSalt) @@ -247,7 +272,23 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, erro func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 { authTagLen, _ := s.AuthTagRTCPLen() - tailOffset := len(in) - (authTagLen + srtcpIndexSize) + tailOffset := len(in) - (authTagLen + srtcpIndexSize + len(s.mki)) srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize] return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) } + +func (s *srtpCipherAesCmHmacSha1) getMKI(in []byte, rtp bool) []byte { + mkiLen := len(s.mki) + if mkiLen == 0 { + return nil + } + + var authTagLen int + if rtp { + authTagLen, _ = s.AuthTagRTPLen() + } else { + authTagLen, _ = s.AuthTagRTCPLen() + } + tailOffset := len(in) - (authTagLen + mkiLen) + return in[tailOffset : tailOffset+mkiLen] +} diff --git a/srtp_cipher_test.go b/srtp_cipher_test.go index 2066c94..bb2d4b9 100644 --- a/srtp_cipher_test.go +++ b/srtp_cipher_test.go @@ -13,11 +13,15 @@ type testCipher struct { profile ProtectionProfile // Protection profile masterKey []byte // Master key masterSalt []byte // Master salt + mki []byte // Master key identifier - decryptedRTPPacket []byte - encryptedRTPPacket []byte - decryptedRTCPPacket []byte - encryptedRTCPPacket []byte + decryptedRTPPacket []byte + encryptedRTPPacket []byte + encryptedRTPPacketWithMKI []byte + + decryptedRTCPPacket []byte + encryptedRTCPPacket []byte + encryptedRTCPPacketWithMKI []byte } // create array of testCiphers for each supported profile @@ -38,6 +42,21 @@ func createTestCiphers() []testCipher { 0x80, 0x00, 0x00, 0x01, 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, 0xcb, 0xd2, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xe2, 0xd8, 0xdf, 0x8f, + 0x7a, 0x75, 0xd6, 0x88, 0xc3, 0x50, 0x2e, 0xee, + 0xc2, 0xa9, 0x80, 0x66, 0x01, 0x02, 0x03, 0x04, + 0xcd, 0x7c, 0x0d, 0x09, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0x56, 0x74, 0xbf, 0x01, 0x81, 0x3d, 0xc0, 0x62, + 0xac, 0x1d, 0xf6, 0xf7, 0x5f, 0x77, 0xc6, 0x88, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x03, 0x04, + 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, + 0xcb, 0xd2, + }, }, { profile: ProtectionProfileAes128CmHmacSha1_80, @@ -55,8 +74,24 @@ func createTestCiphers() []testCipher { 0x80, 0x00, 0x00, 0x01, 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, 0xcb, 0xd2, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xe2, 0xd8, 0xdf, 0x8f, + 0x7a, 0x75, 0xd6, 0x88, 0xc3, 0x50, 0x2e, 0xee, + 0xc2, 0xa9, 0x80, 0x66, 0x01, 0x02, 0x03, 0x04, + 0xcd, 0x7c, 0x0d, 0x09, 0xca, 0x44, 0x32, 0xa5, + 0x6e, 0x3d, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0x56, 0x74, 0xbf, 0x01, 0x81, 0x3d, 0xc0, 0x62, + 0xac, 0x1d, 0xf6, 0xf7, 0x5f, 0x77, 0xc6, 0x88, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x03, 0x04, + 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, + 0xcb, 0xd2, + }, }, - { + { //nolint:dupl profile: ProtectionProfileAeadAes128Gcm, encryptedRTPPacket: []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, @@ -74,11 +109,27 @@ func createTestCiphers() []testCipher { 0xea, 0x31, 0x12, 0x3b, 0xa8, 0x8c, 0xe6, 0x1e, 0x80, 0x00, 0x00, 0x01, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xc5, 0x00, 0x2e, 0xde, + 0x04, 0xcf, 0xdd, 0x2e, 0xb9, 0x11, 0x59, 0xe0, + 0x88, 0x0a, 0xa0, 0x6e, 0xd2, 0x97, 0x68, 0x26, + 0xf7, 0x96, 0xb2, 0x01, 0xdf, 0x31, 0x31, 0xa1, + 0x27, 0xe8, 0xa3, 0x92, 0x01, 0x02, 0x03, 0x04, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0xc9, 0x8b, 0x8b, 0x5d, 0xf0, 0x39, 0x2a, 0x55, + 0x85, 0x2b, 0x6c, 0x21, 0xac, 0x8e, 0x70, 0x25, + 0xc5, 0x2c, 0x6f, 0xbe, 0xa2, 0xb3, 0xb4, 0x46, + 0xea, 0x31, 0x12, 0x3b, 0xa8, 0x8c, 0xe6, 0x1e, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x03, 0x04, + }, }, - { + { //nolint:dupl profile: ProtectionProfileAeadAes256Gcm, encryptedRTPPacket: []byte{ - 0x80, 0xf, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xaf, 0x49, 0x96, 0x8f, 0x7e, 0x9c, 0x43, 0xf8, 0x01, 0xdd, 0x0c, 0x84, 0x8b, 0x1e, 0xc9, 0xb0, 0x29, 0xcd, 0xf8, 0x5c, @@ -93,6 +144,22 @@ func createTestCiphers() []testCipher { 0xb9, 0x51, 0xb6, 0x66, 0x84, 0x24, 0xd4, 0xe2, 0x80, 0x00, 0x00, 0x01, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xaf, 0x49, 0x96, 0x8f, + 0x7e, 0x9c, 0x43, 0xf8, 0x01, 0xdd, 0x0c, 0x84, + 0x8b, 0x1e, 0xc9, 0xb0, 0x29, 0xcd, 0xf8, 0x5c, + 0xb7, 0x9a, 0x2f, 0x95, 0x60, 0xd4, 0x69, 0x75, + 0x98, 0x50, 0x77, 0x25, 0x01, 0x02, 0x03, 0x04, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0x98, 0x22, 0xba, 0x22, 0x96, 0x1c, 0x31, 0x48, + 0xe7, 0xb7, 0xec, 0x4f, 0x09, 0xf4, 0x26, 0xdc, + 0xf6, 0xb5, 0x9a, 0x75, 0xad, 0xec, 0x74, 0xfd, + 0xb9, 0x51, 0xb6, 0x66, 0x84, 0x24, 0xd4, 0xe2, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x03, 0x04, + }, }, } @@ -106,6 +173,7 @@ func createTestCiphers() []testCipher { 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0x0ac, 0xad, } + mki := []byte{0x01, 0x02, 0x03, 0x04} decryptedRTPPacket := []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, @@ -129,6 +197,7 @@ func createTestCiphers() []testCipher { } tests[k].masterKey = masterKey[:keyLen] tests[k].masterSalt = masterSalt[:saltLen] + tests[k].mki = mki tests[k].decryptedRTPPacket = decryptedRTPPacket tests[k].decryptedRTCPPacket = decryptedRTCPPacket } @@ -182,6 +251,50 @@ func TestSrtpCipher(t *testing.T) { assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) }) }) + + t.Run("Encrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTP(nil, c.decryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, c.encryptedRTPPacketWithMKI, actualEncrypted) + }) + }) + + t.Run("Decrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTP(nil, c.encryptedRTPPacketWithMKI, nil) + assert.NoError(t, err) + assert.Equal(t, c.decryptedRTPPacket, actualDecrypted) + }) + }) + + t.Run("Encrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTCP(nil, c.decryptedRTCPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, c.encryptedRTCPPacketWithMKI, actualEncrypted) + }) + }) + + t.Run("Decrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTCP(nil, c.encryptedRTCPPacketWithMKI, nil) + assert.NoError(t, err) + assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) + }) + }) }) } } diff --git a/srtp_test.go b/srtp_test.go index 8b81242..794e474 100644 --- a/srtp_test.go +++ b/srtp_test.go @@ -896,3 +896,161 @@ func TestDecryptInvalidSRTP(t *testing.T) { _, err = decryptContext.DecryptRTP(nil, packet, nil) assert.Error(err) } + +func TestRTPInvalidMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + encryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + + decryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + out, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + if _, err := decryptContext.DecryptRTP(nil, out, nil); err == nil { + t.Errorf("Managed to decrypt with incorrect MKI for packet with SeqNum: %d", testCase.sequenceNumber) + } else { + assert.ErrorIs(t, err, ErrMKINotFound) + } + } +} + +func TestRTPHandleMultipleMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + masterKey2 := []byte{0xff, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} + masterSalt2 := []byte{0xff, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} + + encryptContext1, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + + encryptContext2, err := CreateContext(masterKey2, masterSalt2, profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + decryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = decryptContext.AddCipherForMKI(mki2, masterKey2, masterSalt2) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + encrypted1, err := encryptContext1.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + encrypted2, err := encryptContext2.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + decrypted1, err := decryptContext.DecryptRTP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext.DecryptRTP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, pktRaw, decrypted1) + assert.Equal(t, pktRaw, decrypted2) + } +} + +func TestRTPSwitchMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + masterKey2 := []byte{0xff, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} + masterSalt2 := []byte{0xff, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} + + encryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = encryptContext.AddCipherForMKI(mki2, masterKey2, masterSalt2) + if err != nil { + t.Fatal(err) + } + + decryptContext1, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + decryptContext2, err := CreateContext(masterKey2, masterSalt2, profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + encrypted1, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + err = encryptContext.SetSendMKI(mki2) + if err != nil { + t.Fatal(err) + } + + encrypted2, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + assert.NotEqual(t, encrypted1, encrypted2) + + decrypted1, err := decryptContext1.DecryptRTP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext2.DecryptRTP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, pktRaw, decrypted1) + assert.Equal(t, pktRaw, decrypted2) + + err = encryptContext.SetSendMKI(mki1) + if err != nil { + t.Fatal(err) + } + } +}