diff --git a/cmd/go-ouroboros-network/chainsync.go b/cmd/go-ouroboros-network/chainsync.go index 4f88a44f..361e00fe 100644 --- a/cmd/go-ouroboros-network/chainsync.go +++ b/cmd/go-ouroboros-network/chainsync.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/cloudstruct/go-cardano-ledger" ouroboros "github.com/cloudstruct/go-ouroboros-network" - "github.com/cloudstruct/go-ouroboros-network/protocol/blockfetch" "github.com/cloudstruct/go-ouroboros-network/protocol/chainsync" "github.com/cloudstruct/go-ouroboros-network/protocol/common" "os" @@ -86,15 +85,6 @@ func buildChainSyncConfig() chainsync.Config { ) } -func buildBlockFetchConfig() blockfetch.Config { - return blockfetch.NewConfig( - blockfetch.WithStartBatchFunc(blockFetchStartBatchHandler), - blockfetch.WithNoBlocksFunc(blockFetchNoBlocksHandler), - blockfetch.WithBlockFunc(blockFetchBlockHandler), - blockfetch.WithBatchDoneFunc(blockFetchBatchDoneHandler), - ) -} - func testChainSync(f *globalFlags) { chainSyncFlags := newChainSyncFlags() err := chainSyncFlags.flagset.Parse(f.flagset.Args()[1:]) @@ -133,7 +123,6 @@ func testChainSync(f *globalFlags) { ouroboros.WithErrorChan(errorChan), ouroboros.WithNodeToNode(f.ntnProto), ouroboros.WithKeepAlive(true), - ouroboros.WithBlockFetchConfig(buildBlockFetchConfig()), ouroboros.WithChainSyncConfig(buildChainSyncConfig()), ) if err != nil { @@ -177,25 +166,16 @@ func chainSyncRollBackwardHandler(point common.Point, tip chainsync.Tip) error { } func chainSyncRollForwardHandler(blockType uint, blockData interface{}, tip chainsync.Tip) error { + var block ledger.Block switch v := blockData.(type) { case ledger.Block: - switch blockType { - case ledger.BLOCK_TYPE_BYRON_EBB: - byronEbbBlock := v.(*ledger.ByronEpochBoundaryBlock) - fmt.Printf("era = Byron (EBB), epoch = %d, id = %s\n", byronEbbBlock.Header.ConsensusData.Epoch, byronEbbBlock.Hash()) - case ledger.BLOCK_TYPE_BYRON_MAIN: - byronBlock := v.(*ledger.ByronMainBlock) - fmt.Printf("era = Byron, epoch = %d, slot = %d, id = %s\n", byronBlock.Header.ConsensusData.SlotId.Epoch, byronBlock.SlotNumber(), byronBlock.Hash()) - default: - fmt.Printf("era = %s, slot = %d, block_no = %d, id = %s\n", v.Era().Name, v.SlotNumber(), v.BlockNumber(), v.Hash()) - } + block = v case ledger.BlockHeader: var blockSlot uint64 var blockHash []byte switch blockType { case ledger.BLOCK_TYPE_BYRON_EBB: byronEbbHeader := v.(*ledger.ByronEpochBoundaryBlockHeader) - //fmt.Printf("era = Byron (EBB), epoch = %d, id = %s\n", h.ConsensusData.Epoch, h.Hash()) if syncState.byronEpochSlot > 0 { syncState.byronEpochBaseSlot += syncState.byronEpochSlot + 1 } @@ -203,7 +183,6 @@ func chainSyncRollForwardHandler(blockType uint, blockData interface{}, tip chai blockHash, _ = hex.DecodeString(byronEbbHeader.Hash()) case ledger.BLOCK_TYPE_BYRON_MAIN: byronHeader := v.(*ledger.ByronMainBlockHeader) - //fmt.Printf("era = Byron, epoch = %d, slot = %d, id = %s\n", h.ConsensusData.SlotId.Epoch, h.ConsensusData.SlotId.Slot, h.Hash()) syncState.byronEpochSlot = uint64(byronHeader.ConsensusData.SlotId.Slot) blockSlot = syncState.byronEpochBaseSlot + syncState.byronEpochSlot blockHash, _ = hex.DecodeString(byronHeader.Hash()) @@ -211,38 +190,22 @@ func chainSyncRollForwardHandler(blockType uint, blockData interface{}, tip chai blockSlot = v.SlotNumber() blockHash, _ = hex.DecodeString(v.Hash()) } - if err := syncState.oConn.BlockFetch().Client.RequestRange([]interface{}{blockSlot, blockHash}, []interface{}{blockSlot, blockHash}); err != nil { - fmt.Printf("error calling RequestRange: %s\n", err) + var err error + block, err = syncState.oConn.BlockFetch().Client.GetBlock(common.NewPoint(blockSlot, blockHash)) + if err != nil { return err } } - return nil -} - -func blockFetchStartBatchHandler() error { - return nil -} - -func blockFetchNoBlocksHandler() error { - fmt.Printf("blockFetchNoBlocksHandler()\n") - return nil -} - -func blockFetchBlockHandler(blockType uint, blockData interface{}) error { + // Display block info switch blockType { case ledger.BLOCK_TYPE_BYRON_EBB: - b := blockData.(*ledger.ByronEpochBoundaryBlock) - fmt.Printf("era = Byron (EBB), id = %s\n", b.Hash()) + byronEbbBlock := block.(*ledger.ByronEpochBoundaryBlock) + fmt.Printf("era = Byron (EBB), epoch = %d, id = %s\n", byronEbbBlock.Header.ConsensusData.Epoch, byronEbbBlock.Hash()) case ledger.BLOCK_TYPE_BYRON_MAIN: - b := blockData.(*ledger.ByronMainBlock) - fmt.Printf("era = Byron, epoch = %d, slot = %d, id = %s\n", b.Header.ConsensusData.SlotId.Epoch, b.SlotNumber(), b.Hash()) + byronBlock := block.(*ledger.ByronMainBlock) + fmt.Printf("era = Byron, epoch = %d, slot = %d, id = %s\n", byronBlock.Header.ConsensusData.SlotId.Epoch, byronBlock.SlotNumber(), byronBlock.Hash()) default: - b := blockData.(ledger.Block) - fmt.Printf("era = %s, slot = %d, block_no = %d, id = %s\n", b.Era().Name, b.SlotNumber(), b.BlockNumber(), b.Hash()) + fmt.Printf("era = %s, slot = %d, block_no = %d, id = %s\n", block.Era().Name, block.SlotNumber(), block.BlockNumber(), block.Hash()) } return nil } - -func blockFetchBatchDoneHandler() error { - return nil -} diff --git a/protocol/blockfetch/blockfetch.go b/protocol/blockfetch/blockfetch.go index 9d9b5ef3..6ccff604 100644 --- a/protocol/blockfetch/blockfetch.go +++ b/protocol/blockfetch/blockfetch.go @@ -4,6 +4,8 @@ import ( "time" "github.com/cloudstruct/go-ouroboros-network/protocol" + + "github.com/cloudstruct/go-cardano-ledger" ) const ( @@ -69,19 +71,13 @@ type BlockFetch struct { } type Config struct { - StartBatchFunc StartBatchFunc - NoBlocksFunc NoBlocksFunc BlockFunc BlockFunc - BatchDoneFunc BatchDoneFunc BatchStartTimeout time.Duration BlockTimeout time.Duration } // Callback function types -type StartBatchFunc func() error -type NoBlocksFunc func() error -type BlockFunc func(uint, interface{}) error -type BatchDoneFunc func() error +type BlockFunc func(ledger.Block) error func New(protoOptions protocol.ProtocolOptions, cfg *Config) *BlockFetch { b := &BlockFetch{ @@ -105,30 +101,12 @@ func NewConfig(options ...BlockFetchOptionFunc) Config { return c } -func WithStartBatchFunc(startBatchFunc StartBatchFunc) BlockFetchOptionFunc { - return func(c *Config) { - c.StartBatchFunc = startBatchFunc - } -} - -func WithNoBlocksFunc(noBlocksFunc NoBlocksFunc) BlockFetchOptionFunc { - return func(c *Config) { - c.NoBlocksFunc = noBlocksFunc - } -} - func WithBlockFunc(blockFunc BlockFunc) BlockFetchOptionFunc { return func(c *Config) { c.BlockFunc = blockFunc } } -func WithBatchDoneFunc(BatchDoneFunc BatchDoneFunc) BlockFetchOptionFunc { - return func(c *Config) { - c.BatchDoneFunc = BatchDoneFunc - } -} - func WithBatchStartTimeout(timeout time.Duration) BlockFetchOptionFunc { return func(c *Config) { c.BatchStartTimeout = timeout diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index e00ece69..20a56317 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -2,14 +2,22 @@ package blockfetch import ( "fmt" - "github.com/cloudstruct/go-cardano-ledger" + "sync" + "github.com/cloudstruct/go-ouroboros-network/protocol" + "github.com/cloudstruct/go-ouroboros-network/protocol/common" "github.com/cloudstruct/go-ouroboros-network/utils" + + "github.com/cloudstruct/go-cardano-ledger" ) type Client struct { *protocol.Protocol - config *Config + config *Config + blockChan chan ledger.Block + startBatchResultChan chan error + busyMutex sync.Mutex + blockUseCallback bool } func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { @@ -18,7 +26,9 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { cfg = &tmpCfg } c := &Client{ - config: cfg, + config: cfg, + blockChan: make(chan ledger.Block), + startBatchResultChan: make(chan error), } // Update state map with timeouts stateMap := StateMap.Copy() @@ -44,17 +54,55 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { InitialState: STATE_IDLE, } c.Protocol = protocol.New(protoConfig) + // Start goroutine to cleanup resources on protocol shutdown + go func() { + <-c.Protocol.DoneChan() + close(c.blockChan) + }() return c } -func (c *Client) RequestRange(start []interface{}, end []interface{}) error { - msg := NewMsgRequestRange(start, end) +func (c *Client) Stop() error { + msg := NewMsgClientDone() return c.SendMessage(msg) } -func (c *Client) ClientDone() error { - msg := NewMsgClientDone() - return c.SendMessage(msg) +// GetBlockRange starts an async process to fetch all blocks in the specified range (inclusive) +func (c *Client) GetBlockRange(start common.Point, end common.Point) error { + c.busyMutex.Lock() + c.blockUseCallback = true + msg := NewMsgRequestRange(start, end) + if err := c.SendMessage(msg); err != nil { + c.busyMutex.Unlock() + return err + } + err := <-c.startBatchResultChan + if err != nil { + c.busyMutex.Unlock() + return err + } + return nil +} + +// GetBlock requests and returns a single block specified by the provided point +func (c *Client) GetBlock(point common.Point) (ledger.Block, error) { + c.busyMutex.Lock() + c.blockUseCallback = false + msg := NewMsgRequestRange(point, point) + if err := c.SendMessage(msg); err != nil { + c.busyMutex.Unlock() + return nil, err + } + err := <-c.startBatchResultChan + if err != nil { + c.busyMutex.Unlock() + return nil, err + } + block, ok := <-c.blockChan + if !ok { + return nil, protocol.ProtocolShuttingDownError + } + return block, nil } func (c *Client) messageHandler(msg protocol.Message, isResponse bool) error { @@ -75,25 +123,17 @@ func (c *Client) messageHandler(msg protocol.Message, isResponse bool) error { } func (c *Client) handleStartBatch() error { - if c.config.StartBatchFunc == nil { - return fmt.Errorf("received block-fetch StartBatch message but no callback function is defined") - } - // Call the user callback function - return c.config.StartBatchFunc() + c.startBatchResultChan <- nil + return nil } func (c *Client) handleNoBlocks() error { - if c.config.NoBlocksFunc == nil { - return fmt.Errorf("received block-fetch NoBlocks message but no callback function is defined") - } - // Call the user callback function - return c.config.NoBlocksFunc() + err := fmt.Errorf("block(s) not found") + c.startBatchResultChan <- err + return nil } func (c *Client) handleBlock(msgGeneric protocol.Message) error { - if c.config.BlockFunc == nil { - return fmt.Errorf("received block-fetch Block message but no callback function is defined") - } msg := msgGeneric.(*MsgBlock) // Decode only enough to get the block type value var wrappedBlock WrappedBlock @@ -104,14 +144,18 @@ func (c *Client) handleBlock(msgGeneric protocol.Message) error { if err != nil { return err } - // Call the user callback function - return c.config.BlockFunc(wrappedBlock.Type, blk) + // We use the callback when requesting ranges and the internal channel for a single block + if c.blockUseCallback { + if err := c.config.BlockFunc(blk); err != nil { + return err + } + } else { + c.blockChan <- blk + } + return nil } func (c *Client) handleBatchDone() error { - if c.config.BatchDoneFunc == nil { - return fmt.Errorf("received block-fetch BatchDone message but no callback function is defined") - } - // Call the user callback function - return c.config.BatchDoneFunc() + c.busyMutex.Unlock() + return nil } diff --git a/protocol/error.go b/protocol/error.go new file mode 100644 index 00000000..205d4ca7 --- /dev/null +++ b/protocol/error.go @@ -0,0 +1,7 @@ +package protocol + +import ( + "fmt" +) + +var ProtocolShuttingDownError = fmt.Errorf("protocol is shutting down") diff --git a/protocol/localtxmonitor/client.go b/protocol/localtxmonitor/client.go index 7c6d194e..9d557515 100644 --- a/protocol/localtxmonitor/client.go +++ b/protocol/localtxmonitor/client.go @@ -143,7 +143,7 @@ func (c *Client) HasTx(txId []byte) (bool, error) { } result, ok := <-c.hasTxResultChan if !ok { - return false, fmt.Errorf("protocol is shutting down") + return false, protocol.ProtocolShuttingDownError } return result, nil } @@ -163,7 +163,7 @@ func (c *Client) NextTx() ([]byte, error) { } tx, ok := <-c.nextTxResultChan if !ok { - return nil, fmt.Errorf("protocol is shutting down") + return nil, protocol.ProtocolShuttingDownError } return tx, nil } @@ -183,7 +183,7 @@ func (c *Client) GetSizes() (uint32, uint32, uint32, error) { } result, ok := <-c.getSizesResultChan if !ok { - return 0, 0, 0, fmt.Errorf("protocol is shutting down") + return 0, 0, 0, protocol.ProtocolShuttingDownError } return result.Capacity, result.Size, result.NumberOfTxs, nil }