Skip to content

Commit

Permalink
Make ProtectionProfile methods public
Browse files Browse the repository at this point in the history
This allows to get values like key length and use them with key
management protocols like MIKEY, instead of creating own constants
for this.

Resolves pion#258
  • Loading branch information
sirzooro committed Jul 8, 2024
1 parent e9fc319 commit e3e6d11
Show file tree
Hide file tree
Showing 15 changed files with 56 additions and 49 deletions.
4 changes: 2 additions & 2 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ type Context struct {
//
// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
keyLen, err := profile.keyLen()
keyLen, err := profile.KeyLen()
if err != nil {
return nil, err
}

saltLen, err := profile.saltLen()
saltLen, err := profile.SaltLen()
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion key_derivation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestValidSessionKeys(t *testing.T) {
t.Errorf("Session Salt % 02x does not match expected % 02x", sessionSalt, expectedSessionSalt)
}

authKeyLen, err := ProtectionProfileAes128CmHmacSha1_80.authKeyLen()
authKeyLen, err := ProtectionProfileAes128CmHmacSha1_80.AuthKeyLen()
assert.NoError(t, err)

sessionAuthTag, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, masterKey, masterSalt, 0, authKeyLen)
Expand Down
4 changes: 2 additions & 2 deletions keying.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ type KeyingMaterialExporter interface {
// extracting them from DTLS. This behavior is defined in RFC5764:
// https://tools.ietf.org/html/rfc5764
func (c *Config) ExtractSessionKeysFromDTLS(exporter KeyingMaterialExporter, isClient bool) error {
keyLen, err := c.Profile.keyLen()
keyLen, err := c.Profile.KeyLen()
if err != nil {
return err
}

saltLen, err := c.Profile.saltLen()
saltLen, err := c.Profile.SaltLen()
if err != nil {
return err
}
Expand Down
19 changes: 13 additions & 6 deletions protection_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ const (
ProtectionProfileAeadAes256Gcm ProtectionProfile = 0x0008
)

func (p ProtectionProfile) keyLen() (int, error) {
// KeyLen returns length of encryption key in bytes.
func (p ProtectionProfile) KeyLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAeadAes128Gcm:
return 16, nil
Expand All @@ -28,7 +29,8 @@ func (p ProtectionProfile) keyLen() (int, error) {
}
}

func (p ProtectionProfile) saltLen() (int, error) {
// SaltLen returns length of salt key in bytes.
func (p ProtectionProfile) SaltLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80:
return 14, nil
Expand All @@ -39,7 +41,8 @@ func (p ProtectionProfile) saltLen() (int, error) {
}
}

func (p ProtectionProfile) rtpAuthTagLen() (int, error) {
// AuthTagRTPLen returns length of RTP authentication tag in bytes for AES protection profiles. For AEAD ones it returns zero.
func (p ProtectionProfile) AuthTagRTPLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_80:
return 10, nil
Expand All @@ -52,7 +55,8 @@ func (p ProtectionProfile) rtpAuthTagLen() (int, error) {
}
}

func (p ProtectionProfile) rtcpAuthTagLen() (int, error) {
// AuthTagRTCPLen returns length of RTCP authentication tag in bytes for AES protection profiles. For AEAD ones it returns zero.
func (p ProtectionProfile) AuthTagRTCPLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80:
return 10, nil
Expand All @@ -63,7 +67,8 @@ func (p ProtectionProfile) rtcpAuthTagLen() (int, error) {
}
}

func (p ProtectionProfile) aeadAuthTagLen() (int, error) {
// AEADAuthTagLen returns length of authentication tag in bytes for AEAD protection profiles. For AES ones it returns zero.
func (p ProtectionProfile) AEADAuthTagLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80:
return 0, nil
Expand All @@ -74,7 +79,8 @@ func (p ProtectionProfile) aeadAuthTagLen() (int, error) {
}
}

func (p ProtectionProfile) authKeyLen() (int, error) {
// AuthKeyLen returns length of authentication key in bytes for AES protection profiles. For AEAD ones it returns zero.
func (p ProtectionProfile) AuthKeyLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80:
return 20, nil
Expand All @@ -85,6 +91,7 @@ func (p ProtectionProfile) authKeyLen() (int, error) {
}
}

// String returns the name of the protection profile.
func (p ProtectionProfile) String() string {
switch p {
case ProtectionProfileAes128CmHmacSha1_80:
Expand Down
4 changes: 2 additions & 2 deletions protection_profile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
func TestInvalidProtectionProfile(t *testing.T) {
var invalidProtectionProfile ProtectionProfile

_, err := invalidProtectionProfile.keyLen()
_, err := invalidProtectionProfile.KeyLen()
assert.Error(t, err)

_, err = invalidProtectionProfile.saltLen()
_, err = invalidProtectionProfile.SaltLen()
assert.Error(t, err)
}
2 changes: 1 addition & 1 deletion session_srtcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ func TestSessionSRTCPAcceptStreamTimeout(t *testing.T) {
}

func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err error) {
authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.rtcpAuthTagLen()
authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.AuthTagRTCPLen()
if err != nil {
return 0, err
}
Expand Down
4 changes: 2 additions & 2 deletions srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ const maxSRTCPIndex = 0x7FFFFFFF
func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
out := allocateIfMismatch(dst, encrypted)

authTagLen, err := c.cipher.rtcpAuthTagLen()
authTagLen, err := c.cipher.AuthTagRTCPLen()
if err != nil {
return nil, err
}
aeadAuthTagLen, err := c.cipher.aeadAuthTagLen()
aeadAuthTagLen, err := c.cipher.AEADAuthTagLen()
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions srtcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ func TestRTCPLifecycleInPlace(t *testing.T) {
testCase := testCase
t.Run(caseName, func(t *testing.T) {
assert := assert.New(t)
authTagLen, err := testCase.algo.rtcpAuthTagLen()
authTagLen, err := testCase.algo.AuthTagRTCPLen()
assert.NoError(err)

aeadAuthTagLen, err := testCase.algo.aeadAuthTagLen()
aeadAuthTagLen, err := testCase.algo.AEADAuthTagLen()
assert.NoError(err)

encryptHeader := &rtcp.Header{}
Expand Down Expand Up @@ -272,10 +272,10 @@ func TestRTCPInvalidAuthTag(t *testing.T) {
testCase := testCase
t.Run(caseName, func(t *testing.T) {
assert := assert.New(t)
authTagLen, err := testCase.algo.rtcpAuthTagLen()
authTagLen, err := testCase.algo.AuthTagRTCPLen()
assert.NoError(err)

aeadAuthTagLen, err := testCase.algo.aeadAuthTagLen()
aeadAuthTagLen, err := testCase.algo.AEADAuthTagLen()
assert.NoError(err)

decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo)
Expand Down Expand Up @@ -354,7 +354,7 @@ func TestEncryptRTCPSeparation(t *testing.T) {
encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo)
assert.NoError(err)

authTagLen, err := testCase.algo.rtcpAuthTagLen()
authTagLen, err := testCase.algo.AuthTagRTCPLen()
assert.NoError(err)

decryptContext, err := CreateContext(
Expand Down
2 changes: 1 addition & 1 deletion srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int) ([]byte, error) {
authTagLen, err := c.cipher.rtpAuthTagLen()
authTagLen, err := c.cipher.AuthTagRTPLen()
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions srtp_cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import "github.com/pion/rtp"
// cipher represents a implementation of one
// of the SRTP Specific ciphers
type srtpCipher interface {
// authTagLen returns auth key length of the cipher.
// AuthTagRTPLen/AuthTagRTCPLen return auth key length of the cipher.
// See the note below.
rtpAuthTagLen() (int, error)
rtcpAuthTagLen() (int, error)
// aeadAuthTagLen returns AEAD auth key length of the cipher.
AuthTagRTPLen() (int, error)
AuthTagRTCPLen() (int, error)
// AEADAuthTagLen returns AEAD auth key length of the cipher.
// See the note below.
aeadAuthTagLen() (int, error)
AEADAuthTagLen() (int, error)
getRTCPIndex([]byte) uint32

encryptRTP([]byte, *rtp.Header, []byte, uint32) ([]byte, error)
Expand Down
8 changes: 4 additions & 4 deletions srtp_cipher_aead_aes_gcm.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt []

func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) {
// Grow the given buffer to fit the output.
authTagLen, err := s.aeadAuthTagLen()
authTagLen, err := s.AEADAuthTagLen()
if err != nil {
return nil, err
}
Expand All @@ -85,7 +85,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payloa

func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) {
// Grow the given buffer to fit the output.
authTagLen, err := s.aeadAuthTagLen()
authTagLen, err := s.AEADAuthTagLen()
if err != nil {
return nil, err
}
Expand All @@ -109,7 +109,7 @@ func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.He
}

func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uint32, ssrc uint32) ([]byte, error) {
authTagLen, err := s.aeadAuthTagLen()
authTagLen, err := s.AEADAuthTagLen()
if err != nil {
return nil, err
}
Expand All @@ -130,7 +130,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin
func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ssrc uint32) ([]byte, error) {
aadPos := len(encrypted) - srtcpIndexSize
// Grow the given buffer to fit the output.
authTagLen, err := s.aeadAuthTagLen()
authTagLen, err := s.AEADAuthTagLen()
if err != nil {
return nil, err
}
Expand Down
14 changes: 7 additions & 7 deletions srtp_cipher_aes_cm_hmac_sha1.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt
return nil, err
}

authKeyLen, err := profile.authKeyLen()
authKeyLen, err := profile.AuthKeyLen()
if err != nil {
return nil, err
}
Expand All @@ -71,7 +71,7 @@ func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt

func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) {
// Grow the given buffer to fit the output.
authTagLen, err := s.rtpAuthTagLen()
authTagLen, err := s.AuthTagRTPLen()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -104,7 +104,7 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay

func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) {
// Split the auth tag and the cipher text into two parts.
authTagLen, err := s.rtpAuthTagLen()
authTagLen, err := s.AuthTagRTPLen()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -156,7 +156,7 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex
}

func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc uint32) ([]byte, error) {
authTagLen, err := s.rtcpAuthTagLen()
authTagLen, err := s.AuthTagRTCPLen()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -213,7 +213,7 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtpAuthTag(buf []byte, roc uint32) ([
}

// Truncate the hash to the size indicated by the profile
authTagLen, err := s.rtpAuthTagLen()
authTagLen, err := s.AuthTagRTPLen()
if err != nil {
return nil, err
}
Expand All @@ -237,7 +237,7 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, erro
if _, err := s.srtcpSessionAuth.Write(buf); err != nil {
return nil, err
}
authTagLen, err := s.rtcpAuthTagLen()
authTagLen, err := s.AuthTagRTCPLen()
if err != nil {
return nil, err
}
Expand All @@ -246,7 +246,7 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, erro
}

func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 {
authTagLen, _ := s.rtcpAuthTagLen()
authTagLen, _ := s.AuthTagRTCPLen()
tailOffset := len(in) - (authTagLen + srtcpIndexSize)
srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize]
return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31)
Expand Down
4 changes: 2 additions & 2 deletions srtp_cipher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ func createTestCiphers() []testCipher {
}

for k, v := range tests {
keyLen, err := v.profile.keyLen()
keyLen, err := v.profile.KeyLen()
if err != nil {
panic(err)
}
saltLen, err := v.profile.saltLen()
saltLen, err := v.profile.SaltLen()
if err != nil {
panic(err)
}
Expand Down
10 changes: 5 additions & 5 deletions srtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ func (tc rtpTestCase) encrypted(profile ProtectionProfile) []byte {
}

func testKeyLen(t *testing.T, profile ProtectionProfile) {
keyLen, err := profile.keyLen()
keyLen, err := profile.KeyLen()
assert.NoError(t, err)

saltLen, err := profile.saltLen()
saltLen, err := profile.SaltLen()
assert.NoError(t, err)

if _, err := CreateContext([]byte{}, make([]byte, saltLen), profile); err == nil {
Expand Down Expand Up @@ -181,11 +181,11 @@ func TestRolloverCountOverflow(t *testing.T) {
}

func buildTestContext(profile ProtectionProfile, opts ...ContextOption) (*Context, error) {
keyLen, err := profile.keyLen()
keyLen, err := profile.KeyLen()
if err != nil {
return nil, err
}
saltLen, err := profile.saltLen()
saltLen, err := profile.SaltLen()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -270,7 +270,7 @@ func rtpTestCases() []rtpTestCase {
func testRTPLifecyleNewAlloc(t *testing.T, profile ProtectionProfile) {
assert := assert.New(t)

authTagLen, err := profile.rtpAuthTagLen()
authTagLen, err := profile.AuthTagRTPLen()
assert.NoError(err)

for _, testCase := range rtpTestCases() {
Expand Down
8 changes: 4 additions & 4 deletions stream_srtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ func TestBufferFactory(t *testing.T) {
func benchmarkWrite(b *testing.B, profile ProtectionProfile, size int) {
conn := newNoopConn()

keyLen, err := profile.keyLen()
keyLen, err := profile.KeyLen()
if err != nil {
b.Fatal(err)
}
saltLen, err := profile.saltLen()
saltLen, err := profile.SaltLen()
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -147,11 +147,11 @@ func benchmarkWriteRTP(b *testing.B, profile ProtectionProfile, size int) {
closed: make(chan struct{}),
}

keyLen, err := profile.keyLen()
keyLen, err := profile.KeyLen()
if err != nil {
b.Fatal(err)
}
saltLen, err := profile.saltLen()
saltLen, err := profile.SaltLen()
if err != nil {
b.Fatal(err)
}
Expand Down

0 comments on commit e3e6d11

Please sign in to comment.