From a732e10de996b6f845aa060c85c1dafd61b2e59b Mon Sep 17 00:00:00 2001 From: cptpcrd <31829097+cptpcrd@users.noreply.github.com> Date: Wed, 29 May 2024 15:10:49 -0400 Subject: [PATCH 1/3] Add new APIs to enable simulcast probes without dropping packets --- session.go | 7 ++++++- session_srtcp.go | 10 +++++----- session_srtp.go | 24 ++++++++++++++++-------- stream_srtp.go | 25 +++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 14 deletions(-) diff --git a/session.go b/session.go index 2e1f4fe..8fc1b9c 100644 --- a/session.go +++ b/session.go @@ -20,12 +20,17 @@ type streamSession interface { decrypt([]byte) error } +type newStream struct { + readStream readStream + payloadType uint8 +} + type session struct { localContextMutex sync.Mutex localContext, remoteContext *Context localOptions, remoteOptions []ContextOption - newStream chan readStream + newStream chan newStream acceptStreamTimeout time.Time started chan interface{} diff --git a/session_srtcp.go b/session_srtcp.go index 13f1a95..ae90536 100644 --- a/session_srtcp.go +++ b/session_srtcp.go @@ -53,7 +53,7 @@ func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //n localOptions: localOpts, remoteOptions: remoteOpts, readStreams: map[uint32]readStream{}, - newStream: make(chan readStream), + newStream: make(chan newStream), acceptStreamTimeout: config.AcceptStreamTimeout, started: make(chan interface{}), closed: make(chan interface{}), @@ -93,17 +93,17 @@ func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) { // AcceptStream returns a stream to handle RTCP for a single SSRC func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) { - stream, ok := <-s.newStream + newStream, ok := <-s.newStream if !ok { return nil, 0, errStreamAlreadyClosed } - readStream, ok := stream.(*ReadStreamSRTCP) + readStream, ok := newStream.readStream.(*ReadStreamSRTCP) if !ok { return nil, 0, errFailedTypeAssertion } - return readStream, stream.GetSSRC(), nil + return readStream, readStream.GetSSRC(), nil } // Close ends the session @@ -172,7 +172,7 @@ func (s *SessionSRTCP) decrypt(buf []byte) error { if !s.session.acceptStreamTimeout.IsZero() { _ = s.session.nextConn.SetReadDeadline(time.Time{}) } - s.session.newStream <- r // Notify AcceptStream + s.session.newStream <- newStream{readStream: r} // Notify AcceptStream } readStream, ok := r.(*ReadStreamSRTCP) diff --git a/session_srtp.go b/session_srtp.go index e07cbe2..e19b9c7 100644 --- a/session_srtp.go +++ b/session_srtp.go @@ -54,7 +54,7 @@ func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nol localOptions: localOpts, remoteOptions: remoteOpts, readStreams: map[uint32]readStream{}, - newStream: make(chan readStream), + newStream: make(chan newStream), acceptStreamTimeout: config.AcceptStreamTimeout, started: make(chan interface{}), closed: make(chan interface{}), @@ -93,19 +93,26 @@ func (s *SessionSRTP) OpenReadStream(ssrc uint32) (*ReadStreamSRTP, error) { return nil, errFailedTypeAssertion } -// AcceptStream returns a stream to handle RTCP for a single SSRC +// AcceptStream returns a stream to handle RTP for a single SSRC func (s *SessionSRTP) AcceptStream() (*ReadStreamSRTP, uint32, error) { - stream, ok := <-s.newStream + readStream, ssrc, _, err := s.AcceptStreamWithPayloadType() + return readStream, ssrc, err +} + +// AcceptStreamWithPayloadType returns a stream to handle RTP for a single SSRC. +// It returns the payload type as well as the SSRC. +func (s *SessionSRTP) AcceptStreamWithPayloadType() (*ReadStreamSRTP, uint32, uint8, error) { + newStream, ok := <-s.newStream if !ok { - return nil, 0, errStreamAlreadyClosed + return nil, 0, 0, errStreamAlreadyClosed } - readStream, ok := stream.(*ReadStreamSRTP) + readStream, ok := newStream.readStream.(*ReadStreamSRTP) if !ok { - return nil, 0, errFailedTypeAssertion + return nil, 0, 0, errFailedTypeAssertion } - return readStream, stream.GetSSRC(), nil + return readStream, readStream.GetSSRC(), newStream.payloadType, nil } // Close ends the session @@ -178,7 +185,8 @@ func (s *SessionSRTP) decrypt(buf []byte) error { if !s.session.acceptStreamTimeout.IsZero() { _ = s.session.nextConn.SetReadDeadline(time.Time{}) } - s.session.newStream <- r // Notify AcceptStream + // notify AcceptStream + s.session.newStream <- newStream{readStream: r, payloadType: h.PayloadType} } readStream, ok := r.(*ReadStreamSRTP) diff --git a/stream_srtp.go b/stream_srtp.go index 8589700..861a45e 100644 --- a/stream_srtp.go +++ b/stream_srtp.go @@ -7,6 +7,7 @@ import ( "errors" "io" "sync" + "sync/atomic" "time" "github.com/pion/rtp" @@ -27,6 +28,9 @@ type ReadStreamSRTP struct { isInited bool buffer io.ReadWriteCloser + + peekedPacket atomic.Value + peekedPacketMu sync.Mutex } // Used by getOrCreateReadStream @@ -74,8 +78,29 @@ func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) { return n, err } +// Peek reads the next full RTP packet from the nextConn, but queues it internally. +// The next call to Read (or the next call to Peek without a call to Read in between) +// will return the same packet again. +func (r *ReadStreamSRTP) Peek(buf []byte) (int, error) { + r.peekedPacketMu.Lock() + defer r.peekedPacketMu.Unlock() + if pkt, ok := r.peekedPacket.Swap((*[]byte)(nil)).(*[]byte); ok && pkt != nil { + return copy(buf, *pkt), nil + } + n, err := r.buffer.Read(buf) + if err == nil { + peekedPacket := make([]byte, n) + copy(peekedPacket, buf) + r.peekedPacket.Store(&peekedPacket) + } + return n, err +} + // Read reads and decrypts full RTP packet from the nextConn func (r *ReadStreamSRTP) Read(buf []byte) (int, error) { + if pkt, ok := r.peekedPacket.Swap((*[]byte)(nil)).(*[]byte); ok && pkt != nil { + return copy(buf, *pkt), nil + } return r.buffer.Read(buf) } From d7484fbcffd2458ace26bffca60c2e90e1c34dc3 Mon Sep 17 00:00:00 2001 From: cptpcrd <31829097+cptpcrd@users.noreply.github.com> Date: Wed, 29 May 2024 16:27:11 -0400 Subject: [PATCH 2/3] Fix behavior after multiple calls to Peek --- stream_srtp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stream_srtp.go b/stream_srtp.go index 861a45e..5d4e507 100644 --- a/stream_srtp.go +++ b/stream_srtp.go @@ -84,7 +84,7 @@ func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) { func (r *ReadStreamSRTP) Peek(buf []byte) (int, error) { r.peekedPacketMu.Lock() defer r.peekedPacketMu.Unlock() - if pkt, ok := r.peekedPacket.Swap((*[]byte)(nil)).(*[]byte); ok && pkt != nil { + if pkt, ok := r.peekedPacket.Load().(*[]byte); ok && pkt != nil { return copy(buf, *pkt), nil } n, err := r.buffer.Read(buf) From f09457d39b5bba39ca794318c070997b63a937ce Mon Sep 17 00:00:00 2001 From: cptpcrd <31829097+cptpcrd@users.noreply.github.com> Date: Wed, 29 May 2024 16:56:04 -0400 Subject: [PATCH 3/3] Avoid allocating on every call to Peek --- stream_srtp.go | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/stream_srtp.go b/stream_srtp.go index 5d4e507..3e3fcfa 100644 --- a/stream_srtp.go +++ b/stream_srtp.go @@ -29,8 +29,9 @@ type ReadStreamSRTP struct { buffer io.ReadWriteCloser - peekedPacket atomic.Value - peekedPacketMu sync.Mutex + peekedPacket []byte + peekedPacketMu sync.Mutex + peekedPacketPresent atomic.Bool } // Used by getOrCreateReadStream @@ -84,22 +85,36 @@ func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) { func (r *ReadStreamSRTP) Peek(buf []byte) (int, error) { r.peekedPacketMu.Lock() defer r.peekedPacketMu.Unlock() - if pkt, ok := r.peekedPacket.Load().(*[]byte); ok && pkt != nil { - return copy(buf, *pkt), nil + if r.peekedPacketPresent.Load() { + return copy(buf, r.peekedPacket), nil } n, err := r.buffer.Read(buf) - if err == nil { - peekedPacket := make([]byte, n) - copy(peekedPacket, buf) - r.peekedPacket.Store(&peekedPacket) + if err != nil { + return n, err } - return n, err + if cap(r.peekedPacket) < n { + size := 1500 + if size < n { + size = n + } + r.peekedPacket = make([]byte, size) + } + r.peekedPacket = r.peekedPacket[:n] + copy(r.peekedPacket, buf) + r.peekedPacketPresent.Store(true) + return n, nil } // Read reads and decrypts full RTP packet from the nextConn func (r *ReadStreamSRTP) Read(buf []byte) (int, error) { - if pkt, ok := r.peekedPacket.Swap((*[]byte)(nil)).(*[]byte); ok && pkt != nil { - return copy(buf, *pkt), nil + if r.peekedPacketPresent.Load() { + r.peekedPacketMu.Lock() + if r.peekedPacketPresent.Swap(false) { + n := copy(buf, r.peekedPacket) + r.peekedPacketMu.Unlock() + return n, nil + } + r.peekedPacketMu.Unlock() } return r.buffer.Read(buf) }