Skip to content

Commit

Permalink
refactor: define a package for message sink
Browse files Browse the repository at this point in the history
This allows to generalize the message sink and get rid of a lot of
copy-paste in the handling functions.

Also this moves types to the right modules, so that `peer` is now
completely matrix-unaware module that contains only plain WebRTC logic.
  • Loading branch information
daniel-abramov committed Nov 24, 2022
1 parent 25ba9e2 commit 6b8eee8
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 261 deletions.
38 changes: 38 additions & 0 deletions src/common/message_sink.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package common

// MessageSink is a helper struct that allows to send messages to a message sink.
// The MessageSink abstracts the message sink which has a certain sender, so that
// the sender does not have to be specified every time a message is sent.
// At the same it guarantees that the caller can't alter the `sender`, which means that
// the sender can't impersonate another sender (and we guarantee this on a compile-time).
type MessageSink[SenderType comparable, MessageType any] struct {
// The sender of the messages. This is useful for multiple-producer-single-consumer scenarios.
sender SenderType
// The message sink to which the messages are sent.
messageSink chan<- Message[SenderType, MessageType]
}

// Creates a new MessageSink. The function is generic allowing us to use it for multiple use cases.
func NewMessageSink[S comparable, M any](sender S, messageSink chan<- Message[S, M]) *MessageSink[S, M] {
return &MessageSink[S, M]{
sender: sender,
messageSink: messageSink,
}
}

// Sends a message to the message sink.
func (s *MessageSink[S, M]) Send(message M) {
s.messageSink <- Message[S, M]{
Sender: s.sender,
Content: message,
}
}

// Messages that are sent from the peer to the conference in order to communicate with other peers.
// Since each peer is isolated from others, it can't influence the state of other peers directly.
type Message[SenderType comparable, MessageType any] struct {
// The sender of the message.
Sender SenderType
// The content of the message.
Content MessageType
}
58 changes: 34 additions & 24 deletions src/conference/conference.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package conference

import (
"github.com/matrix-org/waterfall/src/common"
"github.com/matrix-org/waterfall/src/peer"
"github.com/matrix-org/waterfall/src/signaling"
"github.com/pion/webrtc/v3"
Expand All @@ -25,22 +26,22 @@ import (
)

type Conference struct {
id string
config Config
signaling signaling.MatrixSignaling
participants map[peer.ID]*Participant
peerEventsStream chan peer.Message
logger *logrus.Entry
id string
config Config
signaling signaling.MatrixSignaling
participants map[ParticipantID]*Participant
peerEvents chan common.Message[ParticipantID, peer.MessageContent]
logger *logrus.Entry
}

func NewConference(confID string, config Config, signaling signaling.MatrixSignaling) *Conference {
conference := &Conference{
id: confID,
config: config,
signaling: signaling,
participants: make(map[peer.ID]*Participant),
peerEventsStream: make(chan peer.Message),
logger: logrus.WithFields(logrus.Fields{"conf_id": confID}),
id: confID,
config: config,
signaling: signaling,
participants: make(map[ParticipantID]*Participant),
peerEvents: make(chan common.Message[ParticipantID, peer.MessageContent]),
logger: logrus.WithFields(logrus.Fields{"conf_id": confID}),
}

// Start conference "main loop".
Expand All @@ -49,7 +50,7 @@ func NewConference(confID string, config Config, signaling signaling.MatrixSigna
}

// New participant tries to join the conference.
func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event.CallInviteEventContent) {
func (c *Conference) OnNewParticipant(participantID ParticipantID, inviteEvent *event.CallInviteEventContent) {
// As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event,
// any existing calls from this device in this call should be terminated.
// TODO: Implement this.
Expand All @@ -67,7 +68,16 @@ func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event.
}
}

peer, sdpOffer, err := peer.NewPeer(participantID, c.id, inviteEvent.Offer.SDP, c.peerEventsStream)
var (
participantlogger = logrus.WithFields(logrus.Fields{
"user_id": participantID.UserID,
"device_id": participantID.DeviceID,
"conf_id": c.id,
})
messageSink = common.NewMessageSink(participantID, c.peerEvents)
)

peer, sdpOffer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, participantlogger)
if err != nil {
c.logger.WithError(err).Errorf("Failed to create new peer")
return
Expand All @@ -88,11 +98,11 @@ func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event.
c.signaling.SendSDPAnswer(recipient, streamMetadata, sdpOffer.SDP)
}

func (c *Conference) OnCandidates(peerID peer.ID, candidatesEvent *event.CallCandidatesEventContent) {
if participant := c.getParticipant(peerID, nil); participant != nil {
func (c *Conference) OnCandidates(participantID ParticipantID, ev *event.CallCandidatesEventContent) {
if participant := c.getParticipant(participantID, nil); participant != nil {
// Convert the candidates to the WebRTC format.
candidates := make([]webrtc.ICECandidateInit, len(candidatesEvent.Candidates))
for i, candidate := range candidatesEvent.Candidates {
candidates := make([]webrtc.ICECandidateInit, len(ev.Candidates))
for i, candidate := range ev.Candidates {
SDPMLineIndex := uint16(candidate.SDPMLineIndex)
candidates[i] = webrtc.ICECandidateInit{
Candidate: candidate.Candidate,
Expand All @@ -105,19 +115,19 @@ func (c *Conference) OnCandidates(peerID peer.ID, candidatesEvent *event.CallCan
}
}

func (c *Conference) OnSelectAnswer(peerID peer.ID, selectAnswerEvent *event.CallSelectAnswerEventContent) {
if participant := c.getParticipant(peerID, nil); participant != nil {
if selectAnswerEvent.SelectedPartyID != peerID.DeviceID.String() {
func (c *Conference) OnSelectAnswer(participantID ParticipantID, ev *event.CallSelectAnswerEventContent) {
if participant := c.getParticipant(participantID, nil); participant != nil {
if ev.SelectedPartyID != participantID.DeviceID.String() {
c.logger.WithFields(logrus.Fields{
"device_id": selectAnswerEvent.SelectedPartyID,
"device_id": ev.SelectedPartyID,
}).Errorf("Call was answered on a different device, kicking this peer")
participant.peer.Terminate()
}
}
}

func (c *Conference) OnHangup(peerID peer.ID, hangupEvent *event.CallHangupEventContent) {
if participant := c.getParticipant(peerID, nil); participant != nil {
func (c *Conference) OnHangup(participantID ParticipantID, ev *event.CallHangupEventContent) {
if participant := c.getParticipant(participantID, nil); participant != nil {
participant.peer.Terminate()
}
}
22 changes: 14 additions & 8 deletions src/conference/participant.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,23 @@ import (

var ErrInvalidSFUMessage = errors.New("invalid SFU message")

type ParticipantID struct {
UserID id.UserID
DeviceID id.DeviceID
}

type Participant struct {
id peer.ID
peer *peer.Peer
id ParticipantID
peer *peer.Peer[ParticipantID]
remoteSessionID id.SessionID
streamMetadata event.CallSDPStreamMetadata
publishedTracks map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP
}

func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient {
return signaling.MatrixRecipient{
ID: p.id,
UserID: p.id.UserID,
DeviceID: p.id.DeviceID,
RemoteSessionID: p.remoteSessionID,
}
}
Expand All @@ -44,12 +50,12 @@ func (p *Participant) sendDataChannelMessage(toSend event.SFUMessage) error {
return nil
}

func (c *Conference) getParticipant(peerID peer.ID, optionalErrorMessage error) *Participant {
participant, ok := c.participants[peerID]
func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant {
participant, ok := c.participants[participantID]
if !ok {
logEntry := c.logger.WithFields(logrus.Fields{
"user_id": peerID.UserID,
"device_id": peerID.DeviceID,
"user_id": participantID.UserID,
"device_id": participantID.DeviceID,
})

if optionalErrorMessage != nil {
Expand All @@ -64,7 +70,7 @@ func (c *Conference) getParticipant(peerID peer.ID, optionalErrorMessage error)
return participant
}

func (c *Conference) getStreamsMetadata(forParticipant peer.ID) event.CallSDPStreamMetadata {
func (c *Conference) getStreamsMetadata(forParticipant ParticipantID) event.CallSDPStreamMetadata {
streamsMetadata := make(event.CallSDPStreamMetadata)
for id, participant := range c.participants {
if forParticipant != id {
Expand Down
50 changes: 10 additions & 40 deletions src/conference/messages.go → src/conference/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,35 @@ import (
"encoding/json"
"errors"

"github.com/matrix-org/waterfall/src/common"
"github.com/matrix-org/waterfall/src/peer"
"maunium.net/go/mautrix/event"
)

func (c *Conference) processMessages() {
for {
// Read a message from the stream (of type peer.Message) and process it.
message := <-c.peerEventsStream
message := <-c.peerEvents
c.processPeerMessage(message)
}
}

//nolint:funlen
func (c *Conference) processPeerMessage(message peer.Message) {
func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) {
participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant"))
if participant == nil {
return
}

// Since Go does not support ADTs, we have to use a switch statement to
// determine the actual type of the message.
switch msg := message.(type) {
switch msg := message.Content.(type) {
case peer.JoinedTheCall:
case peer.LeftTheCall:
delete(c.participants, msg.Sender)
delete(c.participants, message.Sender)
// TODO: Send new metadata about available streams to all participants.
// TODO: Send the hangup event over the Matrix back to the user.

case peer.NewTrackPublished:
participant := c.getParticipant(msg.Sender, errors.New("New track published from unknown participant"))
if participant == nil {
return
}

key := event.SFUTrackDescription{
StreamID: msg.Track.StreamID(),
TrackID: msg.Track.ID(),
Expand All @@ -46,11 +46,6 @@ func (c *Conference) processPeerMessage(message peer.Message) {
participant.publishedTracks[key] = msg.Track

case peer.PublishedTrackFailed:
participant := c.getParticipant(msg.Sender, errors.New("Published track failed from unknown participant"))
if participant == nil {
return
}

delete(participant.publishedTracks, event.SFUTrackDescription{
StreamID: msg.Track.StreamID(),
TrackID: msg.Track.ID(),
Expand All @@ -59,11 +54,6 @@ func (c *Conference) processPeerMessage(message peer.Message) {
// TODO: Should we remove the local tracks from every subscriber as well? Or will it happen automatically?

case peer.NewICECandidate:
participant := c.getParticipant(msg.Sender, errors.New("ICE candidate from unknown participant"))
if participant == nil {
return
}

// Convert WebRTC ICE candidate to Matrix ICE candidate.
jsonCandidate := msg.Candidate.ToJSON()
candidates := []event.CallCandidate{{
Expand All @@ -74,20 +64,10 @@ func (c *Conference) processPeerMessage(message peer.Message) {
c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates)

case peer.ICEGatheringComplete:
participant := c.getParticipant(msg.Sender, errors.New("Received ICE complete from unknown participant"))
if participant == nil {
return
}

// Send an empty array of candidates to indicate that ICE gathering is complete.
c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient())

case peer.RenegotiationRequired:
participant := c.getParticipant(msg.Sender, errors.New("Renegotiation from unknown participant"))
if participant == nil {
return
}

toSend := event.SFUMessage{
Op: event.SFUOperationOffer,
SDP: msg.Offer.SDP,
Expand All @@ -97,11 +77,6 @@ func (c *Conference) processPeerMessage(message peer.Message) {
participant.sendDataChannelMessage(toSend)

case peer.DataChannelMessage:
participant := c.getParticipant(msg.Sender, errors.New("Data channel message from unknown participant"))
if participant == nil {
return
}

var sfuMessage event.SFUMessage
if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil {
c.logger.Errorf("Failed to unmarshal SFU message: %v", err)
Expand All @@ -111,11 +86,6 @@ func (c *Conference) processPeerMessage(message peer.Message) {
c.handleDataChannelMessage(participant, sfuMessage)

case peer.DataChannelAvailable:
participant := c.getParticipant(msg.Sender, errors.New("Data channel available from unknown participant"))
if participant == nil {
return
}

toSend := event.SFUMessage{
Op: event.SFUOperationMetadata,
Metadata: c.getStreamsMetadata(participant.id),
Expand Down
48 changes: 0 additions & 48 deletions src/peer/channel.go

This file was deleted.

8 changes: 0 additions & 8 deletions src/peer/id.go

This file was deleted.

Loading

0 comments on commit 6b8eee8

Please sign in to comment.