Skip to content

Commit

Permalink
Add support for Master Key Indicator
Browse files Browse the repository at this point in the history
This adds support for Master Key Indicator (MKI). It is used to select
one of pre-configured SRTP/SRTCP encryption keys.

To use it, Context has to be created with MasterKeyIndicator option,
it specifies MKI for master key and salt passed to CreateContext.
Additional master keys/salts with their MKIs can be added using
AddCipherForMKI. To remove MKIs, use RemoveMKI.

All MKIs must have the same length, and use the same length of master
key and salt - they use the same crypto profile.

SRTP/SRTCP packets by default are encrypted using first key/salt/MKI.
To select other key/salt/MKI, use SetSendMKI.

key/salt/MKI used for decryption are chosen automatically, using MKI
sent in encrypted SRTP/SRTCP packet.
  • Loading branch information
[email protected] authored and sirzooro committed Jul 17, 2024
1 parent f8540ec commit 3eac369
Show file tree
Hide file tree
Showing 12 changed files with 817 additions and 62 deletions.
116 changes: 89 additions & 27 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package srtp

import (
"bytes"
"fmt"

"github.com/pion/transport/v2/replaydetector"
Expand Down Expand Up @@ -56,6 +57,11 @@ type Context struct {

newSRTCPReplayDetector func() replaydetector.ReplayDetector
newSRTPReplayDetector func() replaydetector.ReplayDetector

profile ProtectionProfile

sendMKI []byte // Master Key Identifier used for encrypting RTP/RTCP packets. Set to nil if MKI is not enabled.
mkis map[string]srtpCipher // Master Key Identifier to cipher mapping. Used for decrypting packets. Empty if MKI is not enabled.
}

// CreateContext creates a new SRTP Context.
Expand All @@ -66,52 +72,108 @@ type Context struct {
//
// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
keyLen, err := profile.KeyLen()
c = &Context{
srtpSSRCStates: map[uint32]*srtpSSRCState{},
srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
profile: profile,
mkis: map[string]srtpCipher{},
}

for _, o := range append(
[]ContextOption{ // Default options
SRTPNoReplayProtection(),
SRTCPNoReplayProtection(),
},
opts..., // User specified options
) {
if errOpt := o(c); errOpt != nil {
return nil, errOpt
}
}

c.cipher, err = c.createCipher(c.sendMKI, masterKey, masterSalt)
if err != nil {
return nil, err
}
if len(c.sendMKI) != 0 {
c.mkis[string(c.sendMKI)] = c.cipher
}

return c, nil
}

saltLen, err := profile.SaltLen()
// AddCipherForMKI adds new MKI with associated masker key and salt. Context must be created with MasterKeyIndicator option
// to enable MKI support. MKI must be unique and have the same length as the one used for creating Context.
// Operation is not thread-safe, you need to provide synchronization with decrypting packets.
func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error {
if len(c.mkis) == 0 {
return errMKIIsNotEnabled
}
if len(mki) == 0 || len(mki) != len(c.sendMKI) {
return errInvalidMKILength
}
if _, ok := c.mkis[string(mki)]; ok {
return errMKIAlreadyInUse
}

cipher, err := c.createCipher(mki, masterKey, masterSalt)
if err != nil {
return err
}
c.mkis[string(mki)] = cipher
return nil
}

func (c *Context) createCipher(mki, masterKey, masterSalt []byte) (srtpCipher, error) {
keyLen, err := c.profile.KeyLen()
if err != nil {
return nil, err
}

if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
} else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
saltLen, err := c.profile.SaltLen()
if err != nil {
return nil, err
}

c = &Context{
srtpSSRCStates: map[uint32]*srtpSSRCState{},
srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
return nil, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
} else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
return nil, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
}

switch profile {
switch c.profile {
case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt)
return newSrtpCipherAeadAesGcm(c.profile, masterKey, masterSalt, mki)
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80:
c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt)
return newSrtpCipherAesCmHmacSha1(c.profile, masterKey, masterSalt, mki)
default:
return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile)
}
if err != nil {
return nil, err
return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, c.profile)
}
}

for _, o := range append(
[]ContextOption{ // Default options
SRTPNoReplayProtection(),
SRTCPNoReplayProtection(),
},
opts..., // User specified options
) {
if errOpt := o(c); errOpt != nil {
return nil, errOpt
}
// RemoveMKI removes one of MKIs. You cannot remove last MKI and one used for encrypting RTP/RTCP packets.
// Operation is not thread-safe, you need to provide synchronization with decrypting packets.
func (c *Context) RemoveMKI(mki []byte) error {
if _, ok := c.mkis[string(mki)]; !ok {
return ErrMKINotFound
}
if bytes.Equal(mki, c.sendMKI) {
return errMKIAlreadyInUse
}
delete(c.mkis, string(mki))
return nil
}

return c, nil
// SetSendMKI switches MKI and cipher used for encrypting RTP/RTCP packets.
// Operation is not thread-safe, you need to provide synchronization with encrypting packets.
func (c *Context) SetSendMKI(mki []byte) error {
cipher, ok := c.mkis[string(mki)]
if !ok {
return ErrMKINotFound
}
c.sendMKI = mki
c.cipher = cipher
return nil
}

// https://tools.ietf.org/html/rfc3550#appendix-A.1
Expand Down
122 changes: 122 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package srtp

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestContextROC(t *testing.T) {
Expand Down Expand Up @@ -44,3 +46,123 @@ func TestContextIndex(t *testing.T) {
t.Errorf("Index is set to 100, but returned %d", index)
}
}

func TestContextWithoutMKI(t *testing.T) {
c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR)
if err != nil {
t.Fatal(err)
}

err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 4), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.SetSendMKI(nil)
assert.Error(t, err)

err = c.SetSendMKI(make([]byte, 0))
assert.Error(t, err)

err = c.RemoveMKI(nil)
assert.Error(t, err)

err = c.RemoveMKI(make([]byte, 0))
assert.Error(t, err)

err = c.RemoveMKI(make([]byte, 2))
assert.Error(t, err)
}

func TestAddMKIToContextWithMKI(t *testing.T) {
mki1 := []byte{1, 2, 3, 4}
mki2 := []byte{2, 3, 4, 5}

c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}

err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 3), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(mki1, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)
}

func TestContextSetSendMKI(t *testing.T) {
mki1 := []byte{1, 2, 3, 4}
mki2 := []byte{2, 3, 4, 5}

c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}

err = c.SetSendMKI(mki1)
assert.NoError(t, err)

err = c.SetSendMKI(mki2)
assert.NoError(t, err)

err = c.SetSendMKI(make([]byte, 4))
assert.Error(t, err)
}

func TestContextRemoveMKI(t *testing.T) {
mki1 := []byte{1, 2, 3, 4}
mki2 := []byte{2, 3, 4, 5}
mki3 := []byte{3, 4, 5, 6}

c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki3, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}

err = c.RemoveMKI(make([]byte, 4))
assert.Error(t, err)

err = c.RemoveMKI(mki1)
assert.Error(t, err)

err = c.SetSendMKI(mki3)
assert.NoError(t, err)

err = c.RemoveMKI(mki1)
assert.NoError(t, err)

err = c.RemoveMKI(mki2)
assert.NoError(t, err)

err = c.RemoveMKI(mki3)
assert.Error(t, err)
}
9 changes: 8 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import (
)

var (
// ErrFailedToVerifyAuthTag is returned when decryption fails due to invalid authentication tag
ErrFailedToVerifyAuthTag = errors.New("failed to verify auth tag")
// ErrMKINotFound is returned when decryption fails due to unknown MKI value in packet
ErrMKINotFound = errors.New("MKI not found")

errDuplicated = errors.New("duplicated packet")
errShortSrtpMasterKey = errors.New("SRTP master key is not long enough")
errShortSrtpMasterSalt = errors.New("SRTP master salt is not long enough")
Expand All @@ -17,13 +22,15 @@ var (
errExporterWrongLabel = errors.New("exporter called with wrong label")
errNoConfig = errors.New("no config provided")
errNoConn = errors.New("no conn provided")
errFailedToVerifyAuthTag = errors.New("failed to verify auth tag")
errTooShortRTP = errors.New("packet is too short to be RTP packet")
errTooShortRTCP = errors.New("packet is too short to be RTCP packet")
errPayloadDiffers = errors.New("payload differs")
errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed")
errBadIVLength = errors.New("bad iv length in xorBytesCTR")
errExceededMaxPackets = errors.New("exceeded the maximum number of packets")
errMKIAlreadyInUse = errors.New("MKI already in use")
errMKIIsNotEnabled = errors.New("MKI is not enabled")
errInvalidMKILength = errors.New("invalid MKI length")

errStreamNotInited = errors.New("stream has not been inited, unable to close")
errStreamAlreadyClosed = errors.New("stream is already closed")
Expand Down
13 changes: 13 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,16 @@ type nopReplayDetector struct{}
func (s *nopReplayDetector) Check(uint64) (func(), bool) {
return func() {}, true
}

// MasterKeyIndicator sets RTP/RTCP MKI for the initial master key. Array passed as an argument will be
// copied as-is to encrypted SRTP/SRTCP packets, so it must be of proper length and in Big Endian format.
// All MKIs added later using Context.AddCipherForMKI must have the same length as the one used here.
func MasterKeyIndicator(mki []byte) ContextOption {
return func(c *Context) error {
if len(mki) > 0 {
c.sendMKI = make([]byte, len(mki))
copy(c.sendMKI, mki)
}
return nil
}
}
15 changes: 13 additions & 2 deletions srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
if err != nil {
return nil, err
}
tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize)
mkiLen := len(c.sendMKI)
tailOffset := len(encrypted) - (authTagLen + mkiLen + srtcpIndexSize)

if tailOffset < aeadAuthTagLen {
return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(encrypted))
Expand All @@ -40,7 +41,17 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
return nil, &duplicatedError{Proto: "srtcp", SSRC: ssrc, Index: index}
}

out, err = c.cipher.decryptRTCP(out, encrypted, index, ssrc)
cipher := c.cipher
if len(c.mkis) > 0 {
// Find cipher for MKI
actualMKI := c.cipher.getMKI(encrypted, false)
cipher, ok = c.mkis[string(actualMKI)]
if !ok {
return nil, ErrMKINotFound
}
}

out, err = cipher.decryptRTCP(out, encrypted, index, ssrc)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 3eac369

Please sign in to comment.