diff --git a/pkg/conference/publisher/publisher.go b/pkg/conference/publisher/publisher.go index efad3ad..7e7789e 100644 --- a/pkg/conference/publisher/publisher.go +++ b/pkg/conference/publisher/publisher.go @@ -4,6 +4,7 @@ import ( "errors" "io" "sync" + "time" "github.com/pion/rtp" "github.com/sirupsen/logrus" @@ -30,31 +31,37 @@ type Publisher struct { subscriptions map[Subscription]struct{} } +// Starts a new publisher, returns a publisher along with the channel that informs the caller +// about the status update of the publisher (i.e. stalled, or active). Once the channel is closed, +// the publisher can be considered stopped. func NewPublisher( track Track, stop <-chan struct{}, log *logrus.Entry, -) (*Publisher, <-chan struct{}) { - // Create a done channel, so that we can signal the caller when we're done. - done := make(chan struct{}) - +) (*Publisher, <-chan Status) { publisher := &Publisher{ logger: log, track: track, subscriptions: make(map[Subscription]struct{}), } + // Start an observer that expects us to inform it every time we receive a packet. + // When no packets are received for N seconds, the observer will report the stalled status. + observer := newStatusObserver(2 * time.Second) + // Start a goroutine that will read RTP packets from the remote track. // We run the publisher until we receive a stop signal or an error occurs. go func() { - defer close(done) + defer observer.stop() + reportFrameReceived := func() { observer.packetArrived() } + for { // Check if we were signaled to stop. select { case <-stop: return default: - if err := publisher.forwardPacket(); err != nil { + if err := publisher.forwardPacket(reportFrameReceived); err != nil { logStoppedFn := log.Infof if err != io.EOF { logStoppedFn = log.Errorf @@ -67,7 +74,7 @@ func NewPublisher( } }() - return publisher, done + return publisher, observer.statusCh } func (p *Publisher) AddSubscription(subscription Subscription) { @@ -100,7 +107,9 @@ func (p *Publisher) ReplaceTrack(track Track) { } // Reads a single packet from the remote track and forwards it to all subscribers. -func (p *Publisher) forwardPacket() error { +// The function stops when the remote track is closed or an error occurs when reading. +// Each time new packet is received, the provided callback is called. +func (p *Publisher) forwardPacket(reportFrameReceived func()) error { track := p.GetTrack() packet, err := track.ReadPacket() @@ -108,6 +117,9 @@ func (p *Publisher) forwardPacket() error { return err } + // Inform the observer that we received a packet. + reportFrameReceived() + p.mu.Lock() defer p.mu.Unlock() diff --git a/pkg/conference/publisher/status.go b/pkg/conference/publisher/status.go new file mode 100644 index 0000000..07da1e5 --- /dev/null +++ b/pkg/conference/publisher/status.go @@ -0,0 +1,56 @@ +package publisher + +import ( + "sync/atomic" + "time" + + "github.com/matrix-org/waterfall/pkg/worker" +) + +type Status int + +const ( + StatusStalled Status = iota + 1 + StatusRecovered +) + +// `statusObserver` is a helper that observes the status of the publisher. +// Essentially it's a simple worker that expects to be informed about new packet +// arrivals. If no packets are received for N seconds, the worker will report the +// stalled status over the `statusCh` channel. Likewise, it'll update the status to +// recovered once a new packet is received. +type statusObserver struct { + worker *worker.Worker[struct{}] + statusCh chan Status + stalled atomic.Bool +} + +func newStatusObserver(timeout time.Duration) *statusObserver { + statusCh := make(chan Status, 1) + stalled := atomic.Bool{} + + worker := worker.StartWorker(worker.Config[struct{}]{ + ChannelSize: 1, + Timeout: timeout, + OnTimeout: func() { + stalled.Store(true) + statusCh <- StatusStalled + }, + OnTask: func(struct{}) { + if stalled.CompareAndSwap(true, false) { + statusCh <- StatusRecovered + } + }, + }) + + return &statusObserver{worker, statusCh, stalled} +} + +func (o *statusObserver) packetArrived() { + o.worker.Send(struct{}{}) +} + +func (o *statusObserver) stop() { + o.worker.Stop() + close(o.statusCh) +} diff --git a/pkg/conference/track/internal.go b/pkg/conference/track/internal.go index 4866ab6..2912028 100644 --- a/pkg/conference/track/internal.go +++ b/pkg/conference/track/internal.go @@ -50,7 +50,7 @@ func forward(sender *webrtc.TrackRemote, receiver *webrtc.TrackLocalStaticRTP, s func (p *PublishedTrack[SubscriberID]) addVideoPublisher(track *webrtc.TrackRemote) { simulcast := webrtc_ext.RIDToSimulcastLayer(track.RID()) - pub, done := publisher.NewPublisher( + pub, statusCh := publisher.NewPublisher( &publisher.RemoteTrack{track}, p.stopPublishers, p.logger.WithField("layer", simulcast), @@ -65,7 +65,14 @@ func (p *PublishedTrack[SubscriberID]) addVideoPublisher(track *webrtc.TrackRemo go func() { defer p.activePublishers.Done() defer p.telemetry.AddEvent("video publisher stopped", attribute.String("simulcast", simulcast.String())) - <-done + + // Wait for the channel to be closed. Ignore statuses for now. + for { + _, closed := <-statusCh + if closed { + break + } + } p.mutex.Lock() defer p.mutex.Unlock()