Skip to content

Commit

Permalink
fix: utilize 'msg' parameter in handleRequestNext and pass it to call…
Browse files Browse the repository at this point in the history
…back

Signed-off-by: Ales Verbic <[email protected]>
  • Loading branch information
verbotenj committed Oct 1, 2024
1 parent 1cb4e76 commit fa5f78e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 6 deletions.
2 changes: 1 addition & 1 deletion protocol/chainsync/chainsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ type RollBackwardFunc func(CallbackContext, common.Point, Tip) error
type RollForwardFunc func(CallbackContext, uint, interface{}, Tip) error

type FindIntersectFunc func(CallbackContext, []common.Point) (common.Point, Tip, error)
type RequestNextFunc func(CallbackContext) error
type RequestNextFunc func(CallbackContext, *MsgRequestNext) error

// New returns a new ChainSync object
func New(protoOptions protocol.ProtocolOptions, cfg *Config) *ChainSync {
Expand Down
9 changes: 5 additions & 4 deletions protocol/chainsync/messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ package chainsync
import (
"encoding/hex"
"fmt"
"github.com/blinklabs-io/gouroboros/cbor"
"github.com/blinklabs-io/gouroboros/ledger"
"github.com/blinklabs-io/gouroboros/protocol"
"github.com/blinklabs-io/gouroboros/protocol/common"
"os"
"reflect"
"strings"
"testing"

"github.com/blinklabs-io/gouroboros/cbor"
"github.com/blinklabs-io/gouroboros/ledger"
"github.com/blinklabs-io/gouroboros/protocol"
"github.com/blinklabs-io/gouroboros/protocol/common"
)

type testDefinition struct {
Expand Down
7 changes: 6 additions & 1 deletion protocol/chainsync/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,12 @@ func (s *Server) handleRequestNext(msg protocol.Message) error {
"received chain-sync RequestNext message but no callback function is defined",
)
}
return s.config.RequestNextFunc(s.callbackContext)
msgRequestNext, ok := msg.(*MsgRequestNext)
if !ok {
return fmt.Errorf("expected MsgRequestNext, got %T", msg)
}
// Pass the message to the callback function
return s.config.RequestNextFunc(s.callbackContext, msgRequestNext)
}

func (s *Server) handleFindIntersect(msg protocol.Message) error {
Expand Down
88 changes: 88 additions & 0 deletions protocol/chainsync/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2024 Blink Labs Software
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package chainsync

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestHandleRequestNext_ValidMessage(t *testing.T) {
called := false
var receivedMsg *MsgRequestNext

server := &Server{
config: &Config{
RequestNextFunc: func(ctx CallbackContext, msg *MsgRequestNext) error {
called = true
receivedMsg = msg
// Ensure that the CBOR data is not empty
if len(msg.Cbor()) == 0 {
return fmt.Errorf("expected non-empty CBOR data")
}
return nil
},
},
callbackContext: CallbackContext{},
}

msg := &MsgRequestNext{}
// Fake CBOR data
rawCborData := []byte{0x01, 0x02, 0x03}
msg.SetCbor(rawCborData)

err := server.handleRequestNext(msg)

assert.NoError(t, err, "expected no error")
assert.True(t, called, "expected RequestNextFunc to be called")
assert.Equal(t, msg, receivedMsg, "expected received message to be the same as sent message")
assert.Equal(t, rawCborData, receivedMsg.Cbor(), "expected raw CBOR data to be passed correctly")
}

func TestHandleRequestNext_InvalidMessageType(t *testing.T) {
server := &Server{
config: &Config{
RequestNextFunc: func(ctx CallbackContext, msg *MsgRequestNext) error {
return nil
},
},
callbackContext: CallbackContext{},
}

msg := &MsgFindIntersect{}
err := server.handleRequestNext(msg)
expectedError := fmt.Sprintf("expected MsgRequestNext, got %T", msg)

assert.Error(t, err, "expected an error due to invalid message type")
assert.EqualError(t, err, expectedError)
}

func TestHandleRequestNext_NilCallback(t *testing.T) {
server := &Server{
config: &Config{
RequestNextFunc: nil,
},
callbackContext: CallbackContext{},
}

msg := &MsgRequestNext{}
err := server.handleRequestNext(msg)
expectedError := "received chain-sync RequestNext message but no callback function is defined"

assert.Error(t, err, "expected an error due to nil callback")
assert.EqualError(t, err, expectedError)
}

0 comments on commit fa5f78e

Please sign in to comment.