diff --git a/session.go b/session.go index 2e1f4fe..8fc1b9c 100644 --- a/session.go +++ b/session.go @@ -20,12 +20,17 @@ type streamSession interface { decrypt([]byte) error } +type newStream struct { + readStream readStream + payloadType uint8 +} + type session struct { localContextMutex sync.Mutex localContext, remoteContext *Context localOptions, remoteOptions []ContextOption - newStream chan readStream + newStream chan newStream acceptStreamTimeout time.Time started chan interface{} diff --git a/session_srtcp.go b/session_srtcp.go index 13f1a95..ae90536 100644 --- a/session_srtcp.go +++ b/session_srtcp.go @@ -53,7 +53,7 @@ func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //n localOptions: localOpts, remoteOptions: remoteOpts, readStreams: map[uint32]readStream{}, - newStream: make(chan readStream), + newStream: make(chan newStream), acceptStreamTimeout: config.AcceptStreamTimeout, started: make(chan interface{}), closed: make(chan interface{}), @@ -93,17 +93,17 @@ func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) { // AcceptStream returns a stream to handle RTCP for a single SSRC func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) { - stream, ok := <-s.newStream + newStream, ok := <-s.newStream if !ok { return nil, 0, errStreamAlreadyClosed } - readStream, ok := stream.(*ReadStreamSRTCP) + readStream, ok := newStream.readStream.(*ReadStreamSRTCP) if !ok { return nil, 0, errFailedTypeAssertion } - return readStream, stream.GetSSRC(), nil + return readStream, readStream.GetSSRC(), nil } // Close ends the session @@ -172,7 +172,7 @@ func (s *SessionSRTCP) decrypt(buf []byte) error { if !s.session.acceptStreamTimeout.IsZero() { _ = s.session.nextConn.SetReadDeadline(time.Time{}) } - s.session.newStream <- r // Notify AcceptStream + s.session.newStream <- newStream{readStream: r} // Notify AcceptStream } readStream, ok := r.(*ReadStreamSRTCP) diff --git a/session_srtp.go b/session_srtp.go index e07cbe2..e19b9c7 100644 --- a/session_srtp.go +++ b/session_srtp.go @@ -54,7 +54,7 @@ func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nol localOptions: localOpts, remoteOptions: remoteOpts, readStreams: map[uint32]readStream{}, - newStream: make(chan readStream), + newStream: make(chan newStream), acceptStreamTimeout: config.AcceptStreamTimeout, started: make(chan interface{}), closed: make(chan interface{}), @@ -93,19 +93,26 @@ func (s *SessionSRTP) OpenReadStream(ssrc uint32) (*ReadStreamSRTP, error) { return nil, errFailedTypeAssertion } -// AcceptStream returns a stream to handle RTCP for a single SSRC +// AcceptStream returns a stream to handle RTP for a single SSRC func (s *SessionSRTP) AcceptStream() (*ReadStreamSRTP, uint32, error) { - stream, ok := <-s.newStream + readStream, ssrc, _, err := s.AcceptStreamWithPayloadType() + return readStream, ssrc, err +} + +// AcceptStreamWithPayloadType returns a stream to handle RTP for a single SSRC. +// It returns the payload type as well as the SSRC. +func (s *SessionSRTP) AcceptStreamWithPayloadType() (*ReadStreamSRTP, uint32, uint8, error) { + newStream, ok := <-s.newStream if !ok { - return nil, 0, errStreamAlreadyClosed + return nil, 0, 0, errStreamAlreadyClosed } - readStream, ok := stream.(*ReadStreamSRTP) + readStream, ok := newStream.readStream.(*ReadStreamSRTP) if !ok { - return nil, 0, errFailedTypeAssertion + return nil, 0, 0, errFailedTypeAssertion } - return readStream, stream.GetSSRC(), nil + return readStream, readStream.GetSSRC(), newStream.payloadType, nil } // Close ends the session @@ -178,7 +185,8 @@ func (s *SessionSRTP) decrypt(buf []byte) error { if !s.session.acceptStreamTimeout.IsZero() { _ = s.session.nextConn.SetReadDeadline(time.Time{}) } - s.session.newStream <- r // Notify AcceptStream + // notify AcceptStream + s.session.newStream <- newStream{readStream: r, payloadType: h.PayloadType} } readStream, ok := r.(*ReadStreamSRTP) diff --git a/stream_srtp.go b/stream_srtp.go index 8589700..3e3fcfa 100644 --- a/stream_srtp.go +++ b/stream_srtp.go @@ -7,6 +7,7 @@ import ( "errors" "io" "sync" + "sync/atomic" "time" "github.com/pion/rtp" @@ -27,6 +28,10 @@ type ReadStreamSRTP struct { isInited bool buffer io.ReadWriteCloser + + peekedPacket []byte + peekedPacketMu sync.Mutex + peekedPacketPresent atomic.Bool } // Used by getOrCreateReadStream @@ -74,8 +79,43 @@ func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) { return n, err } +// Peek reads the next full RTP packet from the nextConn, but queues it internally. +// The next call to Read (or the next call to Peek without a call to Read in between) +// will return the same packet again. +func (r *ReadStreamSRTP) Peek(buf []byte) (int, error) { + r.peekedPacketMu.Lock() + defer r.peekedPacketMu.Unlock() + if r.peekedPacketPresent.Load() { + return copy(buf, r.peekedPacket), nil + } + n, err := r.buffer.Read(buf) + if err != nil { + return n, err + } + if cap(r.peekedPacket) < n { + size := 1500 + if size < n { + size = n + } + r.peekedPacket = make([]byte, size) + } + r.peekedPacket = r.peekedPacket[:n] + copy(r.peekedPacket, buf) + r.peekedPacketPresent.Store(true) + return n, nil +} + // Read reads and decrypts full RTP packet from the nextConn func (r *ReadStreamSRTP) Read(buf []byte) (int, error) { + if r.peekedPacketPresent.Load() { + r.peekedPacketMu.Lock() + if r.peekedPacketPresent.Swap(false) { + n := copy(buf, r.peekedPacket) + r.peekedPacketMu.Unlock() + return n, nil + } + r.peekedPacketMu.Unlock() + } return r.buffer.Read(buf) }