diff --git a/src/common/message_sink.go b/src/common/message_sink.go new file mode 100644 index 0000000..3351492 --- /dev/null +++ b/src/common/message_sink.go @@ -0,0 +1,38 @@ +package common + +// MessageSink is a helper struct that allows to send messages to a message sink. +// The MessageSink abstracts the message sink which has a certain sender, so that +// the sender does not have to be specified every time a message is sent. +// At the same it guarantees that the caller can't alter the `sender`, which means that +// the sender can't impersonate another sender (and we guarantee this on a compile-time). +type MessageSink[SenderType comparable, MessageType any] struct { + // The sender of the messages. This is useful for multiple-producer-single-consumer scenarios. + sender SenderType + // The message sink to which the messages are sent. + messageSink chan<- Message[SenderType, MessageType] +} + +// Creates a new MessageSink. The function is generic allowing us to use it for multiple use cases. +func NewMessageSink[S comparable, M any](sender S, messageSink chan<- Message[S, M]) *MessageSink[S, M] { + return &MessageSink[S, M]{ + sender: sender, + messageSink: messageSink, + } +} + +// Sends a message to the message sink. +func (s *MessageSink[S, M]) Send(message M) { + s.messageSink <- Message[S, M]{ + Sender: s.sender, + Content: message, + } +} + +// Messages that are sent from the peer to the conference in order to communicate with other peers. +// Since each peer is isolated from others, it can't influence the state of other peers directly. +type Message[SenderType comparable, MessageType any] struct { + // The sender of the message. + Sender SenderType + // The content of the message. + Content MessageType +} diff --git a/src/conference/conference.go b/src/conference/conference.go index 1a2a243..56ef684 100644 --- a/src/conference/conference.go +++ b/src/conference/conference.go @@ -17,6 +17,7 @@ limitations under the License. package conference import ( + "github.com/matrix-org/waterfall/src/common" "github.com/matrix-org/waterfall/src/peer" "github.com/matrix-org/waterfall/src/signaling" "github.com/pion/webrtc/v3" @@ -25,22 +26,22 @@ import ( ) type Conference struct { - id string - config Config - signaling signaling.MatrixSignaling - participants map[peer.ID]*Participant - peerEventsStream chan peer.Message - logger *logrus.Entry + id string + config Config + signaling signaling.MatrixSignaling + participants map[ParticipantID]*Participant + peerEvents chan common.Message[ParticipantID, peer.MessageContent] + logger *logrus.Entry } func NewConference(confID string, config Config, signaling signaling.MatrixSignaling) *Conference { conference := &Conference{ - id: confID, - config: config, - signaling: signaling, - participants: make(map[peer.ID]*Participant), - peerEventsStream: make(chan peer.Message), - logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), + id: confID, + config: config, + signaling: signaling, + participants: make(map[ParticipantID]*Participant), + peerEvents: make(chan common.Message[ParticipantID, peer.MessageContent]), + logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), } // Start conference "main loop". @@ -49,7 +50,7 @@ func NewConference(confID string, config Config, signaling signaling.MatrixSigna } // New participant tries to join the conference. -func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event.CallInviteEventContent) { +func (c *Conference) OnNewParticipant(participantID ParticipantID, inviteEvent *event.CallInviteEventContent) { // As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event, // any existing calls from this device in this call should be terminated. // TODO: Implement this. @@ -67,7 +68,16 @@ func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event. } } - peer, sdpOffer, err := peer.NewPeer(participantID, c.id, inviteEvent.Offer.SDP, c.peerEventsStream) + var ( + participantlogger = logrus.WithFields(logrus.Fields{ + "user_id": participantID.UserID, + "device_id": participantID.DeviceID, + "conf_id": c.id, + }) + messageSink = common.NewMessageSink(participantID, c.peerEvents) + ) + + peer, sdpOffer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, participantlogger) if err != nil { c.logger.WithError(err).Errorf("Failed to create new peer") return @@ -88,11 +98,11 @@ func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event. c.signaling.SendSDPAnswer(recipient, streamMetadata, sdpOffer.SDP) } -func (c *Conference) OnCandidates(peerID peer.ID, candidatesEvent *event.CallCandidatesEventContent) { - if participant := c.getParticipant(peerID, nil); participant != nil { +func (c *Conference) OnCandidates(participantID ParticipantID, ev *event.CallCandidatesEventContent) { + if participant := c.getParticipant(participantID, nil); participant != nil { // Convert the candidates to the WebRTC format. - candidates := make([]webrtc.ICECandidateInit, len(candidatesEvent.Candidates)) - for i, candidate := range candidatesEvent.Candidates { + candidates := make([]webrtc.ICECandidateInit, len(ev.Candidates)) + for i, candidate := range ev.Candidates { SDPMLineIndex := uint16(candidate.SDPMLineIndex) candidates[i] = webrtc.ICECandidateInit{ Candidate: candidate.Candidate, @@ -105,19 +115,19 @@ func (c *Conference) OnCandidates(peerID peer.ID, candidatesEvent *event.CallCan } } -func (c *Conference) OnSelectAnswer(peerID peer.ID, selectAnswerEvent *event.CallSelectAnswerEventContent) { - if participant := c.getParticipant(peerID, nil); participant != nil { - if selectAnswerEvent.SelectedPartyID != peerID.DeviceID.String() { +func (c *Conference) OnSelectAnswer(participantID ParticipantID, ev *event.CallSelectAnswerEventContent) { + if participant := c.getParticipant(participantID, nil); participant != nil { + if ev.SelectedPartyID != participantID.DeviceID.String() { c.logger.WithFields(logrus.Fields{ - "device_id": selectAnswerEvent.SelectedPartyID, + "device_id": ev.SelectedPartyID, }).Errorf("Call was answered on a different device, kicking this peer") participant.peer.Terminate() } } } -func (c *Conference) OnHangup(peerID peer.ID, hangupEvent *event.CallHangupEventContent) { - if participant := c.getParticipant(peerID, nil); participant != nil { +func (c *Conference) OnHangup(participantID ParticipantID, ev *event.CallHangupEventContent) { + if participant := c.getParticipant(participantID, nil); participant != nil { participant.peer.Terminate() } } diff --git a/src/conference/participant.go b/src/conference/participant.go index fc0ced3..08b5f88 100644 --- a/src/conference/participant.go +++ b/src/conference/participant.go @@ -14,9 +14,14 @@ import ( var ErrInvalidSFUMessage = errors.New("invalid SFU message") +type ParticipantID struct { + UserID id.UserID + DeviceID id.DeviceID +} + type Participant struct { - id peer.ID - peer *peer.Peer + id ParticipantID + peer *peer.Peer[ParticipantID] remoteSessionID id.SessionID streamMetadata event.CallSDPStreamMetadata publishedTracks map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP @@ -24,7 +29,8 @@ type Participant struct { func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient { return signaling.MatrixRecipient{ - ID: p.id, + UserID: p.id.UserID, + DeviceID: p.id.DeviceID, RemoteSessionID: p.remoteSessionID, } } @@ -44,12 +50,12 @@ func (p *Participant) sendDataChannelMessage(toSend event.SFUMessage) error { return nil } -func (c *Conference) getParticipant(peerID peer.ID, optionalErrorMessage error) *Participant { - participant, ok := c.participants[peerID] +func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant { + participant, ok := c.participants[participantID] if !ok { logEntry := c.logger.WithFields(logrus.Fields{ - "user_id": peerID.UserID, - "device_id": peerID.DeviceID, + "user_id": participantID.UserID, + "device_id": participantID.DeviceID, }) if optionalErrorMessage != nil { @@ -64,7 +70,7 @@ func (c *Conference) getParticipant(peerID peer.ID, optionalErrorMessage error) return participant } -func (c *Conference) getStreamsMetadata(forParticipant peer.ID) event.CallSDPStreamMetadata { +func (c *Conference) getStreamsMetadata(forParticipant ParticipantID) event.CallSDPStreamMetadata { streamsMetadata := make(event.CallSDPStreamMetadata) for id, participant := range c.participants { if forParticipant != id { diff --git a/src/conference/messages.go b/src/conference/processor.go similarity index 77% rename from src/conference/messages.go rename to src/conference/processor.go index e4b30ea..c401079 100644 --- a/src/conference/messages.go +++ b/src/conference/processor.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" + "github.com/matrix-org/waterfall/src/common" "github.com/matrix-org/waterfall/src/peer" "maunium.net/go/mautrix/event" ) @@ -11,28 +12,27 @@ import ( func (c *Conference) processMessages() { for { // Read a message from the stream (of type peer.Message) and process it. - message := <-c.peerEventsStream + message := <-c.peerEvents c.processPeerMessage(message) } } -//nolint:funlen -func (c *Conference) processPeerMessage(message peer.Message) { +func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) { + participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant")) + if participant == nil { + return + } + // Since Go does not support ADTs, we have to use a switch statement to // determine the actual type of the message. - switch msg := message.(type) { + switch msg := message.Content.(type) { case peer.JoinedTheCall: case peer.LeftTheCall: - delete(c.participants, msg.Sender) + delete(c.participants, message.Sender) // TODO: Send new metadata about available streams to all participants. // TODO: Send the hangup event over the Matrix back to the user. case peer.NewTrackPublished: - participant := c.getParticipant(msg.Sender, errors.New("New track published from unknown participant")) - if participant == nil { - return - } - key := event.SFUTrackDescription{ StreamID: msg.Track.StreamID(), TrackID: msg.Track.ID(), @@ -46,11 +46,6 @@ func (c *Conference) processPeerMessage(message peer.Message) { participant.publishedTracks[key] = msg.Track case peer.PublishedTrackFailed: - participant := c.getParticipant(msg.Sender, errors.New("Published track failed from unknown participant")) - if participant == nil { - return - } - delete(participant.publishedTracks, event.SFUTrackDescription{ StreamID: msg.Track.StreamID(), TrackID: msg.Track.ID(), @@ -59,11 +54,6 @@ func (c *Conference) processPeerMessage(message peer.Message) { // TODO: Should we remove the local tracks from every subscriber as well? Or will it happen automatically? case peer.NewICECandidate: - participant := c.getParticipant(msg.Sender, errors.New("ICE candidate from unknown participant")) - if participant == nil { - return - } - // Convert WebRTC ICE candidate to Matrix ICE candidate. jsonCandidate := msg.Candidate.ToJSON() candidates := []event.CallCandidate{{ @@ -74,20 +64,10 @@ func (c *Conference) processPeerMessage(message peer.Message) { c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates) case peer.ICEGatheringComplete: - participant := c.getParticipant(msg.Sender, errors.New("Received ICE complete from unknown participant")) - if participant == nil { - return - } - // Send an empty array of candidates to indicate that ICE gathering is complete. c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient()) case peer.RenegotiationRequired: - participant := c.getParticipant(msg.Sender, errors.New("Renegotiation from unknown participant")) - if participant == nil { - return - } - toSend := event.SFUMessage{ Op: event.SFUOperationOffer, SDP: msg.Offer.SDP, @@ -97,11 +77,6 @@ func (c *Conference) processPeerMessage(message peer.Message) { participant.sendDataChannelMessage(toSend) case peer.DataChannelMessage: - participant := c.getParticipant(msg.Sender, errors.New("Data channel message from unknown participant")) - if participant == nil { - return - } - var sfuMessage event.SFUMessage if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil { c.logger.Errorf("Failed to unmarshal SFU message: %v", err) @@ -111,11 +86,6 @@ func (c *Conference) processPeerMessage(message peer.Message) { c.handleDataChannelMessage(participant, sfuMessage) case peer.DataChannelAvailable: - participant := c.getParticipant(msg.Sender, errors.New("Data channel available from unknown participant")) - if participant == nil { - return - } - toSend := event.SFUMessage{ Op: event.SFUOperationMetadata, Metadata: c.getStreamsMetadata(participant.id), diff --git a/src/peer/channel.go b/src/peer/channel.go deleted file mode 100644 index b0c573c..0000000 --- a/src/peer/channel.go +++ /dev/null @@ -1,48 +0,0 @@ -package peer - -import ( - "github.com/pion/webrtc/v3" -) - -type Message = interface{} - -type JoinedTheCall struct { - Sender ID -} - -type LeftTheCall struct { - Sender ID -} - -type NewTrackPublished struct { - Sender ID - Track *webrtc.TrackLocalStaticRTP -} - -type PublishedTrackFailed struct { - Sender ID - Track *webrtc.TrackLocalStaticRTP -} - -type NewICECandidate struct { - Sender ID - Candidate *webrtc.ICECandidate -} - -type ICEGatheringComplete struct { - Sender ID -} - -type RenegotiationRequired struct { - Sender ID - Offer *webrtc.SessionDescription -} - -type DataChannelMessage struct { - Sender ID - Message string -} - -type DataChannelAvailable struct { - Sender ID -} diff --git a/src/peer/id.go b/src/peer/id.go deleted file mode 100644 index e7b4697..0000000 --- a/src/peer/id.go +++ /dev/null @@ -1,8 +0,0 @@ -package peer - -import "maunium.net/go/mautrix/id" - -type ID struct { - UserID id.UserID - DeviceID id.DeviceID -} diff --git a/src/peer/messages.go b/src/peer/messages.go new file mode 100644 index 0000000..51ef1d9 --- /dev/null +++ b/src/peer/messages.go @@ -0,0 +1,35 @@ +package peer + +import ( + "github.com/pion/webrtc/v3" +) + +type MessageContent = interface{} + +type JoinedTheCall struct{} + +type LeftTheCall struct{} + +type NewTrackPublished struct { + Track *webrtc.TrackLocalStaticRTP +} + +type PublishedTrackFailed struct { + Track *webrtc.TrackLocalStaticRTP +} + +type NewICECandidate struct { + Candidate *webrtc.ICECandidate +} + +type ICEGatheringComplete struct{} + +type RenegotiationRequired struct { + Offer *webrtc.SessionDescription +} + +type DataChannelMessage struct { + Message string +} + +type DataChannelAvailable struct{} diff --git a/src/peer/peer.go b/src/peer/peer.go index 84506d0..917682b 100644 --- a/src/peer/peer.go +++ b/src/peer/peer.go @@ -4,6 +4,7 @@ import ( "errors" "sync" + "github.com/matrix-org/waterfall/src/common" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" ) @@ -19,39 +20,30 @@ var ( ErrCantSubscribeToTrack = errors.New("can't subscribe to track") ) -type Peer struct { - id ID +type Peer[ID comparable] struct { logger *logrus.Entry - notify chan<- interface{} peerConnection *webrtc.PeerConnection + sink *common.MessageSink[ID, MessageContent] dataChannelMutex sync.Mutex dataChannel *webrtc.DataChannel } -func NewPeer( - info ID, - conferenceId string, +func NewPeer[ID comparable]( sdpOffer string, - notify chan<- interface{}, -) (*Peer, *webrtc.SessionDescription, error) { - logger := logrus.WithFields(logrus.Fields{ - "user_id": info.UserID, - "device_id": info.DeviceID, - "conf_id": conferenceId, - }) - + sink *common.MessageSink[ID, MessageContent], + logger *logrus.Entry, +) (*Peer[ID], *webrtc.SessionDescription, error) { peerConnection, err := webrtc.NewPeerConnection(webrtc.Configuration{}) if err != nil { logger.WithError(err).Error("failed to create peer connection") return nil, nil, ErrCantCreatePeerConnection } - peer := &Peer{ - id: info, + peer := &Peer[ID]{ logger: logger, - notify: notify, peerConnection: peerConnection, + sink: sink, } peerConnection.OnTrack(peer.onRtpTrackReceived) @@ -99,15 +91,15 @@ func NewPeer( return peer, sdpAnswer, nil } -func (p *Peer) Terminate() { +func (p *Peer[ID]) Terminate() { if err := p.peerConnection.Close(); err != nil { p.logger.WithError(err).Error("failed to close peer connection") } - p.notify <- LeftTheCall{Sender: p.id} + p.sink.Send(LeftTheCall{}) } -func (p *Peer) AddICECandidates(candidates []webrtc.ICECandidateInit) { +func (p *Peer[ID]) AddICECandidates(candidates []webrtc.ICECandidateInit) { for _, candidate := range candidates { if err := p.peerConnection.AddICECandidate(candidate); err != nil { p.logger.WithError(err).Error("failed to add ICE candidate") @@ -115,7 +107,7 @@ func (p *Peer) AddICECandidates(candidates []webrtc.ICECandidateInit) { } } -func (p *Peer) SubscribeToTrack(track *webrtc.TrackLocalStaticRTP) error { +func (p *Peer[ID]) SubscribeToTrack(track *webrtc.TrackLocalStaticRTP) error { _, err := p.peerConnection.AddTrack(track) if err != nil { p.logger.WithError(err).Error("failed to add track") @@ -125,7 +117,7 @@ func (p *Peer) SubscribeToTrack(track *webrtc.TrackLocalStaticRTP) error { return nil } -func (p *Peer) SendOverDataChannel(json string) error { +func (p *Peer[ID]) SendOverDataChannel(json string) error { p.dataChannelMutex.Lock() defer p.dataChannelMutex.Unlock() @@ -146,7 +138,7 @@ func (p *Peer) SendOverDataChannel(json string) error { return nil } -func (p *Peer) NewSDPAnswerReceived(sdpAnswer string) error { +func (p *Peer[ID]) NewSDPAnswerReceived(sdpAnswer string) error { err := p.peerConnection.SetRemoteDescription(webrtc.SessionDescription{ Type: webrtc.SDPTypeAnswer, SDP: sdpAnswer, diff --git a/src/peer/webrtc.go b/src/peer/webrtc.go index 889416c..46c54ef 100644 --- a/src/peer/webrtc.go +++ b/src/peer/webrtc.go @@ -11,7 +11,7 @@ import ( // A callback that is called once we receive first RTP packets from a track, i.e. // we call this function each time a new track is received. -func (p *Peer) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { +func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval. // This can be less wasteful by processing incoming RTCP events, then we would emit a NACK/PLI // when a viewer requests it. @@ -40,7 +40,7 @@ func (p *Peer) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *web } // Notify others that our track has just been published. - p.notify <- NewTrackPublished{Sender: p.id, Track: localTrack} + p.sink.Send(NewTrackPublished{Track: localTrack}) // Start forwarding the data from the remote track to the local track, // so that everyone who is subscribed to this track will receive the data. @@ -56,31 +56,31 @@ func (p *Peer) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *web } else { // finished, no more data, but with error, inform others p.logger.WithError(readErr).Error("failed to read from remote track") } - p.notify <- PublishedTrackFailed{Sender: p.id, Track: localTrack} + p.sink.Send(PublishedTrackFailed{Track: localTrack}) } // ErrClosedPipe means we don't have any subscribers, this is ok if no peers have connected yet. if _, err = localTrack.Write(rtpBuf[:index]); err != nil && !errors.Is(err, io.ErrClosedPipe) { p.logger.WithError(err).Error("failed to write to local track") - p.notify <- PublishedTrackFailed{Sender: p.id, Track: localTrack} + p.sink.Send(PublishedTrackFailed{Track: localTrack}) } } }() } // A callback that is called once we receive an ICE candidate for this peer connection. -func (p *Peer) onICECandidateGathered(candidate *webrtc.ICECandidate) { +func (p *Peer[ID]) onICECandidateGathered(candidate *webrtc.ICECandidate) { if candidate == nil { p.logger.Info("ICE candidate gathering finished") return } p.logger.WithField("candidate", candidate).Debug("ICE candidate gathered") - p.notify <- NewICECandidate{Sender: p.id, Candidate: candidate} + p.sink.Send(NewICECandidate{Candidate: candidate}) } // A callback that is called once we receive an ICE connection state change for this peer connection. -func (p *Peer) onNegotiationNeeded() { +func (p *Peer[ID]) onNegotiationNeeded() { p.logger.Debug("negotiation needed") offer, err := p.peerConnection.CreateOffer(nil) if err != nil { @@ -93,11 +93,11 @@ func (p *Peer) onNegotiationNeeded() { return } - p.notify <- RenegotiationRequired{Sender: p.id, Offer: &offer} + p.sink.Send(RenegotiationRequired{Offer: &offer}) } // A callback that is called once we receive an ICE connection state change for this peer connection. -func (p *Peer) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { +func (p *Peer[ID]) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { p.logger.WithField("state", state).Debug("ICE connection state changed") switch state { @@ -111,35 +111,35 @@ func (p *Peer) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { // TODO: Ask Simon if we should do it here as in the previous implementation of the // `waterfall` or the way I did it in this new implementation. // p.notify <- PeerJoinedTheCall{sender: p.data} - p.notify <- ICEGatheringComplete{Sender: p.id} + p.sink.Send(ICEGatheringComplete{}) } } -func (p *Peer) onICEGatheringStateChanged(state webrtc.ICEGathererState) { +func (p *Peer[ID]) onICEGatheringStateChanged(state webrtc.ICEGathererState) { p.logger.WithField("state", state).Debug("ICE gathering state changed") if state == webrtc.ICEGathererStateComplete { - p.notify <- ICEGatheringComplete{Sender: p.id} + p.sink.Send(ICEGatheringComplete{}) } } -func (p *Peer) onSignalingStateChanged(state webrtc.SignalingState) { +func (p *Peer[ID]) onSignalingStateChanged(state webrtc.SignalingState) { p.logger.WithField("state", state).Debug("signaling state changed") } -func (p *Peer) onConnectionStateChanged(state webrtc.PeerConnectionState) { +func (p *Peer[ID]) onConnectionStateChanged(state webrtc.PeerConnectionState) { p.logger.WithField("state", state).Debug("connection state changed") switch state { case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateClosed: - p.notify <- LeftTheCall{Sender: p.id} + p.sink.Send(LeftTheCall{}) case webrtc.PeerConnectionStateConnected: - p.notify <- JoinedTheCall{Sender: p.id} + p.sink.Send(JoinedTheCall{}) } } // A callback that is called once the data channel is ready to be used. -func (p *Peer) onDataChannelReady(dc *webrtc.DataChannel) { +func (p *Peer[ID]) onDataChannelReady(dc *webrtc.DataChannel) { p.dataChannelMutex.Lock() defer p.dataChannelMutex.Unlock() @@ -154,13 +154,13 @@ func (p *Peer) onDataChannelReady(dc *webrtc.DataChannel) { dc.OnOpen(func() { p.logger.Info("data channel opened") - p.notify <- DataChannelAvailable{Sender: p.id} + p.sink.Send(DataChannelAvailable{}) }) dc.OnMessage(func(msg webrtc.DataChannelMessage) { p.logger.WithField("message", msg).Debug("data channel message received") if msg.IsString { - p.notify <- DataChannelMessage{Sender: p.id, Message: string(msg.Data)} + p.sink.Send(DataChannelMessage{Message: string(msg.Data)}) } else { p.logger.Warn("data channel message is not a string, ignoring") } diff --git a/src/router.go b/src/router.go index e7b3899..ec59de9 100644 --- a/src/router.go +++ b/src/router.go @@ -17,8 +17,7 @@ limitations under the License. package main import ( - "github.com/matrix-org/waterfall/src/conference" - "github.com/matrix-org/waterfall/src/peer" + conf "github.com/matrix-org/waterfall/src/conference" "github.com/matrix-org/waterfall/src/signaling" "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" @@ -29,16 +28,16 @@ type Router struct { // Matrix matrix. matrix *signaling.MatrixClient // All calls currently forwarded by this SFU. - conferences map[string]*conference.Conference + conferences map[string]*conf.Conference // Configuration for the calls. - config conference.Config + config conf.Config } // Creates a new instance of the SFU with the given configuration. -func newRouter(matrix *signaling.MatrixClient, config conference.Config) *Router { +func newRouter(matrix *signaling.MatrixClient, config conf.Config) *Router { return &Router{ matrix: matrix, - conferences: make(map[string]*conference.Conference), + conferences: make(map[string]*conf.Conference), config: config, } } @@ -63,17 +62,17 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - // If there is an invitation sent and the conf does not exist, create one. - if conf := r.conferences[invite.ConfID]; conf == nil { + // If there is an invitation sent and the conference does not exist, create one. + if conference := r.conferences[invite.ConfID]; conference == nil { logger.Infof("creating new conference %s", invite.ConfID) - r.conferences[invite.ConfID] = conference.NewConference( + r.conferences[invite.ConfID] = conf.NewConference( invite.ConfID, r.config, r.matrix.CreateForConference(invite.ConfID), ) } - peerID := peer.ID{ + peerID := conf.ParticipantID{ UserID: evt.Sender, DeviceID: invite.DeviceID, } @@ -95,7 +94,7 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - peerID := peer.ID{ + peerID := conf.ParticipantID{ UserID: evt.Sender, DeviceID: candidates.DeviceID, } @@ -116,7 +115,7 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - peerID := peer.ID{ + peerID := conf.ParticipantID{ UserID: evt.Sender, DeviceID: selectAnswer.DeviceID, } @@ -137,7 +136,7 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - peerID := peer.ID{ + peerID := conf.ParticipantID{ UserID: evt.Sender, DeviceID: hangup.DeviceID, } diff --git a/src/signaling/client.go b/src/signaling/client.go new file mode 100644 index 0000000..a16313f --- /dev/null +++ b/src/signaling/client.go @@ -0,0 +1,67 @@ +package signaling + +import ( + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +type MatrixClient struct { + client *mautrix.Client +} + +func NewMatrixClient(config Config) *MatrixClient { + client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) + if err != nil { + logrus.WithError(err).Fatal("Failed to create client") + } + + whoami, err := client.Whoami() + if err != nil { + logrus.WithError(err).Fatal("Failed to identify SFU user") + } + + if config.UserID != whoami.UserID { + logrus.WithField("user_id", config.UserID).Fatal("Access token is for the wrong user") + } + + logrus.WithField("device_id", whoami.DeviceID).Info("Identified SFU as DeviceID") + client.DeviceID = whoami.DeviceID + + return &MatrixClient{ + client: client, + } +} + +// Starts the Matrix client and connects to the homeserver, +// Returns only when the sync with Matrix fails. +func (m *MatrixClient) RunSync(callback func(*event.Event)) { + syncer, ok := m.client.Syncer.(*mautrix.DefaultSyncer) + if !ok { + logrus.Panic("Syncer is not DefaultSyncer") + } + + syncer.ParseEventContent = true + syncer.OnEvent(func(_ mautrix.EventSource, evt *event.Event) { + // We only care about to-device events. + if evt.Type.Class != event.ToDeviceEventType { + logrus.Warn("ignoring a not to-device event") + return + } + + // We drop the messages if they are not meant for us. + if evt.Content.Raw["dest_session_id"] != LocalSessionID { + logrus.Warn("SessionID does not match our SessionID - ignoring") + return + } + + callback(evt) + }) + + // TODO: We may want to reconnect if `Sync()` fails instead of ending the SFU + // as ending here will essentially drop all conferences which may not necessarily + // be what we want for the existing running conferences. + if err := m.client.Sync(); err != nil { + logrus.WithError(err).Panic("Sync failed") + } +} diff --git a/src/signaling/matrix.go b/src/signaling/matrix.go index 3c56a27..65c39d9 100644 --- a/src/signaling/matrix.go +++ b/src/signaling/matrix.go @@ -17,7 +17,6 @@ limitations under the License. package signaling import ( - "github.com/matrix-org/waterfall/src/peer" "github.com/sirupsen/logrus" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -26,66 +25,13 @@ import ( const LocalSessionID = "sfu" -type MatrixClient struct { - client *mautrix.Client -} - -func NewMatrixClient(config Config) *MatrixClient { - client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) - if err != nil { - logrus.WithError(err).Fatal("Failed to create client") - } - - whoami, err := client.Whoami() - if err != nil { - logrus.WithError(err).Fatal("Failed to identify SFU user") - } - - if config.UserID != whoami.UserID { - logrus.WithField("user_id", config.UserID).Fatal("Access token is for the wrong user") - } - - logrus.WithField("device_id", whoami.DeviceID).Info("Identified SFU as DeviceID") - client.DeviceID = whoami.DeviceID - - return &MatrixClient{ - client: client, - } -} - -// Starts the Matrix client and connects to the homeserver, -// Returns only when the sync with Matrix fails. -func (m *MatrixClient) RunSync(callback func(*event.Event)) { - syncer, ok := m.client.Syncer.(*mautrix.DefaultSyncer) - if !ok { - logrus.Panic("Syncer is not DefaultSyncer") - } - - syncer.ParseEventContent = true - syncer.OnEvent(func(_ mautrix.EventSource, evt *event.Event) { - // We only care about to-device events. - if evt.Type.Class != event.ToDeviceEventType { - logrus.Warn("ignoring a not to-device event") - return - } - - // We drop the messages if they are not meant for us. - if evt.Content.Raw["dest_session_id"] != LocalSessionID { - logrus.Warn("SessionID does not match our SessionID - ignoring") - return - } - - callback(evt) - }) - - // TODO: We may want to reconnect if `Sync()` fails instead of ending the SFU - // as ending here will essentially drop all conferences which may not necessarily - // be what we want for the existing running conferences. - if err := m.client.Sync(); err != nil { - logrus.WithError(err).Panic("Sync failed") - } +// Matrix client scoped for a particular conference. +type MatrixForConference struct { + client *mautrix.Client + conferenceID string } +// Create a new Matrix client that abstarcts outgoing Matrix messages from a given conference. func (m *MatrixClient) CreateForConference(conferenceID string) *MatrixForConference { return &MatrixForConference{ client: m.client, @@ -93,11 +39,14 @@ func (m *MatrixClient) CreateForConference(conferenceID string) *MatrixForConfer } } +// Defines the data that identifies a receiver of Matrix's to-device message. type MatrixRecipient struct { - ID peer.ID + UserID id.UserID + DeviceID id.DeviceID RemoteSessionID id.SessionID } +// Interface that abstracts sending Send-to-device messages for the conference. type MatrixSignaling interface { SendSDPAnswer(recipient MatrixRecipient, streamMetadata event.CallSDPStreamMetadata, sdp string) SendICECandidates(recipient MatrixRecipient, candidates []event.CallCandidate) @@ -105,11 +54,6 @@ type MatrixSignaling interface { SendHangup(recipient MatrixRecipient, reason event.CallHangupReason) } -type MatrixForConference struct { - client *mautrix.Client - conferenceID string -} - func (m *MatrixForConference) SendSDPAnswer( recipient MatrixRecipient, streamMetadata event.CallSDPStreamMetadata, @@ -117,7 +61,7 @@ func (m *MatrixForConference) SendSDPAnswer( ) { eventContent := &event.Content{ Parsed: event.CallAnswerEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), Answer: event.CallData{ Type: "answer", SDP: sdp, @@ -126,40 +70,40 @@ func (m *MatrixForConference) SendSDPAnswer( }, } - m.sendToDevice(recipient.ID, event.CallAnswer, eventContent) + m.sendToDevice(recipient, event.CallAnswer, eventContent) } func (m *MatrixForConference) SendICECandidates(recipient MatrixRecipient, candidates []event.CallCandidate) { eventContent := &event.Content{ Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), Candidates: candidates, }, } - m.sendToDevice(recipient.ID, event.CallCandidates, eventContent) + m.sendToDevice(recipient, event.CallCandidates, eventContent) } func (m *MatrixForConference) SendCandidatesGatheringFinished(recipient MatrixRecipient) { eventContent := &event.Content{ Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), Candidates: []event.CallCandidate{{Candidate: ""}}, }, } - m.sendToDevice(recipient.ID, event.CallCandidates, eventContent) + m.sendToDevice(recipient, event.CallCandidates, eventContent) } func (m *MatrixForConference) SendHangup(recipient MatrixRecipient, reason event.CallHangupReason) { eventContent := &event.Content{ Parsed: event.CallHangupEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), Reason: reason, }, } - m.sendToDevice(recipient.ID, event.CallHangup, eventContent) + m.sendToDevice(recipient, event.CallHangup, eventContent) } func (m *MatrixForConference) createBaseEventContent( @@ -178,17 +122,17 @@ func (m *MatrixForConference) createBaseEventContent( } // Sends a to-device event to the given user. -func (m *MatrixForConference) sendToDevice(participantID peer.ID, eventType event.Type, eventContent *event.Content) { +func (m *MatrixForConference) sendToDevice(user MatrixRecipient, eventType event.Type, eventContent *event.Content) { // TODO: Don't create logger again and again, it might be a bit expensive. logger := logrus.WithFields(logrus.Fields{ - "user_id": participantID.UserID, - "device_id": participantID.DeviceID, + "user_id": user.UserID, + "device_id": user.DeviceID, }) sendRequest := &mautrix.ReqSendToDevice{ Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - participantID.UserID: { - participantID.DeviceID: eventContent, + user.UserID: { + user.DeviceID: eventContent, }, }, }