From dbc471e9cabcb699f866018373618cde5b6e393e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Fru=C5=BCy=C5=84ski?= Date: Sat, 24 Feb 2024 22:02:06 +0100 Subject: [PATCH] Make ProtectionProfile methods public This allows to get values like key length and use them with key management protocols like MIKEY, instead of creating own constants for this. Resolves #258 --- context.go | 4 ++-- key_derivation_test.go | 2 +- keying.go | 4 ++-- protection_profile.go | 19 +++++++++++++------ protection_profile_test.go | 4 ++-- session_srtcp_test.go | 2 +- srtcp.go | 4 ++-- srtcp_test.go | 10 +++++----- srtp.go | 2 +- srtp_cipher.go | 10 +++++----- srtp_cipher_aead_aes_gcm.go | 8 ++++---- srtp_cipher_aes_cm_hmac_sha1.go | 14 +++++++------- srtp_cipher_test.go | 4 ++-- srtp_test.go | 10 +++++----- stream_srtp_test.go | 8 ++++---- 15 files changed, 56 insertions(+), 49 deletions(-) diff --git a/context.go b/context.go index 27da02c..0d0ce3e 100644 --- a/context.go +++ b/context.go @@ -66,12 +66,12 @@ type Context struct { // // decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256)) func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) { - keyLen, err := profile.keyLen() + keyLen, err := profile.KeyLen() if err != nil { return nil, err } - saltLen, err := profile.saltLen() + saltLen, err := profile.SaltLen() if err != nil { return nil, err } diff --git a/key_derivation_test.go b/key_derivation_test.go index 9995e32..0b7cf18 100644 --- a/key_derivation_test.go +++ b/key_derivation_test.go @@ -32,7 +32,7 @@ func TestValidSessionKeys(t *testing.T) { t.Errorf("Session Salt % 02x does not match expected % 02x", sessionSalt, expectedSessionSalt) } - authKeyLen, err := ProtectionProfileAes128CmHmacSha1_80.authKeyLen() + authKeyLen, err := ProtectionProfileAes128CmHmacSha1_80.AuthKeyLen() assert.NoError(t, err) sessionAuthTag, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, masterKey, masterSalt, 0, authKeyLen) diff --git a/keying.go b/keying.go index c5977c3..617f4d7 100644 --- a/keying.go +++ b/keying.go @@ -14,12 +14,12 @@ type KeyingMaterialExporter interface { // extracting them from DTLS. This behavior is defined in RFC5764: // https://tools.ietf.org/html/rfc5764 func (c *Config) ExtractSessionKeysFromDTLS(exporter KeyingMaterialExporter, isClient bool) error { - keyLen, err := c.Profile.keyLen() + keyLen, err := c.Profile.KeyLen() if err != nil { return err } - saltLen, err := c.Profile.saltLen() + saltLen, err := c.Profile.SaltLen() if err != nil { return err } diff --git a/protection_profile.go b/protection_profile.go index 9f3a2d4..6271407 100644 --- a/protection_profile.go +++ b/protection_profile.go @@ -17,7 +17,8 @@ const ( ProtectionProfileAeadAes256Gcm ProtectionProfile = 0x0008 ) -func (p ProtectionProfile) keyLen() (int, error) { +// KeyLen returns length of encryption key in bytes. +func (p ProtectionProfile) KeyLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAeadAes128Gcm: return 16, nil @@ -28,7 +29,8 @@ func (p ProtectionProfile) keyLen() (int, error) { } } -func (p ProtectionProfile) saltLen() (int, error) { +// SaltLen returns length of salt key in bytes. +func (p ProtectionProfile) SaltLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: return 14, nil @@ -39,7 +41,8 @@ func (p ProtectionProfile) saltLen() (int, error) { } } -func (p ProtectionProfile) rtpAuthTagLen() (int, error) { +// AuthTagRTPLen returns length of RTP authentication tag in bytes for AES protection profiles. For AEAD ones it returns zero. +func (p ProtectionProfile) AuthTagRTPLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_80: return 10, nil @@ -52,7 +55,8 @@ func (p ProtectionProfile) rtpAuthTagLen() (int, error) { } } -func (p ProtectionProfile) rtcpAuthTagLen() (int, error) { +// AuthTagRTCPLen returns length of RTCP authentication tag in bytes for AES protection profiles. For AEAD ones it returns zero. +func (p ProtectionProfile) AuthTagRTCPLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: return 10, nil @@ -63,7 +67,8 @@ func (p ProtectionProfile) rtcpAuthTagLen() (int, error) { } } -func (p ProtectionProfile) aeadAuthTagLen() (int, error) { +// AEADAuthTagLen returns length of authentication tag in bytes for AEAD protection profiles. For AES ones it returns zero. +func (p ProtectionProfile) AEADAuthTagLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: return 0, nil @@ -74,7 +79,8 @@ func (p ProtectionProfile) aeadAuthTagLen() (int, error) { } } -func (p ProtectionProfile) authKeyLen() (int, error) { +// AuthKeyLen returns length of authentication key in bytes for AES protection profiles. For AEAD ones it returns zero. +func (p ProtectionProfile) AuthKeyLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: return 20, nil @@ -85,6 +91,7 @@ func (p ProtectionProfile) authKeyLen() (int, error) { } } +// String returns the name of the protection profile. func (p ProtectionProfile) String() string { switch p { case ProtectionProfileAes128CmHmacSha1_80: diff --git a/protection_profile_test.go b/protection_profile_test.go index b0edfd5..077e343 100644 --- a/protection_profile_test.go +++ b/protection_profile_test.go @@ -12,9 +12,9 @@ import ( func TestInvalidProtectionProfile(t *testing.T) { var invalidProtectionProfile ProtectionProfile - _, err := invalidProtectionProfile.keyLen() + _, err := invalidProtectionProfile.KeyLen() assert.Error(t, err) - _, err = invalidProtectionProfile.saltLen() + _, err = invalidProtectionProfile.SaltLen() assert.Error(t, err) } diff --git a/session_srtcp_test.go b/session_srtcp_test.go index d5dc8e6..f2f8ed6 100644 --- a/session_srtcp_test.go +++ b/session_srtcp_test.go @@ -341,7 +341,7 @@ func TestSessionSRTCPAcceptStreamTimeout(t *testing.T) { } func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err error) { - authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.rtcpAuthTagLen() + authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.AuthTagRTCPLen() if err != nil { return 0, err } diff --git a/srtcp.go b/srtcp.go index 0aae9ab..86a963f 100644 --- a/srtcp.go +++ b/srtcp.go @@ -15,11 +15,11 @@ const maxSRTCPIndex = 0x7FFFFFFF func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { out := allocateIfMismatch(dst, encrypted) - authTagLen, err := c.cipher.rtcpAuthTagLen() + authTagLen, err := c.cipher.AuthTagRTCPLen() if err != nil { return nil, err } - aeadAuthTagLen, err := c.cipher.aeadAuthTagLen() + aeadAuthTagLen, err := c.cipher.AEADAuthTagLen() if err != nil { return nil, err } diff --git a/srtcp_test.go b/srtcp_test.go index a089bfe..3018757 100644 --- a/srtcp_test.go +++ b/srtcp_test.go @@ -164,10 +164,10 @@ func TestRTCPLifecycleInPlace(t *testing.T) { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) - authTagLen, err := testCase.algo.rtcpAuthTagLen() + authTagLen, err := testCase.algo.AuthTagRTCPLen() assert.NoError(err) - aeadAuthTagLen, err := testCase.algo.aeadAuthTagLen() + aeadAuthTagLen, err := testCase.algo.AEADAuthTagLen() assert.NoError(err) encryptHeader := &rtcp.Header{} @@ -272,10 +272,10 @@ func TestRTCPInvalidAuthTag(t *testing.T) { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) - authTagLen, err := testCase.algo.rtcpAuthTagLen() + authTagLen, err := testCase.algo.AuthTagRTCPLen() assert.NoError(err) - aeadAuthTagLen, err := testCase.algo.aeadAuthTagLen() + aeadAuthTagLen, err := testCase.algo.AEADAuthTagLen() assert.NoError(err) decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) @@ -354,7 +354,7 @@ func TestEncryptRTCPSeparation(t *testing.T) { encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) assert.NoError(err) - authTagLen, err := testCase.algo.rtcpAuthTagLen() + authTagLen, err := testCase.algo.AuthTagRTCPLen() assert.NoError(err) decryptContext, err := CreateContext( diff --git a/srtp.go b/srtp.go index 3a7ed11..e02ceff 100644 --- a/srtp.go +++ b/srtp.go @@ -9,7 +9,7 @@ import ( ) func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int) ([]byte, error) { - authTagLen, err := c.cipher.rtpAuthTagLen() + authTagLen, err := c.cipher.AuthTagRTPLen() if err != nil { return nil, err } diff --git a/srtp_cipher.go b/srtp_cipher.go index db50147..1af88df 100644 --- a/srtp_cipher.go +++ b/srtp_cipher.go @@ -8,13 +8,13 @@ import "github.com/pion/rtp" // cipher represents a implementation of one // of the SRTP Specific ciphers type srtpCipher interface { - // authTagLen returns auth key length of the cipher. + // AuthTagRTPLen/AuthTagRTCPLen return auth key length of the cipher. // See the note below. - rtpAuthTagLen() (int, error) - rtcpAuthTagLen() (int, error) - // aeadAuthTagLen returns AEAD auth key length of the cipher. + AuthTagRTPLen() (int, error) + AuthTagRTCPLen() (int, error) + // AEADAuthTagLen returns AEAD auth key length of the cipher. // See the note below. - aeadAuthTagLen() (int, error) + AEADAuthTagLen() (int, error) getRTCPIndex([]byte) uint32 encryptRTP([]byte, *rtp.Header, []byte, uint32) ([]byte, error) diff --git a/srtp_cipher_aead_aes_gcm.go b/srtp_cipher_aead_aes_gcm.go index 90643d9..744fbbb 100644 --- a/srtp_cipher_aead_aes_gcm.go +++ b/srtp_cipher_aead_aes_gcm.go @@ -67,7 +67,7 @@ func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt [] func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) { // Grow the given buffer to fit the output. - authTagLen, err := s.aeadAuthTagLen() + authTagLen, err := s.AEADAuthTagLen() if err != nil { return nil, err } @@ -85,7 +85,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payloa func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) { // Grow the given buffer to fit the output. - authTagLen, err := s.aeadAuthTagLen() + authTagLen, err := s.AEADAuthTagLen() if err != nil { return nil, err } @@ -109,7 +109,7 @@ func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.He } func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uint32, ssrc uint32) ([]byte, error) { - authTagLen, err := s.aeadAuthTagLen() + authTagLen, err := s.AEADAuthTagLen() if err != nil { return nil, err } @@ -130,7 +130,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ssrc uint32) ([]byte, error) { aadPos := len(encrypted) - srtcpIndexSize // Grow the given buffer to fit the output. - authTagLen, err := s.aeadAuthTagLen() + authTagLen, err := s.AEADAuthTagLen() if err != nil { return nil, err } diff --git a/srtp_cipher_aes_cm_hmac_sha1.go b/srtp_cipher_aes_cm_hmac_sha1.go index 20e0376..11369af 100644 --- a/srtp_cipher_aes_cm_hmac_sha1.go +++ b/srtp_cipher_aes_cm_hmac_sha1.go @@ -49,7 +49,7 @@ func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt return nil, err } - authKeyLen, err := profile.authKeyLen() + authKeyLen, err := profile.AuthKeyLen() if err != nil { return nil, err } @@ -71,7 +71,7 @@ func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) { // Grow the given buffer to fit the output. - authTagLen, err := s.rtpAuthTagLen() + authTagLen, err := s.AuthTagRTPLen() if err != nil { return nil, err } @@ -104,7 +104,7 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) { // Split the auth tag and the cipher text into two parts. - authTagLen, err := s.rtpAuthTagLen() + authTagLen, err := s.AuthTagRTPLen() if err != nil { return nil, err } @@ -156,7 +156,7 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex } func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc uint32) ([]byte, error) { - authTagLen, err := s.rtcpAuthTagLen() + authTagLen, err := s.AuthTagRTCPLen() if err != nil { return nil, err } @@ -213,7 +213,7 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtpAuthTag(buf []byte, roc uint32) ([ } // Truncate the hash to the size indicated by the profile - authTagLen, err := s.rtpAuthTagLen() + authTagLen, err := s.AuthTagRTPLen() if err != nil { return nil, err } @@ -237,7 +237,7 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, erro if _, err := s.srtcpSessionAuth.Write(buf); err != nil { return nil, err } - authTagLen, err := s.rtcpAuthTagLen() + authTagLen, err := s.AuthTagRTCPLen() if err != nil { return nil, err } @@ -246,7 +246,7 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, erro } func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 { - authTagLen, _ := s.rtcpAuthTagLen() + authTagLen, _ := s.AuthTagRTCPLen() tailOffset := len(in) - (authTagLen + srtcpIndexSize) srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize] return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) diff --git a/srtp_cipher_test.go b/srtp_cipher_test.go index d45907e..2066c94 100644 --- a/srtp_cipher_test.go +++ b/srtp_cipher_test.go @@ -119,11 +119,11 @@ func createTestCiphers() []testCipher { } for k, v := range tests { - keyLen, err := v.profile.keyLen() + keyLen, err := v.profile.KeyLen() if err != nil { panic(err) } - saltLen, err := v.profile.saltLen() + saltLen, err := v.profile.SaltLen() if err != nil { panic(err) } diff --git a/srtp_test.go b/srtp_test.go index ef9bee2..34c483c 100644 --- a/srtp_test.go +++ b/srtp_test.go @@ -37,10 +37,10 @@ func (tc rtpTestCase) encrypted(profile ProtectionProfile) []byte { } func testKeyLen(t *testing.T, profile ProtectionProfile) { - keyLen, err := profile.keyLen() + keyLen, err := profile.KeyLen() assert.NoError(t, err) - saltLen, err := profile.saltLen() + saltLen, err := profile.SaltLen() assert.NoError(t, err) if _, err := CreateContext([]byte{}, make([]byte, saltLen), profile); err == nil { @@ -181,11 +181,11 @@ func TestRolloverCountOverflow(t *testing.T) { } func buildTestContext(profile ProtectionProfile, opts ...ContextOption) (*Context, error) { - keyLen, err := profile.keyLen() + keyLen, err := profile.KeyLen() if err != nil { return nil, err } - saltLen, err := profile.saltLen() + saltLen, err := profile.SaltLen() if err != nil { return nil, err } @@ -270,7 +270,7 @@ func rtpTestCases() []rtpTestCase { func testRTPLifecyleNewAlloc(t *testing.T, profile ProtectionProfile) { assert := assert.New(t) - authTagLen, err := profile.rtpAuthTagLen() + authTagLen, err := profile.AuthTagRTPLen() assert.NoError(err) for _, testCase := range rtpTestCases() { diff --git a/stream_srtp_test.go b/stream_srtp_test.go index 755c1c4..38a89b8 100644 --- a/stream_srtp_test.go +++ b/stream_srtp_test.go @@ -67,11 +67,11 @@ func TestBufferFactory(t *testing.T) { func benchmarkWrite(b *testing.B, profile ProtectionProfile, size int) { conn := newNoopConn() - keyLen, err := profile.keyLen() + keyLen, err := profile.KeyLen() if err != nil { b.Fatal(err) } - saltLen, err := profile.saltLen() + saltLen, err := profile.SaltLen() if err != nil { b.Fatal(err) } @@ -147,11 +147,11 @@ func benchmarkWriteRTP(b *testing.B, profile ProtectionProfile, size int) { closed: make(chan struct{}), } - keyLen, err := profile.keyLen() + keyLen, err := profile.KeyLen() if err != nil { b.Fatal(err) } - saltLen, err := profile.saltLen() + saltLen, err := profile.SaltLen() if err != nil { b.Fatal(err) }