From 3eac36937db62aa00e6ea7889c6b94184fc9f59b Mon Sep 17 00:00:00 2001 From: "daniel@poradnik-webmastera.com" Date: Sat, 29 Jun 2024 21:30:57 +0200 Subject: [PATCH] Add support for Master Key Indicator This adds support for Master Key Indicator (MKI). It is used to select one of pre-configured SRTP/SRTCP encryption keys. To use it, Context has to be created with MasterKeyIndicator option, it specifies MKI for master key and salt passed to CreateContext. Additional master keys/salts with their MKIs can be added using AddCipherForMKI. To remove MKIs, use RemoveMKI. All MKIs must have the same length, and use the same length of master key and salt - they use the same crypto profile. SRTP/SRTCP packets by default are encrypted using first key/salt/MKI. To select other key/salt/MKI, use SetSendMKI. key/salt/MKI used for decryption are chosen automatically, using MKI sent in encrypted SRTP/SRTCP packet. --- context.go | 116 ++++++++++++++++------ context_test.go | 122 +++++++++++++++++++++++ errors.go | 9 +- option.go | 13 +++ srtcp.go | 15 ++- srtcp_test.go | 160 ++++++++++++++++++++++++++++++ srtp.go | 16 ++- srtp_cipher.go | 1 + srtp_cipher_aead_aes_gcm.go | 46 +++++++-- srtp_cipher_aes_cm_hmac_sha1.go | 57 +++++++++-- srtp_cipher_test.go | 166 +++++++++++++++++++++++++++++--- srtp_test.go | 158 ++++++++++++++++++++++++++++++ 12 files changed, 817 insertions(+), 62 deletions(-) diff --git a/context.go b/context.go index ecaa0d7..8078430 100644 --- a/context.go +++ b/context.go @@ -4,6 +4,7 @@ package srtp import ( + "bytes" "fmt" "github.com/pion/transport/v2/replaydetector" @@ -56,6 +57,11 @@ type Context struct { newSRTCPReplayDetector func() replaydetector.ReplayDetector newSRTPReplayDetector func() replaydetector.ReplayDetector + + profile ProtectionProfile + + sendMKI []byte // Master Key Identifier used for encrypting RTP/RTCP packets. Set to nil if MKI is not enabled. + mkis map[string]srtpCipher // Master Key Identifier to cipher mapping. Used for decrypting packets. Empty if MKI is not enabled. } // CreateContext creates a new SRTP Context. @@ -66,52 +72,108 @@ type Context struct { // // decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256)) func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) { - keyLen, err := profile.KeyLen() + c = &Context{ + srtpSSRCStates: map[uint32]*srtpSSRCState{}, + srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, + profile: profile, + mkis: map[string]srtpCipher{}, + } + + for _, o := range append( + []ContextOption{ // Default options + SRTPNoReplayProtection(), + SRTCPNoReplayProtection(), + }, + opts..., // User specified options + ) { + if errOpt := o(c); errOpt != nil { + return nil, errOpt + } + } + + c.cipher, err = c.createCipher(c.sendMKI, masterKey, masterSalt) if err != nil { return nil, err } + if len(c.sendMKI) != 0 { + c.mkis[string(c.sendMKI)] = c.cipher + } + + return c, nil +} - saltLen, err := profile.SaltLen() +// AddCipherForMKI adds new MKI with associated masker key and salt. Context must be created with MasterKeyIndicator option +// to enable MKI support. MKI must be unique and have the same length as the one used for creating Context. +// Operation is not thread-safe, you need to provide synchronization with decrypting packets. +func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error { + if len(c.mkis) == 0 { + return errMKIIsNotEnabled + } + if len(mki) == 0 || len(mki) != len(c.sendMKI) { + return errInvalidMKILength + } + if _, ok := c.mkis[string(mki)]; ok { + return errMKIAlreadyInUse + } + + cipher, err := c.createCipher(mki, masterKey, masterSalt) + if err != nil { + return err + } + c.mkis[string(mki)] = cipher + return nil +} + +func (c *Context) createCipher(mki, masterKey, masterSalt []byte) (srtpCipher, error) { + keyLen, err := c.profile.KeyLen() if err != nil { return nil, err } - if masterKeyLen := len(masterKey); masterKeyLen != keyLen { - return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen) - } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen { - return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen) + saltLen, err := c.profile.SaltLen() + if err != nil { + return nil, err } - c = &Context{ - srtpSSRCStates: map[uint32]*srtpSSRCState{}, - srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, + if masterKeyLen := len(masterKey); masterKeyLen != keyLen { + return nil, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen) + } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen { + return nil, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen) } - switch profile { + switch c.profile { case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: - c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt) + return newSrtpCipherAeadAesGcm(c.profile, masterKey, masterSalt, mki) case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80: - c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt) + return newSrtpCipherAesCmHmacSha1(c.profile, masterKey, masterSalt, mki) default: - return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile) - } - if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, c.profile) } +} - for _, o := range append( - []ContextOption{ // Default options - SRTPNoReplayProtection(), - SRTCPNoReplayProtection(), - }, - opts..., // User specified options - ) { - if errOpt := o(c); errOpt != nil { - return nil, errOpt - } +// RemoveMKI removes one of MKIs. You cannot remove last MKI and one used for encrypting RTP/RTCP packets. +// Operation is not thread-safe, you need to provide synchronization with decrypting packets. +func (c *Context) RemoveMKI(mki []byte) error { + if _, ok := c.mkis[string(mki)]; !ok { + return ErrMKINotFound } + if bytes.Equal(mki, c.sendMKI) { + return errMKIAlreadyInUse + } + delete(c.mkis, string(mki)) + return nil +} - return c, nil +// SetSendMKI switches MKI and cipher used for encrypting RTP/RTCP packets. +// Operation is not thread-safe, you need to provide synchronization with encrypting packets. +func (c *Context) SetSendMKI(mki []byte) error { + cipher, ok := c.mkis[string(mki)] + if !ok { + return ErrMKINotFound + } + c.sendMKI = mki + c.cipher = cipher + return nil } // https://tools.ietf.org/html/rfc3550#appendix-A.1 diff --git a/context_test.go b/context_test.go index 60b078c..81b9b90 100644 --- a/context_test.go +++ b/context_test.go @@ -5,6 +5,8 @@ package srtp import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestContextROC(t *testing.T) { @@ -44,3 +46,123 @@ func TestContextIndex(t *testing.T) { t.Errorf("Index is set to 100, but returned %d", index) } } + +func TestContextWithoutMKI(t *testing.T) { + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR) + if err != nil { + t.Fatal(err) + } + + err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 4), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.SetSendMKI(nil) + assert.Error(t, err) + + err = c.SetSendMKI(make([]byte, 0)) + assert.Error(t, err) + + err = c.RemoveMKI(nil) + assert.Error(t, err) + + err = c.RemoveMKI(make([]byte, 0)) + assert.Error(t, err) + + err = c.RemoveMKI(make([]byte, 2)) + assert.Error(t, err) +} + +func TestAddMKIToContextWithMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 3), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(mki1, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) +} + +func TestContextSetSendMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.SetSendMKI(mki1) + assert.NoError(t, err) + + err = c.SetSendMKI(mki2) + assert.NoError(t, err) + + err = c.SetSendMKI(make([]byte, 4)) + assert.Error(t, err) +} + +func TestContextRemoveMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + mki3 := []byte{3, 4, 5, 6} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki3, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.RemoveMKI(make([]byte, 4)) + assert.Error(t, err) + + err = c.RemoveMKI(mki1) + assert.Error(t, err) + + err = c.SetSendMKI(mki3) + assert.NoError(t, err) + + err = c.RemoveMKI(mki1) + assert.NoError(t, err) + + err = c.RemoveMKI(mki2) + assert.NoError(t, err) + + err = c.RemoveMKI(mki3) + assert.Error(t, err) +} diff --git a/errors.go b/errors.go index f918b9c..c22653f 100644 --- a/errors.go +++ b/errors.go @@ -9,6 +9,11 @@ import ( ) var ( + // ErrFailedToVerifyAuthTag is returned when decryption fails due to invalid authentication tag + ErrFailedToVerifyAuthTag = errors.New("failed to verify auth tag") + // ErrMKINotFound is returned when decryption fails due to unknown MKI value in packet + ErrMKINotFound = errors.New("MKI not found") + errDuplicated = errors.New("duplicated packet") errShortSrtpMasterKey = errors.New("SRTP master key is not long enough") errShortSrtpMasterSalt = errors.New("SRTP master salt is not long enough") @@ -17,13 +22,15 @@ var ( errExporterWrongLabel = errors.New("exporter called with wrong label") errNoConfig = errors.New("no config provided") errNoConn = errors.New("no conn provided") - errFailedToVerifyAuthTag = errors.New("failed to verify auth tag") errTooShortRTP = errors.New("packet is too short to be RTP packet") errTooShortRTCP = errors.New("packet is too short to be RTCP packet") errPayloadDiffers = errors.New("payload differs") errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed") errBadIVLength = errors.New("bad iv length in xorBytesCTR") errExceededMaxPackets = errors.New("exceeded the maximum number of packets") + errMKIAlreadyInUse = errors.New("MKI already in use") + errMKIIsNotEnabled = errors.New("MKI is not enabled") + errInvalidMKILength = errors.New("invalid MKI length") errStreamNotInited = errors.New("stream has not been inited, unable to close") errStreamAlreadyClosed = errors.New("stream is already closed") diff --git a/option.go b/option.go index 0c75096..e5899bb 100644 --- a/option.go +++ b/option.go @@ -71,3 +71,16 @@ type nopReplayDetector struct{} func (s *nopReplayDetector) Check(uint64) (func(), bool) { return func() {}, true } + +// MasterKeyIndicator sets RTP/RTCP MKI for the initial master key. Array passed as an argument will be +// copied as-is to encrypted SRTP/SRTCP packets, so it must be of proper length and in Big Endian format. +// All MKIs added later using Context.AddCipherForMKI must have the same length as the one used here. +func MasterKeyIndicator(mki []byte) ContextOption { + return func(c *Context) error { + if len(mki) > 0 { + c.sendMKI = make([]byte, len(mki)) + copy(c.sendMKI, mki) + } + return nil + } +} diff --git a/srtcp.go b/srtcp.go index 86a963f..1359fcd 100644 --- a/srtcp.go +++ b/srtcp.go @@ -23,7 +23,8 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { if err != nil { return nil, err } - tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) + mkiLen := len(c.sendMKI) + tailOffset := len(encrypted) - (authTagLen + mkiLen + srtcpIndexSize) if tailOffset < aeadAuthTagLen { return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(encrypted)) @@ -40,7 +41,17 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { return nil, &duplicatedError{Proto: "srtcp", SSRC: ssrc, Index: index} } - out, err = c.cipher.decryptRTCP(out, encrypted, index, ssrc) + cipher := c.cipher + if len(c.mkis) > 0 { + // Find cipher for MKI + actualMKI := c.cipher.getMKI(encrypted, false) + cipher, ok = c.mkis[string(actualMKI)] + if !ok { + return nil, ErrMKINotFound + } + } + + out, err = cipher.decryptRTCP(out, encrypted, index, ssrc) if err != nil { return nil, err } diff --git a/srtcp_test.go b/srtcp_test.go index 3018757..5b0798a 100644 --- a/srtcp_test.go +++ b/srtcp_test.go @@ -618,3 +618,163 @@ func TestEncryptInvalidRTCP(t *testing.T) { _, err = decryptContext.EncryptRTCP(nil, packet, nil) assert.Error(err) } + +func TestRTCPInvalidMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + _, err = decryptContext.DecryptRTCP(nil, encrypted, nil) + if err == nil { + t.Errorf("Managed to decrypt with incorrect MKI for packet with SSRC: %d", pkt.ssrc) + } else { + assert.ErrorIs(t, err, ErrMKINotFound) + } + } + }) + } +} + +func TestRTCPHandleMultipleMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + masterKey2 := make([]byte, len(testCase.masterKey)) + copy(masterKey2, testCase.masterKey) + masterKey2[0] = ^masterKey2[0] + + encryptContext1, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + encryptContext2, err := CreateContext(masterKey2, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + err = decryptContext.AddCipherForMKI(mki2, masterKey2, testCase.masterSalt) + if err != nil { + t.Errorf("AddCipherForMKI failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted1, err := encryptContext1.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + encrypted2, err := encryptContext2.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + decrypted1, err := decryptContext.DecryptRTCP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext.DecryptRTCP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, rtcpPacket, decrypted1) + assert.Equal(t, rtcpPacket, decrypted2) + } + }) + } +} + +func TestRTCPSwitchMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + masterKey2 := make([]byte, len(testCase.masterKey)) + copy(masterKey2, testCase.masterKey) + masterKey2[0] = ^masterKey2[0] + + encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + err = encryptContext.AddCipherForMKI(mki2, masterKey2, testCase.masterSalt) + if err != nil { + t.Errorf("AddCipherForMKI failed: %v", err) + } + + decryptContext1, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + decryptContext2, err := CreateContext(masterKey2, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted1, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + err = encryptContext.SetSendMKI(mki2) + if err != nil { + t.Fatal(err) + } + + encrypted2, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + assert.NotEqual(t, encrypted1, encrypted2) + + decrypted1, err := decryptContext1.DecryptRTCP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext2.DecryptRTCP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, rtcpPacket, decrypted1) + assert.Equal(t, rtcpPacket, decrypted2) + + err = encryptContext.SetSendMKI(mki1) + if err != nil { + t.Fatal(err) + } + } + }) + } +} diff --git a/srtp.go b/srtp.go index e02ceff..ca7c3fc 100644 --- a/srtp.go +++ b/srtp.go @@ -14,7 +14,7 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL return nil, err } - if len(ciphertext) < headerLen+authTagLen { + if len(ciphertext) < headerLen+len(c.sendMKI)+authTagLen { return nil, errTooShortRTP } @@ -30,9 +30,19 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL } } - dst = growBufferSize(dst, len(ciphertext)-authTagLen) + cipher := c.cipher + if len(c.mkis) > 0 { + // Find cipher for MKI + actualMKI := c.cipher.getMKI(ciphertext, true) + cipher, ok = c.mkis[string(actualMKI)] + if !ok { + return nil, ErrMKINotFound + } + } + + dst = growBufferSize(dst, len(ciphertext)-authTagLen-len(c.sendMKI)) - dst, err = c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) + dst, err = cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) if err != nil { return nil, err } diff --git a/srtp_cipher.go b/srtp_cipher.go index 1af88df..da745e7 100644 --- a/srtp_cipher.go +++ b/srtp_cipher.go @@ -16,6 +16,7 @@ type srtpCipher interface { // See the note below. AEADAuthTagLen() (int, error) getRTCPIndex([]byte) uint32 + getMKI([]byte, bool) []byte encryptRTP([]byte, *rtp.Header, []byte, uint32) ([]byte, error) encryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error) diff --git a/srtp_cipher_aead_aes_gcm.go b/srtp_cipher_aead_aes_gcm.go index 744fbbb..e59c141 100644 --- a/srtp_cipher_aead_aes_gcm.go +++ b/srtp_cipher_aead_aes_gcm.go @@ -21,9 +21,11 @@ type srtpCipherAeadAesGcm struct { srtpCipher, srtcpCipher cipher.AEAD srtpSessionSalt, srtcpSessionSalt []byte + + mki []byte } -func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAeadAesGcm, error) { +func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt, mki []byte) (*srtpCipherAeadAesGcm, error) { s := &srtpCipherAeadAesGcm{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) @@ -62,6 +64,12 @@ func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt [] return nil, err } + mkiLen := len(mki) + if mkiLen > 0 { + s.mki = make([]byte, mkiLen) + copy(s.mki, mki) + } + return s, nil } @@ -71,7 +79,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payloa if err != nil { return nil, err } - dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) + dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen+len(s.mki)) n, err := header.MarshalTo(dst) if err != nil { @@ -80,6 +88,12 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payloa iv := s.rtpInitializationVector(header, roc) s.srtpCipher.Seal(dst[n:n], iv[:], payload, dst[:n]) + + // Add MKI after the encrypted payload + if len(s.mki) > 0 { + copy(dst[len(dst)-len(s.mki):], s.mki) + } + return dst, nil } @@ -89,17 +103,18 @@ func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.He if err != nil { return nil, err } - nDst := len(ciphertext) - authTagLen - if nDst < 0 { + nDst := len(ciphertext) - authTagLen - len(s.mki) + if nDst < headerLen { // Size of ciphertext is shorter than AEAD auth tag len. - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } dst = growBufferSize(dst, nDst) iv := s.rtpInitializationVector(header, roc) + nEnd := len(ciphertext) - len(s.mki) if _, err := s.srtpCipher.Open( - dst[headerLen:headerLen], iv[:], ciphertext[headerLen:], ciphertext[:headerLen], + dst[headerLen:headerLen], iv[:], ciphertext[headerLen:nEnd], ciphertext[:headerLen], ); err != nil { return nil, err } @@ -115,7 +130,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin } aadPos := len(decrypted) + authTagLen // Grow the given buffer to fit the output. - dst = growBufferSize(dst, aadPos+srtcpIndexSize) + dst = growBufferSize(dst, aadPos+srtcpIndexSize+len(s.mki)) iv := s.rtcpInitializationVector(srtcpIndex, ssrc) aad := s.rtcpAdditionalAuthenticatedData(decrypted, srtcpIndex) @@ -124,11 +139,12 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin copy(dst[:8], decrypted[:8]) copy(dst[aadPos:aadPos+4], aad[8:12]) + copy(dst[aadPos+4:], s.mki) return dst, nil } func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ssrc uint32) ([]byte, error) { - aadPos := len(encrypted) - srtcpIndexSize + aadPos := len(encrypted) - srtcpIndexSize - len(s.mki) // Grow the given buffer to fit the output. authTagLen, err := s.AEADAuthTagLen() if err != nil { @@ -137,7 +153,7 @@ func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ss nDst := aadPos - authTagLen if nDst < 0 { // Size of ciphertext is shorter than AEAD auth tag len. - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } dst = growBufferSize(dst, nDst) @@ -205,5 +221,15 @@ func (s *srtpCipherAeadAesGcm) rtcpAdditionalAuthenticatedData(rtcpPacket []byte } func (s *srtpCipherAeadAesGcm) getRTCPIndex(in []byte) uint32 { - return binary.BigEndian.Uint32(in[len(in)-4:]) &^ (rtcpEncryptionFlag << 24) + return binary.BigEndian.Uint32(in[len(in)-len(s.mki)-4:]) &^ (rtcpEncryptionFlag << 24) +} + +func (s *srtpCipherAeadAesGcm) getMKI(in []byte, _ bool) []byte { + mkiLen := len(s.mki) + if mkiLen == 0 { + return nil + } + + tailOffset := len(in) - mkiLen + return in[tailOffset:] } diff --git a/srtp_cipher_aes_cm_hmac_sha1.go b/srtp_cipher_aes_cm_hmac_sha1.go index 11369af..0ffe51a 100644 --- a/srtp_cipher_aes_cm_hmac_sha1.go +++ b/srtp_cipher_aes_cm_hmac_sha1.go @@ -25,9 +25,11 @@ type srtpCipherAesCmHmacSha1 struct { srtcpSessionSalt []byte srtcpSessionAuth hash.Hash srtcpBlock cipher.Block + + mki []byte } -func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAesCmHmacSha1, error) { +func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt, mki []byte) (*srtpCipherAesCmHmacSha1, error) { s := &srtpCipherAesCmHmacSha1{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { @@ -66,6 +68,13 @@ func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt s.srtcpSessionAuth = hmac.New(sha1.New, srtcpSessionAuthTag) s.srtpSessionAuth = hmac.New(sha1.New, srtpSessionAuthTag) + + mkiLen := len(mki) + if mkiLen > 0 { + s.mki = make([]byte, mkiLen) + copy(s.mki, mki) + } + return s, nil } @@ -75,7 +84,7 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay if err != nil { return nil, err } - dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) + dst = growBufferSize(dst, header.MarshalSize()+len(payload)+len(s.mki)+authTagLen) // Copy the header unencrypted. n, err := header.MarshalTo(dst) @@ -96,6 +105,12 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay return nil, err } + // Append the MKI (if used) + if len(s.mki) > 0 { + copy(dst[n:], s.mki) + n += len(s.mki) + } + // Write the auth tag to the dest. copy(dst[n:], authTag) @@ -108,8 +123,10 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp if err != nil { return nil, err } + + // Split the auth tag and the cipher text into two parts. actualTag := ciphertext[len(ciphertext)-authTagLen:] - ciphertext = ciphertext[:len(ciphertext)-authTagLen] + ciphertext = ciphertext[:len(ciphertext)-len(s.mki)-authTagLen] // Generate the auth tag we expect to see from the ciphertext. expectedTag, err := s.generateSrtpAuthTag(ciphertext, roc) @@ -120,7 +137,7 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp // See if the auth tag actually matches. // We use a constant time comparison to prevent timing attacks. if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } // Write the plaintext header to the destination buffer. @@ -148,10 +165,18 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex binary.BigEndian.PutUint32(dst[len(dst)-4:], srtcpIndex) dst[len(dst)-4] |= 0x80 + // Generate the authentication tag authTag, err := s.generateSrtcpAuthTag(dst) if err != nil { return nil, err } + + // Include the MKI if provided + if len(s.mki) > 0 { + dst = append(dst, s.mki...) + } + + // Append the auth tag at the end of the buffer return append(dst, authTag...), nil } @@ -160,20 +185,20 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc if err != nil { return nil, err } - tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) + tailOffset := len(encrypted) - (authTagLen + len(s.mki) + srtcpIndexSize) if tailOffset < 8 { return nil, errTooShortRTCP } out = out[0:tailOffset] - expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-authTagLen]) + expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-len(s.mki)-authTagLen]) if err != nil { return nil, err } actualTag := encrypted[len(encrypted)-authTagLen:] if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } counter := generateCounter(uint16(index&0xffff), index>>16, ssrc, s.srtcpSessionSalt) @@ -247,7 +272,23 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, erro func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 { authTagLen, _ := s.AuthTagRTCPLen() - tailOffset := len(in) - (authTagLen + srtcpIndexSize) + tailOffset := len(in) - (authTagLen + srtcpIndexSize + len(s.mki)) srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize] return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) } + +func (s *srtpCipherAesCmHmacSha1) getMKI(in []byte, rtp bool) []byte { + mkiLen := len(s.mki) + if mkiLen == 0 { + return nil + } + + var authTagLen int + if rtp { + authTagLen, _ = s.AuthTagRTPLen() + } else { + authTagLen, _ = s.AuthTagRTCPLen() + } + tailOffset := len(in) - (authTagLen + mkiLen) + return in[tailOffset : tailOffset+mkiLen] +} diff --git a/srtp_cipher_test.go b/srtp_cipher_test.go index fa99ab7..262b89c 100644 --- a/srtp_cipher_test.go +++ b/srtp_cipher_test.go @@ -13,17 +13,21 @@ type testCipher struct { profile ProtectionProfile // Protection profile masterKey []byte // Master key masterSalt []byte // Master salt + mki []byte // Master key identifier - decryptedRTPPacket []byte - encryptedRTPPacket []byte - decryptedRTCPPacket []byte - encryptedRTCPPacket []byte + decryptedRTPPacket []byte + encryptedRTPPacket []byte + encryptedRTPPacketWithMKI []byte + + decryptedRTCPPacket []byte + encryptedRTCPPacket []byte + encryptedRTCPPacketWithMKI []byte } // create array of testCiphers for each supported profile func createTestCiphers() []testCipher { tests := []testCipher{ - { + { //nolint:dupl profile: ProtectionProfileAes128CmHmacSha1_32, encryptedRTPPacket: []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, @@ -38,8 +42,23 @@ func createTestCiphers() []testCipher { 0x80, 0x00, 0x00, 0x01, 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, 0xcb, 0xd2, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xe2, 0xd8, 0xdf, 0x8f, + 0x7a, 0x75, 0xd6, 0x88, 0xc3, 0x50, 0x2e, 0xee, + 0xc2, 0xa9, 0x80, 0x66, 0x01, 0x02, 0x00, 0x04, + 0xcd, 0x7c, 0x0d, 0x09, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0x56, 0x74, 0xbf, 0x01, 0x81, 0x3d, 0xc0, 0x62, + 0xac, 0x1d, 0xf6, 0xf7, 0x5f, 0x77, 0xc6, 0x88, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x00, 0x04, + 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, + 0xcb, 0xd2, + }, }, - { + { //nolint:dupl profile: ProtectionProfileAes128CmHmacSha1_80, encryptedRTPPacket: []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, @@ -55,8 +74,24 @@ func createTestCiphers() []testCipher { 0x80, 0x00, 0x00, 0x01, 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, 0xcb, 0xd2, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xe2, 0xd8, 0xdf, 0x8f, + 0x7a, 0x75, 0xd6, 0x88, 0xc3, 0x50, 0x2e, 0xee, + 0xc2, 0xa9, 0x80, 0x66, 0x01, 0x02, 0x00, 0x04, + 0xcd, 0x7c, 0x0d, 0x09, 0xca, 0x44, 0x32, 0xa5, + 0x6e, 0x3d, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0x56, 0x74, 0xbf, 0x01, 0x81, 0x3d, 0xc0, 0x62, + 0xac, 0x1d, 0xf6, 0xf7, 0x5f, 0x77, 0xc6, 0x88, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x00, 0x04, + 0x3d, 0xb7, 0xa1, 0x98, 0x37, 0xff, 0x64, 0xe5, + 0xcb, 0xd2, + }, }, - { + { //nolint:dupl profile: ProtectionProfileAes256CmHmacSha1_32, encryptedRTPPacket: []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, @@ -71,8 +106,23 @@ func createTestCiphers() []testCipher { 0x80, 0x00, 0x00, 0x01, 0xbf, 0x18, 0x18, 0x2d, 0xd1, 0x18, 0x81, 0x28, 0x78, 0xb1, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xac, 0x3b, 0xca, 0x88, + 0x14, 0x37, 0x57, 0x83, 0x35, 0xc6, 0xd4, 0x57, + 0xf1, 0xc3, 0x6b, 0xa7, 0x01, 0x02, 0x00, 0x04, + 0x3d, 0x71, 0x48, 0x63, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0x97, 0x04, 0x31, 0xdc, 0x4a, 0xe6, 0xd2, 0xaf, + 0xd6, 0x54, 0xbf, 0x90, 0xf4, 0x35, 0x44, 0x9e, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x00, 0x04, + 0xbf, 0x18, 0x18, 0x2d, 0xd1, 0x18, 0x81, 0x28, + 0x78, 0xb1, + }, }, - { + { //nolint:dupl profile: ProtectionProfileAes256CmHmacSha1_80, encryptedRTPPacket: []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, @@ -88,8 +138,24 @@ func createTestCiphers() []testCipher { 0x80, 0x00, 0x00, 0x01, 0xbf, 0x18, 0x18, 0x2d, 0xd1, 0x18, 0x81, 0x28, 0x78, 0xb1, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xac, 0x3b, 0xca, 0x88, + 0x14, 0x37, 0x57, 0x83, 0x35, 0xc6, 0xd4, 0x57, + 0xf1, 0xc3, 0x6b, 0xa7, 0x01, 0x02, 0x00, 0x04, + 0x3d, 0x71, 0x48, 0x63, 0x90, 0x9b, 0xbf, 0x15, + 0xac, 0xec, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0x97, 0x04, 0x31, 0xdc, 0x4a, 0xe6, 0xd2, 0xaf, + 0xd6, 0x54, 0xbf, 0x90, 0xf4, 0x35, 0x44, 0x9e, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x00, 0x04, + 0xbf, 0x18, 0x18, 0x2d, 0xd1, 0x18, 0x81, 0x28, + 0x78, 0xb1, + }, }, - { + { //nolint:dupl profile: ProtectionProfileAeadAes128Gcm, encryptedRTPPacket: []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, @@ -107,11 +173,27 @@ func createTestCiphers() []testCipher { 0xea, 0x31, 0x12, 0x3b, 0xa8, 0x8c, 0xe6, 0x1e, 0x80, 0x00, 0x00, 0x01, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xc5, 0x00, 0x2e, 0xde, + 0x04, 0xcf, 0xdd, 0x2e, 0xb9, 0x11, 0x59, 0xe0, + 0x88, 0x0a, 0xa0, 0x6e, 0xd2, 0x97, 0x68, 0x26, + 0xf7, 0x96, 0xb2, 0x01, 0xdf, 0x31, 0x31, 0xa1, + 0x27, 0xe8, 0xa3, 0x92, 0x01, 0x02, 0x00, 0x04, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0xc9, 0x8b, 0x8b, 0x5d, 0xf0, 0x39, 0x2a, 0x55, + 0x85, 0x2b, 0x6c, 0x21, 0xac, 0x8e, 0x70, 0x25, + 0xc5, 0x2c, 0x6f, 0xbe, 0xa2, 0xb3, 0xb4, 0x46, + 0xea, 0x31, 0x12, 0x3b, 0xa8, 0x8c, 0xe6, 0x1e, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x00, 0x04, + }, }, - { + { //nolint:dupl profile: ProtectionProfileAeadAes256Gcm, encryptedRTPPacket: []byte{ - 0x80, 0xf, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xaf, 0x49, 0x96, 0x8f, 0x7e, 0x9c, 0x43, 0xf8, 0x01, 0xdd, 0x0c, 0x84, 0x8b, 0x1e, 0xc9, 0xb0, 0x29, 0xcd, 0xf8, 0x5c, @@ -126,6 +208,22 @@ func createTestCiphers() []testCipher { 0xb9, 0x51, 0xb6, 0x66, 0x84, 0x24, 0xd4, 0xe2, 0x80, 0x00, 0x00, 0x01, }, + encryptedRTPPacketWithMKI: []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xaf, 0x49, 0x96, 0x8f, + 0x7e, 0x9c, 0x43, 0xf8, 0x01, 0xdd, 0x0c, 0x84, + 0x8b, 0x1e, 0xc9, 0xb0, 0x29, 0xcd, 0xf8, 0x5c, + 0xb7, 0x9a, 0x2f, 0x95, 0x60, 0xd4, 0x69, 0x75, + 0x98, 0x50, 0x77, 0x25, 0x01, 0x02, 0x00, 0x04, + }, + encryptedRTCPPacketWithMKI: []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0x98, 0x22, 0xba, 0x22, 0x96, 0x1c, 0x31, 0x48, + 0xe7, 0xb7, 0xec, 0x4f, 0x09, 0xf4, 0x26, 0xdc, + 0xf6, 0xb5, 0x9a, 0x75, 0xad, 0xec, 0x74, 0xfd, + 0xb9, 0x51, 0xb6, 0x66, 0x84, 0x24, 0xd4, 0xe2, + 0x80, 0x00, 0x00, 0x01, 0x01, 0x02, 0x00, 0x04, + }, }, } @@ -139,6 +237,7 @@ func createTestCiphers() []testCipher { 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, } + mki := []byte{0x01, 0x02, 0x00, 0x04} decryptedRTPPacket := []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, @@ -162,6 +261,7 @@ func createTestCiphers() []testCipher { } tests[k].masterKey = masterKey[:keyLen] tests[k].masterSalt = masterSalt[:saltLen] + tests[k].mki = mki tests[k].decryptedRTPPacket = decryptedRTPPacket tests[k].decryptedRTCPPacket = decryptedRTCPPacket } @@ -215,6 +315,50 @@ func TestSrtpCipher(t *testing.T) { assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) }) }) + + t.Run("Encrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTP(nil, c.decryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, c.encryptedRTPPacketWithMKI, actualEncrypted) + }) + }) + + t.Run("Decrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTP(nil, c.encryptedRTPPacketWithMKI, nil) + assert.NoError(t, err) + assert.Equal(t, c.decryptedRTPPacket, actualDecrypted) + }) + }) + + t.Run("Encrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTCP(nil, c.decryptedRTCPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, c.encryptedRTCPPacketWithMKI, actualEncrypted) + }) + }) + + t.Run("Decrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTCP(nil, c.encryptedRTCPPacketWithMKI, nil) + assert.NoError(t, err) + assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) + }) + }) }) } } diff --git a/srtp_test.go b/srtp_test.go index 34c483c..6ff6f7e 100644 --- a/srtp_test.go +++ b/srtp_test.go @@ -896,3 +896,161 @@ func TestDecryptInvalidSRTP(t *testing.T) { _, err = decryptContext.DecryptRTP(nil, packet, nil) assert.Error(err) } + +func TestRTPInvalidMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + encryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + + decryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + out, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + if _, err := decryptContext.DecryptRTP(nil, out, nil); err == nil { + t.Errorf("Managed to decrypt with incorrect MKI for packet with SeqNum: %d", testCase.sequenceNumber) + } else { + assert.ErrorIs(t, err, ErrMKINotFound) + } + } +} + +func TestRTPHandleMultipleMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + masterKey2 := []byte{0xff, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} + masterSalt2 := []byte{0xff, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} + + encryptContext1, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + + encryptContext2, err := CreateContext(masterKey2, masterSalt2, profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + decryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = decryptContext.AddCipherForMKI(mki2, masterKey2, masterSalt2) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + encrypted1, err := encryptContext1.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + encrypted2, err := encryptContext2.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + decrypted1, err := decryptContext.DecryptRTP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext.DecryptRTP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, pktRaw, decrypted1) + assert.Equal(t, pktRaw, decrypted2) + } +} + +func TestRTPSwitchMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + masterKey2 := []byte{0xff, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} + masterSalt2 := []byte{0xff, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} + + encryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = encryptContext.AddCipherForMKI(mki2, masterKey2, masterSalt2) + if err != nil { + t.Fatal(err) + } + + decryptContext1, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + decryptContext2, err := CreateContext(masterKey2, masterSalt2, profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + encrypted1, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + err = encryptContext.SetSendMKI(mki2) + if err != nil { + t.Fatal(err) + } + + encrypted2, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + assert.NotEqual(t, encrypted1, encrypted2) + + decrypted1, err := decryptContext1.DecryptRTP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext2.DecryptRTP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, pktRaw, decrypted1) + assert.Equal(t, pktRaw, decrypted2) + + err = encryptContext.SetSendMKI(mki1) + if err != nil { + t.Fatal(err) + } + } +}