Skip to content

Commit

Permalink
Merge pull request #16 from cloudstruct/feature/protocol-refactor
Browse files Browse the repository at this point in the history
Refactor mini-protocols
  • Loading branch information
agaffney authored Feb 13, 2022
2 parents c63495e + 78dee7a commit 11df25c
Show file tree
Hide file tree
Showing 12 changed files with 487 additions and 421 deletions.
2 changes: 1 addition & 1 deletion cmd/go-ouroboros-network/chainsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func chainSyncIntersectFoundHandler(point interface{}, tip interface{}) error {
return nil
}

func chainSyncIntersectNotFoundHandler() error {
func chainSyncIntersectNotFoundHandler(tip interface{}) error {
fmt.Printf("ERROR: failed to find intersection\n")
os.Exit(1)
return nil
Expand Down
51 changes: 0 additions & 51 deletions muxer/message.go

This file was deleted.

40 changes: 27 additions & 13 deletions muxer/muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,37 @@ import (
"encoding/binary"
"fmt"
"io"
"sync"
)

const (
// Magic number chosen to represent unknown protocols
PROTOCOL_UNKNOWN uint16 = 0xabcd
)

type Muxer struct {
conn io.ReadWriteCloser
sendMutex sync.Mutex
ErrorChan chan error
protocolSenders map[uint16]chan *Message
protocolReceivers map[uint16]chan *Message
protocolSenders map[uint16]chan *Segment
protocolReceivers map[uint16]chan *Segment
}

func New(conn io.ReadWriteCloser) *Muxer {
m := &Muxer{
conn: conn,
ErrorChan: make(chan error, 10),
protocolSenders: make(map[uint16]chan *Message),
protocolReceivers: make(map[uint16]chan *Message),
protocolSenders: make(map[uint16]chan *Segment),
protocolReceivers: make(map[uint16]chan *Segment),
}
go m.readLoop()
return m
}

func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Message, chan *Message) {
func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment) {
// Generate channels
senderChan := make(chan *Message, 10)
receiverChan := make(chan *Message, 10)
senderChan := make(chan *Segment, 10)
receiverChan := make(chan *Segment, 10)
// Record channels in protocol sender/receiver maps
m.protocolSenders[protocolId] = senderChan
m.protocolReceivers[protocolId] = receiverChan
Expand All @@ -44,9 +51,12 @@ func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Message, chan *Messag
return senderChan, receiverChan
}

func (m *Muxer) Send(msg *Message) error {
func (m *Muxer) Send(msg *Segment) error {
// We use a mutex to make sure only one protocol can send at a time
m.sendMutex.Lock()
defer m.sendMutex.Unlock()
buf := &bytes.Buffer{}
err := binary.Write(buf, binary.BigEndian, msg.MessageHeader)
err := binary.Write(buf, binary.BigEndian, msg.SegmentHeader)
if err != nil {
return err
}
Expand All @@ -60,12 +70,12 @@ func (m *Muxer) Send(msg *Message) error {

func (m *Muxer) readLoop() {
for {
header := MessageHeader{}
header := SegmentHeader{}
if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil {
m.ErrorChan <- err
}
msg := &Message{
MessageHeader: header,
msg := &Segment{
SegmentHeader: header,
Payload: make([]byte, header.PayloadLength),
}
// We use ReadFull because it guarantees to read the expected number of bytes or
Expand All @@ -76,7 +86,11 @@ func (m *Muxer) readLoop() {
// Send message payload to proper receiver
recvChan := m.protocolReceivers[msg.GetProtocolId()]
if recvChan == nil {
m.ErrorChan <- fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId())
// Try the "unknown protocol" receiver if we didn't find an explicit one
recvChan = m.protocolReceivers[PROTOCOL_UNKNOWN]
if recvChan == nil {
m.ErrorChan <- fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId())
}
} else {
m.protocolReceivers[msg.GetProtocolId()] <- msg
}
Expand Down
51 changes: 51 additions & 0 deletions muxer/segment.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package muxer

import (
"time"
)

const (
SEGMENT_PROTOCOL_ID_RESPONSE_FLAG = 0x8000
)

type SegmentHeader struct {
Timestamp uint32
ProtocolId uint16
PayloadLength uint16
}

type Segment struct {
SegmentHeader
Payload []byte
}

func NewSegment(protocolId uint16, payload []byte, isResponse bool) *Segment {
header := SegmentHeader{
Timestamp: uint32(time.Now().UnixNano() & 0xffffffff),
ProtocolId: protocolId,
}
if isResponse {
header.ProtocolId = header.ProtocolId + SEGMENT_PROTOCOL_ID_RESPONSE_FLAG
}
header.PayloadLength = uint16(len(payload))
segment := &Segment{
SegmentHeader: header,
Payload: payload,
}
return segment
}

func (s *SegmentHeader) IsRequest() bool {
return (s.ProtocolId & SEGMENT_PROTOCOL_ID_RESPONSE_FLAG) == 0
}

func (s *SegmentHeader) IsResponse() bool {
return (s.ProtocolId & SEGMENT_PROTOCOL_ID_RESPONSE_FLAG) > 0
}

func (s *SegmentHeader) GetProtocolId() uint16 {
if s.ProtocolId >= SEGMENT_PROTOCOL_ID_RESPONSE_FLAG {
return s.ProtocolId - SEGMENT_PROTOCOL_ID_RESPONSE_FLAG
}
return s.ProtocolId
}
Loading

0 comments on commit 11df25c

Please sign in to comment.