Skip to content

Commit

Permalink
Merge pull request #410 from blinklabs-io/feat/protocol-localtxmonito…
Browse files Browse the repository at this point in the history
…r-server

feat: finish implementing server side of local-tx-monitor
  • Loading branch information
agaffney authored Oct 27, 2023
2 parents 7e0b57d + d8e3ee4 commit f1cfd4a
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 1 deletion.
19 changes: 19 additions & 0 deletions protocol/localtxmonitor/localtxmonitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package localtxmonitor
import (
"time"

"github.com/blinklabs-io/gouroboros/ledger"
"github.com/blinklabs-io/gouroboros/protocol"
)

Expand Down Expand Up @@ -114,10 +115,21 @@ type LocalTxMonitor struct {

// Config is used to configure the LocalTxMonitor protocol instance
type Config struct {
GetMempoolFunc GetMempoolFunc
AcquireTimeout time.Duration
QueryTimeout time.Duration
}

// Helper types
type TxAndEraId struct {
EraId uint
Tx []byte
txObj ledger.Transaction
}

// Callback function types
type GetMempoolFunc func() (uint64, uint32, []TxAndEraId, error)

// New returns a new LocalTxMonitor object
func New(protoOptions protocol.ProtocolOptions, cfg *Config) *LocalTxMonitor {
l := &LocalTxMonitor{
Expand All @@ -143,6 +155,13 @@ func NewConfig(options ...LocalTxMonitorOptionFunc) Config {
return c
}

// WithGetMempoolFunc specifies the callback function for retrieving the mempool
func WithGetMempoolFunc(getMempoolFunc GetMempoolFunc) LocalTxMonitorOptionFunc {
return func(c *Config) {
c.GetMempoolFunc = getMempoolFunc
}
}

// WithAcquireTimeout specifies the timeout for acquire operations when acting as a client
func WithAcquireTimeout(timeout time.Duration) LocalTxMonitorOptionFunc {
return func(c *Config) {
Expand Down
78 changes: 77 additions & 1 deletion protocol/localtxmonitor/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
package localtxmonitor

import (
"encoding/hex"
"fmt"

"github.com/blinklabs-io/gouroboros/ledger"
"github.com/blinklabs-io/gouroboros/protocol"
)

// Server implements the LocalTxMonitor server
type Server struct {
*protocol.Protocol
config *Config
config *Config
mempoolCapacity uint32
mempoolTxs []TxAndEraId
mempoolNextTxIdx int
}

// NewServer returns a new Server object
Expand Down Expand Up @@ -72,6 +78,36 @@ func (s *Server) messageHandler(msg protocol.Message, isResponse bool) error {
}

func (s *Server) handleAcquire() error {
if s.config.GetMempoolFunc == nil {
return fmt.Errorf(
"received local-tx-monitor Acquire message but no GetMempool callback function is defined",
)
}
// Call the user callback function to get mempool information
mempoolSlotNumber, mempoolCapacity, mempoolTxs, err := s.config.GetMempoolFunc()
if err != nil {
return err
}
s.mempoolCapacity = mempoolCapacity
s.mempoolNextTxIdx = 0
s.mempoolTxs = make([]TxAndEraId, 0)
for _, mempoolTx := range mempoolTxs {
newTx := TxAndEraId{
EraId: mempoolTx.EraId,
Tx: mempoolTx.Tx[:],
}
// Pre-parse TX for convenience
tmpTxObj, err := ledger.NewTransactionFromCbor(mempoolTx.EraId, mempoolTx.Tx)
if err != nil {
return err
}
newTx.txObj = tmpTxObj
s.mempoolTxs = append(s.mempoolTxs, newTx)
}
newMsg := NewMsgAcquired(mempoolSlotNumber)
if err := s.SendMessage(newMsg); err != nil {
return err
}
return nil
}

Expand All @@ -80,17 +116,57 @@ func (s *Server) handleDone() error {
}

func (s *Server) handleRelease() error {
s.mempoolCapacity = 0
s.mempoolTxs = nil
return nil
}

func (s *Server) handleHasTx(msg protocol.Message) error {
msgHasTx := msg.(*MsgHasTx)
txId := hex.EncodeToString(msgHasTx.TxId)
hasTx := false
for _, tx := range s.mempoolTxs {
if tx.txObj.Hash() == txId {
hasTx = true
break
}
}
newMsg := NewMsgReplyHasTx(hasTx)
if err := s.SendMessage(newMsg); err != nil {
return err
}
return nil
}

func (s *Server) handleNextTx() error {
if s.mempoolNextTxIdx > len(s.mempoolTxs) {
newMsg := NewMsgReplyNextTx(0, nil)
if err := s.SendMessage(newMsg); err != nil {
return err
}
return nil
}
mempoolTx := s.mempoolTxs[s.mempoolNextTxIdx]
newMsg := NewMsgReplyNextTx(uint8(mempoolTx.EraId), mempoolTx.Tx)
if err := s.SendMessage(newMsg); err != nil {
return err
}
s.mempoolNextTxIdx++
return nil
}

func (s *Server) handleGetSizes() error {
totalTxSize := 0
for _, tx := range s.mempoolTxs {
totalTxSize += len(tx.Tx)
}
newMsg := NewMsgReplyGetSizes(
s.mempoolCapacity,
uint32(totalTxSize),
uint32(len(s.mempoolTxs)),
)
if err := s.SendMessage(newMsg); err != nil {
return err
}
return nil
}

0 comments on commit f1cfd4a

Please sign in to comment.