Skip to content

Commit

Permalink
Improve message validation (#175)
Browse files Browse the repository at this point in the history
* Improve message validation

* Remove obsolete test cases

* Fix validation

* Update log message usage - stringer vs str
  • Loading branch information
Maelkum authored Nov 27, 2024
1 parent adee252 commit 1f86fc1
Show file tree
Hide file tree
Showing 17 changed files with 165 additions and 107 deletions.
5 changes: 5 additions & 0 deletions api/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ func (a *API) ExecuteFunction(ctx echo.Context) error {
Parameters: req.Parameters,
}

err = exr.Valid()
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid request: %w", err))
}

// Get the execution result.
code, id, results, cluster, err := a.Node.ExecuteFunction(ctx.Request().Context(), exr, req.Topic)
if err != nil {
Expand Down
14 changes: 12 additions & 2 deletions api/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ const (
functionInstallTimeout = 10 * time.Second
)

func (r FunctionInstallRequest) Valid() error {

if r.Cid == "" {
return errors.New("function CID is required")
}

return nil
}

func (a *API) InstallFunction(ctx echo.Context) error {

// Unpack the API request.
Expand All @@ -24,8 +33,9 @@ func (a *API) InstallFunction(ctx echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("could not unpack request: %w", err))
}

if req.Uri == "" && req.Cid == "" {
return echo.NewHTTPError(http.StatusBadRequest, errors.New("URI or CID are required"))
err = req.Valid()
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid request: %w", err))
}

// Add a deadline to the context.
Expand Down
15 changes: 12 additions & 3 deletions api/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ import (
"github.com/labstack/echo/v4"
)

func (r FunctionResultRequest) Valid() error {

if r.Id == "" {
return errors.New("request ID is required")
}

return nil
}

// ExecutionResult implements the REST API endpoint for retrieving the result of a function execution.
func (a *API) ExecutionResult(ctx echo.Context) error {

Expand All @@ -18,13 +27,13 @@ func (a *API) ExecutionResult(ctx echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("could not unpack request: %w", err))
}

requestID := request.Id
if requestID == "" {
err = request.Valid()
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, errors.New("missing request ID"))
}

// Lookup execution result.
result, ok := a.Node.ExecutionResult(requestID)
result, ok := a.Node.ExecutionResult(request.Id)
if !ok {
return ctx.NoContent(http.StatusNotFound)
}
Expand Down
21 changes: 21 additions & 0 deletions models/execute/request.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package execute

import (
"errors"

"github.com/hashicorp/go-multierror"
)

// Request describes an execution request.
type Request struct {
FunctionID string `json:"function_id"`
Expand All @@ -11,6 +17,21 @@ type Request struct {
Signature string `json:"signature,omitempty"`
}

func (r Request) Valid() error {

var err *multierror.Error

if r.FunctionID == "" {
err = multierror.Append(err, errors.New("function ID is required"))
}

if r.Method == "" {
err = multierror.Append(err, errors.New("method is required"))
}

return err.ErrorOrNil()
}

// Parameter represents an execution parameter, modeled as a key-value pair.
type Parameter struct {
Name string `json:"name,omitempty"`
Expand Down
14 changes: 11 additions & 3 deletions models/request/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/blocklessnetwork/b7s/models/codes"
"github.com/blocklessnetwork/b7s/models/execute"
"github.com/blocklessnetwork/b7s/models/response"
"github.com/hashicorp/go-multierror"
)

var _ (json.Marshaler) = (*Execute)(nil)
Expand Down Expand Up @@ -50,16 +51,23 @@ func (e Execute) MarshalJSON() ([]byte, error) {

func (e Execute) Valid() error {

var multierr *multierror.Error
err := e.Request.Valid()
if err != nil {
multierr = multierror.Append(multierr, err)
}

c, err := consensus.Parse(e.Config.ConsensusAlgorithm)
if err != nil {
return fmt.Errorf("could not parse consensus algorithm: %w", err)
multierr = multierror.Append(multierr, fmt.Errorf("could not parse consensus algorithm: %w", err))
}

if c == consensus.PBFT &&
e.Config.NodeCount > 0 &&
e.Config.NodeCount < pbft.MinimumReplicaCount {
return fmt.Errorf("minimum %v nodes needed for PBFT consensus", pbft.MinimumReplicaCount)

multierr = multierror.Append(multierr, fmt.Errorf("minimum %v nodes needed for PBFT consensus", pbft.MinimumReplicaCount))
}

return nil
return multierr.ErrorOrNil()
}
10 changes: 10 additions & 0 deletions models/request/install_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package request

import (
"encoding/json"
"errors"

"github.com/blocklessnetwork/b7s/models/blockless"
"github.com/blocklessnetwork/b7s/models/codes"
Expand Down Expand Up @@ -39,3 +40,12 @@ func (f InstallFunction) MarshalJSON() ([]byte, error) {
}
return json.Marshal(rec)
}

func (f InstallFunction) Valid() error {

if f.CID == "" {
return errors.New("function CID is required")
}

return nil
}
24 changes: 6 additions & 18 deletions node/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@ import (

func (n *Node) processFormCluster(ctx context.Context, from peer.ID, req request.FormCluster) error {

// Should never happen.
if !n.isWorker() {
n.log.Warn().Str("peer", from.String()).Msg("only worker nodes participate in consensus clusters")
return nil
}

n.log.Info().Str("request", req.RequestID).Strs("peers", blockless.PeerIDsToStr(req.Peers)).Str("consensus", req.Consensus.String()).Msg("received request to form consensus cluster")
n.log.Info().Str("request", req.RequestID).Strs("peers", blockless.PeerIDsToStr(req.Peers)).Stringer("consensus", req.Consensus).Msg("received request to form consensus cluster")

// Add connection info about peers if we're not already connected to them.
for _, addrInfo := range req.ConnectionInfo {
Expand Down Expand Up @@ -60,7 +54,7 @@ func (n *Node) processFormCluster(ctx context.Context, from peer.ID, req request
// processFormClusterResponse will record the cluster formation response.
func (n *Node) processFormClusterResponse(ctx context.Context, from peer.ID, res response.FormCluster) error {

n.log.Debug().Str("request", res.RequestID).Str("from", from.String()).Msg("received cluster formation response")
n.log.Debug().Str("request", res.RequestID).Stringer("from", from).Msg("received cluster formation response")

key := consensusResponseKey(res.RequestID, from)
n.consensusResponses.Set(key, res)
Expand All @@ -71,20 +65,14 @@ func (n *Node) processFormClusterResponse(ctx context.Context, from peer.ID, res
// processDisbandCluster will start cluster shutdown command.
func (n *Node) processDisbandCluster(ctx context.Context, from peer.ID, req request.DisbandCluster) error {

// Should never happen.
if !n.isWorker() {
n.log.Warn().Str("peer", from.String()).Msg("only worker nodes participate in consensus clusters")
return nil
}

n.log.Info().Str("peer", from.String()).Str("request", req.RequestID).Msg("received request to disband consensus cluster")
n.log.Info().Stringer("peer", from).Str("request", req.RequestID).Msg("received request to disband consensus cluster")

err := n.leaveCluster(req.RequestID, consensusClusterDisbandTimeout)
if err != nil {
return fmt.Errorf("could not disband cluster (request: %s): %w", req.RequestID, err)
}

n.log.Info().Str("peer", from.String()).Str("request", req.RequestID).Msg("left consensus cluster")
n.log.Info().Stringer("peer", from).Str("request", req.RequestID).Msg("left consensus cluster")

return nil
}
Expand Down Expand Up @@ -140,10 +128,10 @@ func (n *Node) formCluster(ctx context.Context, requestID string, replicas []pee
return
}

n.log.Info().Str("request", requestID).Str("peer", rp.String()).Msg("accounted consensus cluster response from roll called peer")
n.log.Info().Str("request", requestID).Stringer("peer", rp).Msg("accounted consensus cluster response from roll called peer")

if fc.Code != codes.OK {
log.Warn().Str("peer", rp.String()).Msg("peer failed to join consensus cluster")
log.Warn().Stringer("peer", rp).Msg("peer failed to join consensus cluster")
return
}

Expand Down
8 changes: 4 additions & 4 deletions node/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ import (
)

func (n *Node) processHealthCheck(ctx context.Context, from peer.ID, _ response.Health) error {
n.log.Trace().Str("from", from.String()).Msg("peer health check received")
n.log.Trace().Stringer("peer", from).Msg("peer health check received")
return nil
}

func (n *Node) processRollCallResponse(ctx context.Context, from peer.ID, res response.RollCall) error {

log := n.log.With().Str("request", res.RequestID).Str("peer", from.String()).Logger()
log := n.log.With().Str("request", res.RequestID).Stringer("peer", from).Logger()

log.Debug().Msg("processing peers roll call response")

// Check if the response is adequate.
if res.Code != codes.Accepted {
log.Info().Str("code", res.Code.String()).Msg("skipping inadequate roll call response - unwanted code")
log.Info().Stringer("code", res.Code).Msg("skipping inadequate roll call response - unwanted code")
return nil
}

Expand All @@ -47,6 +47,6 @@ func (n *Node) processRollCallResponse(ctx context.Context, from peer.ID, res re
}

func (n *Node) processInstallFunctionResponse(ctx context.Context, from peer.ID, res response.InstallFunction) error {
n.log.Trace().Str("from", from.String()).Str("cid", res.CID).Msg("function install response received")
n.log.Trace().Stringer("peer", from).Str("cid", res.CID).Msg("function install response received")
return nil
}
8 changes: 0 additions & 8 deletions node/handlers_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,6 @@ func TestNode_InstallFunction(t *testing.T) {
CID: cid,
}

t.Run("head node handles install", func(t *testing.T) {
t.Parallel()

node := createNode(t, blockless.HeadNode)

err := node.processInstallFunction(context.Background(), mocks.GenericPeerID, installReq)
require.NoError(t, err)
})
t.Run("worker node handles install", func(t *testing.T) {
t.Parallel()

Expand Down
7 changes: 0 additions & 7 deletions node/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,12 @@ import (

"github.com/libp2p/go-libp2p/core/peer"

"github.com/blocklessnetwork/b7s/models/blockless"
"github.com/blocklessnetwork/b7s/models/codes"
"github.com/blocklessnetwork/b7s/models/request"
)

func (n *Node) processInstallFunction(ctx context.Context, from peer.ID, req request.InstallFunction) error {

// Only workers should respond to function install requests.
if n.cfg.Role != blockless.WorkerNode {
n.log.Debug().Msg("received function install request, ignoring")
return nil
}

// Install function.
err := n.installFunction(ctx, req.CID, req.ManifestURL)
if err != nil {
Expand Down
14 changes: 5 additions & 9 deletions node/pipeline.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
package node

import (
"errors"

"github.com/blocklessnetwork/b7s/models/blockless"
pp "github.com/blocklessnetwork/b7s/node/internal/pipeline"
)

var errDisallowedMessage = errors.New("disallowed message")

func allowedMessage(msg string, pipeline pp.Pipeline) error {
func messageAllowedOnPipeline(msg string, pipeline pp.Pipeline) bool {

if pipeline.ID == pp.DirectMessage {

Expand All @@ -22,10 +18,10 @@ func allowedMessage(msg string, pipeline pp.Pipeline) error {
// Technically we only publish InstallFunction. However, it's handy for tests to support
// direct install, and it's somewhat of a low risk.

return errDisallowedMessage
return false

default:
return nil
return true
}
}

Expand All @@ -40,9 +36,9 @@ func allowedMessage(msg string, pipeline pp.Pipeline) error {
blockless.MessageDisbandCluster,
blockless.MessageRollCallResponse:

return errDisallowedMessage
return false

default:
return nil
return true
}
}
4 changes: 2 additions & 2 deletions node/pipeline_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestNode_DisallowedMessages(t *testing.T) {
}

for _, test := range tests {
err := allowedMessage(test.message, test.pipeline)
require.ErrorIsf(t, err, errDisallowedMessage, "message: %s, pipeline: %s", test.message, test.pipeline)
ok := messageAllowedOnPipeline(test.message, test.pipeline)
require.False(t, ok, "message: %s, pipeline: %s", test.message, test.pipeline)
}
}
Loading

0 comments on commit 1f86fc1

Please sign in to comment.