diff --git a/api/execute.go b/api/execute.go index 5cbf563..d4beca5 100644 --- a/api/execute.go +++ b/api/execute.go @@ -29,6 +29,11 @@ func (a *API) ExecuteFunction(ctx echo.Context) error { Parameters: req.Parameters, } + err = exr.Valid() + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid request: %w", err)) + } + // Get the execution result. code, id, results, cluster, err := a.Node.ExecuteFunction(ctx.Request().Context(), exr, req.Topic) if err != nil { diff --git a/api/install.go b/api/install.go index 825c369..3068019 100644 --- a/api/install.go +++ b/api/install.go @@ -15,6 +15,15 @@ const ( functionInstallTimeout = 10 * time.Second ) +func (r FunctionInstallRequest) Valid() error { + + if r.Cid == "" { + return errors.New("function CID is required") + } + + return nil +} + func (a *API) InstallFunction(ctx echo.Context) error { // Unpack the API request. @@ -24,8 +33,9 @@ func (a *API) InstallFunction(ctx echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("could not unpack request: %w", err)) } - if req.Uri == "" && req.Cid == "" { - return echo.NewHTTPError(http.StatusBadRequest, errors.New("URI or CID are required")) + err = req.Valid() + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid request: %w", err)) } // Add a deadline to the context. diff --git a/api/result.go b/api/result.go index 9c74181..389d40f 100644 --- a/api/result.go +++ b/api/result.go @@ -8,6 +8,15 @@ import ( "github.com/labstack/echo/v4" ) +func (r FunctionResultRequest) Valid() error { + + if r.Id == "" { + return errors.New("request ID is required") + } + + return nil +} + // ExecutionResult implements the REST API endpoint for retrieving the result of a function execution. func (a *API) ExecutionResult(ctx echo.Context) error { @@ -18,13 +27,13 @@ func (a *API) ExecutionResult(ctx echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("could not unpack request: %w", err)) } - requestID := request.Id - if requestID == "" { + err = request.Valid() + if err != nil { return echo.NewHTTPError(http.StatusBadRequest, errors.New("missing request ID")) } // Lookup execution result. - result, ok := a.Node.ExecutionResult(requestID) + result, ok := a.Node.ExecutionResult(request.Id) if !ok { return ctx.NoContent(http.StatusNotFound) } diff --git a/models/execute/request.go b/models/execute/request.go index cc5ef19..dd53721 100644 --- a/models/execute/request.go +++ b/models/execute/request.go @@ -1,5 +1,11 @@ package execute +import ( + "errors" + + "github.com/hashicorp/go-multierror" +) + // Request describes an execution request. type Request struct { FunctionID string `json:"function_id"` @@ -11,6 +17,21 @@ type Request struct { Signature string `json:"signature,omitempty"` } +func (r Request) Valid() error { + + var err *multierror.Error + + if r.FunctionID == "" { + err = multierror.Append(err, errors.New("function ID is required")) + } + + if r.Method == "" { + err = multierror.Append(err, errors.New("method is required")) + } + + return err.ErrorOrNil() +} + // Parameter represents an execution parameter, modeled as a key-value pair. type Parameter struct { Name string `json:"name,omitempty"` diff --git a/models/request/execute.go b/models/request/execute.go index 64be9cb..13ad027 100644 --- a/models/request/execute.go +++ b/models/request/execute.go @@ -11,6 +11,7 @@ import ( "github.com/blocklessnetwork/b7s/models/codes" "github.com/blocklessnetwork/b7s/models/execute" "github.com/blocklessnetwork/b7s/models/response" + "github.com/hashicorp/go-multierror" ) var _ (json.Marshaler) = (*Execute)(nil) @@ -50,16 +51,23 @@ func (e Execute) MarshalJSON() ([]byte, error) { func (e Execute) Valid() error { + var multierr *multierror.Error + err := e.Request.Valid() + if err != nil { + multierr = multierror.Append(multierr, err) + } + c, err := consensus.Parse(e.Config.ConsensusAlgorithm) if err != nil { - return fmt.Errorf("could not parse consensus algorithm: %w", err) + multierr = multierror.Append(multierr, fmt.Errorf("could not parse consensus algorithm: %w", err)) } if c == consensus.PBFT && e.Config.NodeCount > 0 && e.Config.NodeCount < pbft.MinimumReplicaCount { - return fmt.Errorf("minimum %v nodes needed for PBFT consensus", pbft.MinimumReplicaCount) + + multierr = multierror.Append(multierr, fmt.Errorf("minimum %v nodes needed for PBFT consensus", pbft.MinimumReplicaCount)) } - return nil + return multierr.ErrorOrNil() } diff --git a/models/request/install_function.go b/models/request/install_function.go index 90d7de4..7abd5ea 100644 --- a/models/request/install_function.go +++ b/models/request/install_function.go @@ -2,6 +2,7 @@ package request import ( "encoding/json" + "errors" "github.com/blocklessnetwork/b7s/models/blockless" "github.com/blocklessnetwork/b7s/models/codes" @@ -39,3 +40,12 @@ func (f InstallFunction) MarshalJSON() ([]byte, error) { } return json.Marshal(rec) } + +func (f InstallFunction) Valid() error { + + if f.CID == "" { + return errors.New("function CID is required") + } + + return nil +} diff --git a/node/cluster.go b/node/cluster.go index e0f58f0..2159ceb 100644 --- a/node/cluster.go +++ b/node/cluster.go @@ -18,13 +18,7 @@ import ( func (n *Node) processFormCluster(ctx context.Context, from peer.ID, req request.FormCluster) error { - // Should never happen. - if !n.isWorker() { - n.log.Warn().Str("peer", from.String()).Msg("only worker nodes participate in consensus clusters") - return nil - } - - n.log.Info().Str("request", req.RequestID).Strs("peers", blockless.PeerIDsToStr(req.Peers)).Str("consensus", req.Consensus.String()).Msg("received request to form consensus cluster") + n.log.Info().Str("request", req.RequestID).Strs("peers", blockless.PeerIDsToStr(req.Peers)).Stringer("consensus", req.Consensus).Msg("received request to form consensus cluster") // Add connection info about peers if we're not already connected to them. for _, addrInfo := range req.ConnectionInfo { @@ -60,7 +54,7 @@ func (n *Node) processFormCluster(ctx context.Context, from peer.ID, req request // processFormClusterResponse will record the cluster formation response. func (n *Node) processFormClusterResponse(ctx context.Context, from peer.ID, res response.FormCluster) error { - n.log.Debug().Str("request", res.RequestID).Str("from", from.String()).Msg("received cluster formation response") + n.log.Debug().Str("request", res.RequestID).Stringer("from", from).Msg("received cluster formation response") key := consensusResponseKey(res.RequestID, from) n.consensusResponses.Set(key, res) @@ -71,20 +65,14 @@ func (n *Node) processFormClusterResponse(ctx context.Context, from peer.ID, res // processDisbandCluster will start cluster shutdown command. func (n *Node) processDisbandCluster(ctx context.Context, from peer.ID, req request.DisbandCluster) error { - // Should never happen. - if !n.isWorker() { - n.log.Warn().Str("peer", from.String()).Msg("only worker nodes participate in consensus clusters") - return nil - } - - n.log.Info().Str("peer", from.String()).Str("request", req.RequestID).Msg("received request to disband consensus cluster") + n.log.Info().Stringer("peer", from).Str("request", req.RequestID).Msg("received request to disband consensus cluster") err := n.leaveCluster(req.RequestID, consensusClusterDisbandTimeout) if err != nil { return fmt.Errorf("could not disband cluster (request: %s): %w", req.RequestID, err) } - n.log.Info().Str("peer", from.String()).Str("request", req.RequestID).Msg("left consensus cluster") + n.log.Info().Stringer("peer", from).Str("request", req.RequestID).Msg("left consensus cluster") return nil } @@ -140,10 +128,10 @@ func (n *Node) formCluster(ctx context.Context, requestID string, replicas []pee return } - n.log.Info().Str("request", requestID).Str("peer", rp.String()).Msg("accounted consensus cluster response from roll called peer") + n.log.Info().Str("request", requestID).Stringer("peer", rp).Msg("accounted consensus cluster response from roll called peer") if fc.Code != codes.OK { - log.Warn().Str("peer", rp.String()).Msg("peer failed to join consensus cluster") + log.Warn().Stringer("peer", rp).Msg("peer failed to join consensus cluster") return } diff --git a/node/handlers.go b/node/handlers.go index c56e29f..50d3bce 100644 --- a/node/handlers.go +++ b/node/handlers.go @@ -10,19 +10,19 @@ import ( ) func (n *Node) processHealthCheck(ctx context.Context, from peer.ID, _ response.Health) error { - n.log.Trace().Str("from", from.String()).Msg("peer health check received") + n.log.Trace().Stringer("peer", from).Msg("peer health check received") return nil } func (n *Node) processRollCallResponse(ctx context.Context, from peer.ID, res response.RollCall) error { - log := n.log.With().Str("request", res.RequestID).Str("peer", from.String()).Logger() + log := n.log.With().Str("request", res.RequestID).Stringer("peer", from).Logger() log.Debug().Msg("processing peers roll call response") // Check if the response is adequate. if res.Code != codes.Accepted { - log.Info().Str("code", res.Code.String()).Msg("skipping inadequate roll call response - unwanted code") + log.Info().Stringer("code", res.Code).Msg("skipping inadequate roll call response - unwanted code") return nil } @@ -47,6 +47,6 @@ func (n *Node) processRollCallResponse(ctx context.Context, from peer.ID, res re } func (n *Node) processInstallFunctionResponse(ctx context.Context, from peer.ID, res response.InstallFunction) error { - n.log.Trace().Str("from", from.String()).Str("cid", res.CID).Msg("function install response received") + n.log.Trace().Stringer("peer", from).Str("cid", res.CID).Msg("function install response received") return nil } diff --git a/node/handlers_internal_test.go b/node/handlers_internal_test.go index 01619ba..07f043a 100644 --- a/node/handlers_internal_test.go +++ b/node/handlers_internal_test.go @@ -120,14 +120,6 @@ func TestNode_InstallFunction(t *testing.T) { CID: cid, } - t.Run("head node handles install", func(t *testing.T) { - t.Parallel() - - node := createNode(t, blockless.HeadNode) - - err := node.processInstallFunction(context.Background(), mocks.GenericPeerID, installReq) - require.NoError(t, err) - }) t.Run("worker node handles install", func(t *testing.T) { t.Parallel() diff --git a/node/install.go b/node/install.go index da71566..0c2907f 100644 --- a/node/install.go +++ b/node/install.go @@ -6,19 +6,12 @@ import ( "github.com/libp2p/go-libp2p/core/peer" - "github.com/blocklessnetwork/b7s/models/blockless" "github.com/blocklessnetwork/b7s/models/codes" "github.com/blocklessnetwork/b7s/models/request" ) func (n *Node) processInstallFunction(ctx context.Context, from peer.ID, req request.InstallFunction) error { - // Only workers should respond to function install requests. - if n.cfg.Role != blockless.WorkerNode { - n.log.Debug().Msg("received function install request, ignoring") - return nil - } - // Install function. err := n.installFunction(ctx, req.CID, req.ManifestURL) if err != nil { diff --git a/node/pipeline.go b/node/pipeline.go index e8bc4c6..d2de642 100644 --- a/node/pipeline.go +++ b/node/pipeline.go @@ -1,15 +1,11 @@ package node import ( - "errors" - "github.com/blocklessnetwork/b7s/models/blockless" pp "github.com/blocklessnetwork/b7s/node/internal/pipeline" ) -var errDisallowedMessage = errors.New("disallowed message") - -func allowedMessage(msg string, pipeline pp.Pipeline) error { +func messageAllowedOnPipeline(msg string, pipeline pp.Pipeline) bool { if pipeline.ID == pp.DirectMessage { @@ -22,10 +18,10 @@ func allowedMessage(msg string, pipeline pp.Pipeline) error { // Technically we only publish InstallFunction. However, it's handy for tests to support // direct install, and it's somewhat of a low risk. - return errDisallowedMessage + return false default: - return nil + return true } } @@ -40,9 +36,9 @@ func allowedMessage(msg string, pipeline pp.Pipeline) error { blockless.MessageDisbandCluster, blockless.MessageRollCallResponse: - return errDisallowedMessage + return false default: - return nil + return true } } diff --git a/node/pipeline_internal_test.go b/node/pipeline_internal_test.go index 3ff2b97..232d636 100644 --- a/node/pipeline_internal_test.go +++ b/node/pipeline_internal_test.go @@ -33,7 +33,7 @@ func TestNode_DisallowedMessages(t *testing.T) { } for _, test := range tests { - err := allowedMessage(test.message, test.pipeline) - require.ErrorIsf(t, err, errDisallowedMessage, "message: %s, pipeline: %s", test.message, test.pipeline) + ok := messageAllowedOnPipeline(test.message, test.pipeline) + require.False(t, ok, "message: %s, pipeline: %s", test.message, test.pipeline) } } diff --git a/node/process.go b/node/process.go index 95cc1dd..e31fc86 100644 --- a/node/process.go +++ b/node/process.go @@ -23,6 +23,18 @@ func (n *Node) processMessage(ctx context.Context, from peer.ID, payload []byte, return fmt.Errorf("could not unpack message: %w", err) } + log := n.log.With().Stringer("peer", from).Str("type", msgType).Stringer("pipeline", pipeline).Logger() + + if !messageAllowedOnPipeline(msgType, pipeline) { + log.Debug().Msg("message not allowed on pipeline") + return nil + } + + if !n.messageAllowedForRole(msgType) { + log.Debug().Msg("message not intended for our role") + return nil + } + n.metrics.IncrCounterWithLabels(messagesProcessedMetric, 1, []metrics.Label{{Name: "type", Value: msgType}}) defer func() { switch procError { @@ -55,14 +67,6 @@ func (n *Node) processMessage(ctx context.Context, from peer.ID, payload []byte, span.SetStatus(otelcodes.Error, spanStatusErr) }() - log := n.log.With().Str("peer", from.String()).Str("type", msgType).Str("pipeline", pipeline.String()).Logger() - - err = allowedMessage(msgType, pipeline) - if err != nil { - log.Warn().Msg("message not allowed on pipeline") - return nil - } - log.Debug().Msg("received message from peer") switch msgType { @@ -96,6 +100,42 @@ func (n *Node) processMessage(ctx context.Context, from peer.ID, payload []byte, } } +func (n *Node) messageAllowedForRole(msgType string) bool { + + // Worker node allowed messages. + if n.isWorker() { + switch msgType { + case blockless.MessageHealthCheck, + blockless.MessageInstallFunction, + blockless.MessageRollCall, + blockless.MessageExecute, + blockless.MessageFormCluster, + blockless.MessageDisbandCluster: + return true + + default: + return false + } + } + + // Head node allowed messages. + switch msgType { + + case blockless.MessageHealthCheck, + blockless.MessageInstallFunctionResponse, + blockless.MessageRollCallResponse, + blockless.MessageExecute, + blockless.MessageExecuteResponse, + blockless.MessageFormClusterResponse: + + // NOTE: We provide a mechanism via the REST API to broadcast function install, so there's a case for this being supported. + return true + + default: + return false + } +} + func handleMessage[T blockless.Message](ctx context.Context, from peer.ID, payload []byte, processFunc func(ctx context.Context, from peer.ID, msg T) error) error { var msg T @@ -104,6 +144,19 @@ func handleMessage[T blockless.Message](ctx context.Context, from peer.ID, paylo return fmt.Errorf("could not unmarshal message: %w", err) } + // If the message provides a validation mechanism - use it. + type validator interface { + Valid() error + } + + vmsg, ok := any(msg).(validator) + if ok { + err = vmsg.Valid() + if err != nil { + return fmt.Errorf("rejecting message that failed validation: %w", err) + } + } + return processFunc(ctx, from, msg) } diff --git a/node/roll_call.go b/node/roll_call.go index 40b6b43..eb5c192 100644 --- a/node/roll_call.go +++ b/node/roll_call.go @@ -18,12 +18,6 @@ import ( func (n *Node) processRollCall(ctx context.Context, from peer.ID, req request.RollCall) error { - // Only workers respond to roll calls at the moment. - if n.cfg.Role != blockless.WorkerNode { - n.log.Debug().Msg("skipping roll call as a non-worker node") - return nil - } - n.metrics.IncrCounterWithLabels(rollCallsSeenMetric, 1, []metrics.Label{{Name: "function", Value: req.FunctionID}}) log := n.log.With().Str("request", req.RequestID).Str("origin", req.Origin.String()).Str("function", req.FunctionID).Logger() diff --git a/node/roll_call_internal_test.go b/node/roll_call_internal_test.go index b55105d..10578da 100644 --- a/node/roll_call_internal_test.go +++ b/node/roll_call_internal_test.go @@ -21,19 +21,6 @@ import ( func TestNode_RollCall(t *testing.T) { - t.Run("head node handles roll call", func(t *testing.T) { - t.Parallel() - - rollCallReq := request.RollCall{ - FunctionID: "dummy-function-id", - RequestID: mocks.GenericUUID.String(), - } - - node := createNode(t, blockless.HeadNode) - err := node.processRollCall(context.Background(), mocks.GenericPeerID, rollCallReq) - require.NoError(t, err) - }) - t.Run("worker node handles roll call", func(t *testing.T) { t.Parallel() diff --git a/node/telemetry_message_internal_test.go b/node/telemetry_message_internal_test.go index bbf9e6b..2fcd2e3 100644 --- a/node/telemetry_message_internal_test.go +++ b/node/telemetry_message_internal_test.go @@ -273,55 +273,47 @@ func TestNode_ProcessedMessageMetric(t *testing.T) { node.metrics = m - // Messages to send. We will send multiple health check, execution response and install response messages. + // Messages to send. We will send multiple health check and disband cluster messages. // Note that not all messages make sense in the context of a real-world node, but we just care about having // a few messages flow through the system. var ( // Do between 1 and 10 messages. limit = 10 healthcheckCount = rand.Intn(limit) + 1 - execCount = rand.Intn(limit) + 1 - installCount = rand.Intn(limit) + 1 + disbandCount = rand.Intn(limit) + 1 healthCheck = response.Health{} - execResponse = response.Execute{ + disbandRequest = request.DisbandCluster{ RequestID: newRequestID(), - Results: execute.ResultMap{mocks.GenericPeerID: execute.NodeResult{Result: mocks.GenericExecutionResult}}, - } - - instResponse = response.InstallFunction{ - CID: mocks.GenericFunctionRecord.CID, } ) msgs := []struct { - count int - payload []byte + count int + pipeline pipeline.Pipeline + payload []byte }{ { - count: healthcheckCount, - payload: serialize(t, healthCheck), + count: healthcheckCount, + pipeline: pipeline.PubSubPipeline(DefaultTopic), + payload: serialize(t, healthCheck), }, { - count: execCount, - payload: serialize(t, execResponse), - }, - { - count: installCount, - payload: serialize(t, instResponse), + count: disbandCount, + pipeline: pipeline.DirectMessagePipeline(), + payload: serialize(t, disbandRequest), }, } for _, msg := range msgs { for i := 0; i < msg.count; i++ { - err = node.processMessage(ctx, mocks.GenericPeerID, msg.payload, pipeline.PubSubPipeline(DefaultTopic)) - require.NoError(t, err) + // We don't care if the message was processed okay (disband cluster will fail). + _ = node.processMessage(ctx, mocks.GenericPeerID, msg.payload, msg.pipeline) } } metricMap := helpers.MetricMap(t, registry) helpers.CounterCmp(t, metricMap, float64(healthcheckCount), "b7s_node_messages_processed", "type", "MsgHealthCheck") - helpers.CounterCmp(t, metricMap, float64(execCount), "b7s_node_messages_processed", "type", "MsgExecuteResponse") - helpers.CounterCmp(t, metricMap, float64(installCount), "b7s_node_messages_processed", "type", "MsgInstallFunctionResponse") + helpers.CounterCmp(t, metricMap, float64(disbandCount), "b7s_node_messages_processed", "type", "MsgDisbandCluster") } diff --git a/telemetry/tracing/propagation.go b/telemetry/tracing/propagation.go index b26f213..0c8237b 100644 --- a/telemetry/tracing/propagation.go +++ b/telemetry/tracing/propagation.go @@ -10,7 +10,7 @@ import ( ) type TraceInfo struct { - Carrier propagation.MapCarrier + Carrier propagation.MapCarrier `json:"carrier,omitempty"` } // Empty returns true if the TraceInfo structure contains any tracing information.