From d8e3ee49384ab3bf597a45b06e9d955b188c3a1b Mon Sep 17 00:00:00 2001 From: Aurora Gaffney Date: Fri, 27 Oct 2023 01:09:36 -0500 Subject: [PATCH] feat: finish implementing server side of local-tx-monitor Fixes #352 --- protocol/localtxmonitor/localtxmonitor.go | 19 ++++++ protocol/localtxmonitor/server.go | 78 ++++++++++++++++++++++- 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/protocol/localtxmonitor/localtxmonitor.go b/protocol/localtxmonitor/localtxmonitor.go index c2745644..21b3b3ce 100644 --- a/protocol/localtxmonitor/localtxmonitor.go +++ b/protocol/localtxmonitor/localtxmonitor.go @@ -18,6 +18,7 @@ package localtxmonitor import ( "time" + "github.com/blinklabs-io/gouroboros/ledger" "github.com/blinklabs-io/gouroboros/protocol" ) @@ -114,10 +115,21 @@ type LocalTxMonitor struct { // Config is used to configure the LocalTxMonitor protocol instance type Config struct { + GetMempoolFunc GetMempoolFunc AcquireTimeout time.Duration QueryTimeout time.Duration } +// Helper types +type TxAndEraId struct { + EraId uint + Tx []byte + txObj ledger.Transaction +} + +// Callback function types +type GetMempoolFunc func() (uint64, uint32, []TxAndEraId, error) + // New returns a new LocalTxMonitor object func New(protoOptions protocol.ProtocolOptions, cfg *Config) *LocalTxMonitor { l := &LocalTxMonitor{ @@ -143,6 +155,13 @@ func NewConfig(options ...LocalTxMonitorOptionFunc) Config { return c } +// WithGetMempoolFunc specifies the callback function for retrieving the mempool +func WithGetMempoolFunc(getMempoolFunc GetMempoolFunc) LocalTxMonitorOptionFunc { + return func(c *Config) { + c.GetMempoolFunc = getMempoolFunc + } +} + // WithAcquireTimeout specifies the timeout for acquire operations when acting as a client func WithAcquireTimeout(timeout time.Duration) LocalTxMonitorOptionFunc { return func(c *Config) { diff --git a/protocol/localtxmonitor/server.go b/protocol/localtxmonitor/server.go index 57fc0dcb..a10128f8 100644 --- a/protocol/localtxmonitor/server.go +++ b/protocol/localtxmonitor/server.go @@ -15,14 +15,20 @@ package localtxmonitor import ( + "encoding/hex" "fmt" + + "github.com/blinklabs-io/gouroboros/ledger" "github.com/blinklabs-io/gouroboros/protocol" ) // Server implements the LocalTxMonitor server type Server struct { *protocol.Protocol - config *Config + config *Config + mempoolCapacity uint32 + mempoolTxs []TxAndEraId + mempoolNextTxIdx int } // NewServer returns a new Server object @@ -72,6 +78,36 @@ func (s *Server) messageHandler(msg protocol.Message, isResponse bool) error { } func (s *Server) handleAcquire() error { + if s.config.GetMempoolFunc == nil { + return fmt.Errorf( + "received local-tx-monitor Acquire message but no GetMempool callback function is defined", + ) + } + // Call the user callback function to get mempool information + mempoolSlotNumber, mempoolCapacity, mempoolTxs, err := s.config.GetMempoolFunc() + if err != nil { + return err + } + s.mempoolCapacity = mempoolCapacity + s.mempoolNextTxIdx = 0 + s.mempoolTxs = make([]TxAndEraId, 0) + for _, mempoolTx := range mempoolTxs { + newTx := TxAndEraId{ + EraId: mempoolTx.EraId, + Tx: mempoolTx.Tx[:], + } + // Pre-parse TX for convenience + tmpTxObj, err := ledger.NewTransactionFromCbor(mempoolTx.EraId, mempoolTx.Tx) + if err != nil { + return err + } + newTx.txObj = tmpTxObj + s.mempoolTxs = append(s.mempoolTxs, newTx) + } + newMsg := NewMsgAcquired(mempoolSlotNumber) + if err := s.SendMessage(newMsg); err != nil { + return err + } return nil } @@ -80,17 +116,57 @@ func (s *Server) handleDone() error { } func (s *Server) handleRelease() error { + s.mempoolCapacity = 0 + s.mempoolTxs = nil return nil } func (s *Server) handleHasTx(msg protocol.Message) error { + msgHasTx := msg.(*MsgHasTx) + txId := hex.EncodeToString(msgHasTx.TxId) + hasTx := false + for _, tx := range s.mempoolTxs { + if tx.txObj.Hash() == txId { + hasTx = true + break + } + } + newMsg := NewMsgReplyHasTx(hasTx) + if err := s.SendMessage(newMsg); err != nil { + return err + } return nil } func (s *Server) handleNextTx() error { + if s.mempoolNextTxIdx > len(s.mempoolTxs) { + newMsg := NewMsgReplyNextTx(0, nil) + if err := s.SendMessage(newMsg); err != nil { + return err + } + return nil + } + mempoolTx := s.mempoolTxs[s.mempoolNextTxIdx] + newMsg := NewMsgReplyNextTx(uint8(mempoolTx.EraId), mempoolTx.Tx) + if err := s.SendMessage(newMsg); err != nil { + return err + } + s.mempoolNextTxIdx++ return nil } func (s *Server) handleGetSizes() error { + totalTxSize := 0 + for _, tx := range s.mempoolTxs { + totalTxSize += len(tx.Tx) + } + newMsg := NewMsgReplyGetSizes( + s.mempoolCapacity, + uint32(totalTxSize), + uint32(len(s.mempoolTxs)), + ) + if err := s.SendMessage(newMsg); err != nil { + return err + } return nil }