Skip to content

Commit

Permalink
Add support for Master Key Indicator
Browse files Browse the repository at this point in the history
This adds support for Master Key Indicator (MKI). It is used to select
one of pre-configured SRTP/SRTCP encryption keys.

To use it, Context has to be created with MasterKeyIndicator option,
it specifies MKI for master key and salt passed to CreateContext.
Additional master keys/salts with their MKIs can be added using
AddCipherForMKI. To remove MKIs, use RemoveMKI.

All MKIs must have the same length, and use the same length of master
key and salt - they use the same crypto profile.

SRTP/SRTCP packets by default are encrypted using first key/salt/MKI.
To select other key/salt/MKI, use SetSendMKI.

key/salt/MKI used for decryption are chosen automatically, using MKI
sent in encrypted SRTP/SRTCP packet.
  • Loading branch information
[email protected] authored and sirzooro committed Jul 9, 2024
1 parent e3e6d11 commit 5149cc5
Show file tree
Hide file tree
Showing 12 changed files with 767 additions and 56 deletions.
104 changes: 78 additions & 26 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package srtp

import (
"bytes"
"fmt"

"github.com/pion/transport/v3/replaydetector"
Expand Down Expand Up @@ -56,6 +57,10 @@ type Context struct {

newSRTCPReplayDetector func() replaydetector.ReplayDetector
newSRTPReplayDetector func() replaydetector.ReplayDetector

profile ProtectionProfile
sendMKI []byte
mkis map[string]srtpCipher
}

// CreateContext creates a new SRTP Context.
Expand All @@ -66,52 +71,99 @@ type Context struct {
//
// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
keyLen, err := profile.KeyLen()
c = &Context{
srtpSSRCStates: map[uint32]*srtpSSRCState{},
srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
profile: profile,
mkis: map[string]srtpCipher{},
}

for _, o := range append(
[]ContextOption{ // Default options
SRTPNoReplayProtection(),
SRTCPNoReplayProtection(),
},
opts..., // User specified options
) {
if errOpt := o(c); errOpt != nil {
return nil, errOpt
}
}

err = c.AddCipherForMKI(c.sendMKI, masterKey, masterSalt)
if err != nil {
return nil, err
}
c.cipher = c.mkis[string(c.sendMKI)]

saltLen, err := profile.SaltLen()
return c, nil
}

// AddCipherForMKI adds new MKI with associated masker key and salt. Context must be created with MasterKeyIndicator option to enable MKI support.
func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error {
if len(mki) != len(c.sendMKI) {
return errInvalidMKILength
}
if _, ok := c.mkis[string(mki)]; ok {
return errMKIAlreadyInUse
}

keyLen, err := c.profile.KeyLen()
if err != nil {
return nil, err
return err
}

if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
} else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
saltLen, err := c.profile.SaltLen()
if err != nil {
return err
}

c = &Context{
srtpSSRCStates: map[uint32]*srtpSSRCState{},
srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
return fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
} else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
return fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
}

switch profile {
var cipher srtpCipher
switch c.profile {
case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt)
cipher, err = newSrtpCipherAeadAesGcm(c.profile, masterKey, masterSalt, mki)
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80:
c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt)
cipher, err = newSrtpCipherAesCmHmacSha1(c.profile, masterKey, masterSalt, mki)
default:
return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile)
return fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, c.profile)
}
if err != nil {
return nil, err
return err
}

for _, o := range append(
[]ContextOption{ // Default options
SRTPNoReplayProtection(),
SRTCPNoReplayProtection(),
},
opts..., // User specified options
) {
if errOpt := o(c); errOpt != nil {
return nil, errOpt
}
c.mkis[string(mki)] = cipher
return nil
}

// RemoveMKI removes one of MKIs. You cannot remove last MKI and one used for encrypting RTP/RTCP packets.
// Operation is not thread-safe, you need to provide synchronization with decrypting packets.
func (c *Context) RemoveMKI(mki []byte) error {
if _, ok := c.mkis[string(mki)]; !ok {
return ErrMKINotFound
}
if bytes.Equal(mki, c.sendMKI) {
return errMKIAlreadyInUse
}
delete(c.mkis, string(mki))
return nil
}

return c, nil
// SetSendMKI switches MKI and cipher used for encrypting RTP/RTCP packets.
// Operation is not thread-safe, you need to provide synchronization with encrypting packets.
func (c *Context) SetSendMKI(mki []byte) error {
cipher, ok := c.mkis[string(mki)]
if !ok {
return ErrMKINotFound
}
c.sendMKI = mki
c.cipher = cipher
return nil
}

// https://tools.ietf.org/html/rfc3550#appendix-A.1
Expand Down
122 changes: 122 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package srtp

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestContextROC(t *testing.T) {
Expand Down Expand Up @@ -44,3 +46,123 @@ func TestContextIndex(t *testing.T) {
t.Errorf("Index is set to 100, but returned %d", index)
}
}

func TestContextWithoutMKI(t *testing.T) {
c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR)
if err != nil {
t.Fatal(err)
}

err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 4), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.SetSendMKI(nil)
assert.NoError(t, err)

err = c.SetSendMKI(make([]byte, 0))
assert.NoError(t, err)

err = c.RemoveMKI(nil)
assert.Error(t, err)

err = c.RemoveMKI(make([]byte, 0))
assert.Error(t, err)

err = c.RemoveMKI(make([]byte, 2))
assert.Error(t, err)
}

func TestAddMKIToContextWithMKI(t *testing.T) {
mki1 := []byte{1, 2, 3, 4}
mki2 := []byte{2, 3, 4, 5}

c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}

err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 3), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(mki1, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)
}

func TestContextSetSendMKI(t *testing.T) {
mki1 := []byte{1, 2, 3, 4}
mki2 := []byte{2, 3, 4, 5}

c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}

err = c.SetSendMKI(mki1)
assert.NoError(t, err)

err = c.SetSendMKI(mki2)
assert.NoError(t, err)

err = c.SetSendMKI(make([]byte, 4))
assert.Error(t, err)
}

func TestContextRemoveMKI(t *testing.T) {
mki1 := []byte{1, 2, 3, 4}
mki2 := []byte{2, 3, 4, 5}
mki3 := []byte{3, 4, 5, 6}

c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki3, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}

err = c.RemoveMKI(make([]byte, 4))
assert.Error(t, err)

err = c.RemoveMKI(mki1)
assert.Error(t, err)

err = c.SetSendMKI(mki3)
assert.NoError(t, err)

err = c.RemoveMKI(mki1)
assert.NoError(t, err)

err = c.RemoveMKI(mki2)
assert.NoError(t, err)

err = c.RemoveMKI(mki3)
assert.Error(t, err)
}
8 changes: 7 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import (
)

var (
// ErrFailedToVerifyAuthTag is returned when decryption fails due to invalid authentication tag
ErrFailedToVerifyAuthTag = errors.New("failed to verify auth tag")
// ErrMKINotFound is returned when decryption fails due to unknown MKI value in packet
ErrMKINotFound = errors.New("MKI not found")

errDuplicated = errors.New("duplicated packet")
errShortSrtpMasterKey = errors.New("SRTP master key is not long enough")
errShortSrtpMasterSalt = errors.New("SRTP master salt is not long enough")
Expand All @@ -17,13 +22,14 @@ var (
errExporterWrongLabel = errors.New("exporter called with wrong label")
errNoConfig = errors.New("no config provided")
errNoConn = errors.New("no conn provided")
errFailedToVerifyAuthTag = errors.New("failed to verify auth tag")
errTooShortRTP = errors.New("packet is too short to be RTP packet")
errTooShortRTCP = errors.New("packet is too short to be RTCP packet")
errPayloadDiffers = errors.New("payload differs")
errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed")
errBadIVLength = errors.New("bad iv length in xorBytesCTR")
errExceededMaxPackets = errors.New("exceeded the maximum number of packets")
errMKIAlreadyInUse = errors.New("MKI already in use")
errInvalidMKILength = errors.New("invalid MKI length")

errStreamNotInited = errors.New("stream has not been inited, unable to close")
errStreamAlreadyClosed = errors.New("stream is already closed")
Expand Down
11 changes: 11 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ type nopReplayDetector struct{}
func (s *nopReplayDetector) Check(uint64) (func() bool, bool) {
return func() bool { return true }, true
}

// MasterKeyIndicator sets MKI for RTP and RTCP.
func MasterKeyIndicator(mki []byte) ContextOption {
return func(c *Context) error {
if len(mki) > 0 {
c.sendMKI = make([]byte, len(mki))
copy(c.sendMKI, mki)
}
return nil
}
}
17 changes: 14 additions & 3 deletions srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
if err != nil {
return nil, err
}
tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize)
mkiLen := len(c.sendMKI)
tailOffset := len(encrypted) - (authTagLen + mkiLen + srtcpIndexSize)

if tailOffset < aeadAuthTagLen {
if tailOffset < aeadAuthTagLen+8 {
return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(encrypted))
} else if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 {
return out, nil
Expand All @@ -40,7 +41,17 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
return nil, &duplicatedError{Proto: "srtcp", SSRC: ssrc, Index: index}
}

out, err = c.cipher.decryptRTCP(out, encrypted, index, ssrc)
cipher := c.cipher
if len(c.mkis) > 0 {
// Find cipher for MKI
actualMKI := c.cipher.getMKI(encrypted, false)
cipher, ok = c.mkis[string(actualMKI)]
if !ok {
return nil, ErrMKINotFound
}
}

out, err = cipher.decryptRTCP(out, encrypted, index, ssrc)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 5149cc5

Please sign in to comment.