Skip to content

Commit

Permalink
Implement validation for buffered partial messages (#829)
Browse files Browse the repository at this point in the history
When messages arrive with yet to be discovered EC chain we want to do as
much validation as possible before buffering the message for future use.

Because the key of the EC chain is included in the partial messages we
are able to validate basically everything about the message except the
chain itself. The changes here implement this ability.

The implementation introduces a new validator implementation
specifically written to handle partial messages. The validation rules
are similar to the full GMesssage validator but then are adopted to
infer the state of a message from chain key instead. A separate issue
is captured to reduce duplicate rules across the two validators. This
refactor is postponed on purpose as it touches a lot of the core
implementations. Instead, the implementation here aims to provide a
correct partial validation mechanism to progress efforts.

The new validation makes a final validation attempt once the chain is
discovered to check the validity of chain, its consistency with key and
justification.

The partial validator maintains its own cache, similar to the full
validator, where messages are evicted as soon as GPBFT progresses. The
caches between full and partial validators are independent and do not
have an overlap: a message that can immediately be completed is never
cached by the partial validator and vise versa. This should keep the
total memory footprint across the both validators basically the same as
before.

Fixes #813
  • Loading branch information
masih authored Jan 16, 2025
1 parent 935bafa commit 79a7008
Show file tree
Hide file tree
Showing 4 changed files with 452 additions and 43 deletions.
62 changes: 32 additions & 30 deletions host.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/filecoin-project/go-f3/certstore"
"github.com/filecoin-project/go-f3/ec"
"github.com/filecoin-project/go-f3/gpbft"
"github.com/filecoin-project/go-f3/internal/caching"
"github.com/filecoin-project/go-f3/internal/clock"
"github.com/filecoin-project/go-f3/internal/psutil"
"github.com/filecoin-project/go-f3/internal/writeaheadlog"
Expand Down Expand Up @@ -53,6 +54,8 @@ type gpbftRunner struct {
inputs gpbftInputs
msgEncoding gMessageEncoding
pmm *partialMessageManager
pmv *cachingPartialValidator
pmCache *caching.GroupedSet
}

type roundPhase struct {
Expand Down Expand Up @@ -148,6 +151,10 @@ func newRunner(
return nil, fmt.Errorf("creating partial message manager: %w", err)
}

runner.pmCache = caching.NewGroupedSet(int(m.CommitteeLookback), 25_000)
obfuscatedHost := (*gpbftHost)(runner)
runner.pmv = newCachingPartialValidator(obfuscatedHost, runner.Progress, runner.pmCache, m.CommitteeLookback, runner.pmm.chainex)

return runner, nil
}

Expand Down Expand Up @@ -231,23 +238,13 @@ func (h *gpbftRunner) Start(ctx context.Context) (_err error) {
// errors.
log.Errorf("error when processing message: %+v", err)
}
case gmsg, ok := <-completedMessageQueue:
case pvmsg, ok := <-completedMessageQueue:
if !ok {
return fmt.Errorf("incoming completed message queue closed")
}
switch validatedMessage, err := h.participant.ValidateMessage(gmsg); {
case errors.Is(err, gpbft.ErrValidationInvalid):
log.Debugw("validation error while validating completed message", "err", err)
// TODO: Signal partial message manager to penalise sender,
// e.g. reduce the total number of messages stroed from sender?
case errors.Is(err, gpbft.ErrValidationTooOld):
// TODO: Signal partial message manager to drop the instance?
case errors.Is(err, gpbft.ErrValidationNotRelevant):
// TODO: Signal partial message manager to drop irrelevant messages?
case errors.Is(err, gpbft.ErrValidationNoCommittee):
log.Debugw("committee error while validating completed message", "err", err)
switch validatedMessage, err := h.pmv.ValidateMessage(pvmsg); {
case err != nil:
log.Errorw("unknown error while validating completed message", "err", err)
log.Debugw("Invalid partially validated message", "err", err)
default:
recordValidatedMessage(ctx, validatedMessage)
if err := h.participant.ReceiveMessage(validatedMessage); err != nil {
Expand Down Expand Up @@ -552,29 +549,33 @@ func (h *gpbftRunner) validatePubsubMessage(ctx context.Context, _ peer.ID, msg

gmsg, completed := h.pmm.CompleteMessage(ctx, pgmsg)
if !completed {
// TODO: Partially validate the message because we can. To do this, however,
// message validator needs to be refactored to tolerate partial data.
// Hence, for now validation is postponed entirely until that refactor
// is done to accommodate partial messages.
// See: https://github.com/filecoin-project/go-f3/issues/813

// FIXME: must verify signature before buffering otherwise nodes can spoof the
// buffer with invalid messages on behalf of other peers as censorship
// attack.
partiallyValidatedMessage, err := h.pmv.PartiallyValidateMessage(pgmsg)
result := pubsubValidationResultFromError(err)
if result == pubsub.ValidationAccept {
msg.ValidatorData = partiallyValidatedMessage
}
return result
}

msg.ValidatorData = pgmsg
return pubsub.ValidationAccept
validatedMessage, err := h.participant.ValidateMessage(gmsg)
result := pubsubValidationResultFromError(err)
if result == pubsub.ValidationAccept {
recordValidatedMessage(ctx, validatedMessage)
msg.ValidatorData = validatedMessage
}
return result
}

switch validatedMessage, err := h.participant.ValidateMessage(gmsg); {
func pubsubValidationResultFromError(err error) pubsub.ValidationResult {
switch {
case errors.Is(err, gpbft.ErrValidationInvalid):
log.Debugf("validation error during validation: %+v", err)
return pubsub.ValidationReject
case errors.Is(err, gpbft.ErrValidationTooOld):
// we got the message too late
// The message has arrived too late to be useful. Ignore it.
return pubsub.ValidationIgnore
case errors.Is(err, gpbft.ErrValidationNotRelevant):
// The message is valid but will not effectively aid progress of GPBFT. Ignore it
// The message is valid but won't effectively aid the progress of GPBFT. Ignore it
// to stop its further propagation across the network.
return pubsub.ValidationIgnore
case errors.Is(err, gpbft.ErrValidationNoCommittee):
Expand All @@ -584,8 +585,6 @@ func (h *gpbftRunner) validatePubsubMessage(ctx context.Context, _ peer.ID, msg
log.Infof("unknown error during validation: %+v", err)
return pubsub.ValidationIgnore
default:
recordValidatedMessage(ctx, validatedMessage)
msg.ValidatorData = validatedMessage
return pubsub.ValidationAccept
}
}
Expand Down Expand Up @@ -667,7 +666,7 @@ func (h *gpbftRunner) startPubsub() (<-chan gpbft.ValidatedMessage, error) {
case <-h.runningCtx.Done():
return nil
}
case *PartialGMessage:
case *PartiallyValidatedMessage:
h.pmm.bufferPartialMessage(h.runningCtx, gmsg)
default:
log.Errorf("invalid msgValidatorData: %+v", msg.ValidatorData)
Expand Down Expand Up @@ -787,6 +786,9 @@ func (h *gpbftHost) SetAlarm(at time.Time) {
func (h *gpbftHost) ReceiveDecision(decision *gpbft.Justification) (time.Time, error) {
log.Infow("reached a decision", "instance", decision.Vote.Instance,
"ecHeadEpoch", decision.Vote.Value.Head().Epoch)
if decision.Vote.Instance > 0 {
h.pmCache.RemoveGroupsLessThan(decision.Vote.Instance - 1)
}
cert, err := h.saveDecision(decision)
if err != nil {
err := fmt.Errorf("error while saving decision: %+v", err)
Expand Down
5 changes: 4 additions & 1 deletion merkle/merkle.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ import (
"golang.org/x/crypto/sha3"
)

// DigestLength is the length of a Digest in number of bytes.
const DigestLength = 32

// Digest is a 32-byte hash digest.
type Digest = [32]byte
type Digest = [DigestLength]byte

// TreeWithProofs returns a the root of the merkle-tree of the given values, along with merkle-proofs for
// each leaf.
Expand Down
24 changes: 12 additions & 12 deletions partial_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ type partialMessageManager struct {

// pmByInstance is a map of instance to a buffer of partial messages that are
// keyed by sender+instance+round+phase.
pmByInstance map[uint64]*lru.Cache[partialMessageKey, *PartialGMessage]
pmByInstance map[uint64]*lru.Cache[partialMessageKey, *PartiallyValidatedMessage]
// pmkByInstanceByChainKey is used for an auxiliary lookup of all partial
// messages for a given vote value at an instance.
pmkByInstanceByChainKey map[uint64]map[string][]partialMessageKey
// pendingPartialMessages is a channel of partial messages that are pending to be buffered.
pendingPartialMessages chan *PartialGMessage
pendingPartialMessages chan *PartiallyValidatedMessage
// pendingDiscoveredChains is a channel of chains discovered by chainexchange
// that are pending to be processed.
pendingDiscoveredChains chan *discoveredChain
Expand All @@ -49,10 +49,10 @@ type partialMessageManager struct {

func newPartialMessageManager(progress gpbft.Progress, ps *pubsub.PubSub, m *manifest.Manifest) (*partialMessageManager, error) {
pmm := &partialMessageManager{
pmByInstance: make(map[uint64]*lru.Cache[partialMessageKey, *PartialGMessage]),
pmByInstance: make(map[uint64]*lru.Cache[partialMessageKey, *PartiallyValidatedMessage]),
pmkByInstanceByChainKey: make(map[uint64]map[string][]partialMessageKey),
pendingDiscoveredChains: make(chan *discoveredChain, 100), // TODO: parameterize buffer size.
pendingPartialMessages: make(chan *PartialGMessage, 100), // TODO: parameterize buffer size.
pendingDiscoveredChains: make(chan *discoveredChain, 100), // TODO: parameterize buffer size.
pendingPartialMessages: make(chan *PartiallyValidatedMessage, 100), // TODO: parameterize buffer size.
}
var err error
pmm.chainex, err = chainexchange.NewPubSubChainExchange(
Expand All @@ -72,12 +72,12 @@ func newPartialMessageManager(progress gpbft.Progress, ps *pubsub.PubSub, m *man
return pmm, nil
}

func (pmm *partialMessageManager) Start(ctx context.Context) (<-chan *gpbft.GMessage, error) {
func (pmm *partialMessageManager) Start(ctx context.Context) (<-chan *PartiallyValidatedMessage, error) {
if err := pmm.chainex.Start(ctx); err != nil {
return nil, fmt.Errorf("starting chain exchange: %w", err)
}

completedMessages := make(chan *gpbft.GMessage, 100) // TODO: parameterize buffer size.
completedMessages := make(chan *PartiallyValidatedMessage, 100) // TODO: parameterize buffer size.
ctx, pmm.stop = context.WithCancel(context.Background())
go func() {
defer func() {
Expand Down Expand Up @@ -111,11 +111,11 @@ func (pmm *partialMessageManager) Start(ctx context.Context) (<-chan *gpbft.GMes
for _, messageKey := range partialMessageKeys {
if pgmsg, found := buffer.Get(messageKey); found {
pgmsg.Vote.Value = discovered.chain
inferJustificationVoteValue(pgmsg)
inferJustificationVoteValue(pgmsg.PartialGMessage)
select {
case <-ctx.Done():
return
case completedMessages <- pgmsg.GMessage:
case completedMessages <- pgmsg:
default:
log.Warnw("Dropped completed message as the gpbft runner is too slow to consume them.", "msg", pgmsg.GMessage)
}
Expand Down Expand Up @@ -214,7 +214,7 @@ func (pmm *partialMessageManager) NotifyChainDiscovered(ctx context.Context, key
}
}

func (pmm *partialMessageManager) bufferPartialMessage(ctx context.Context, msg *PartialGMessage) {
func (pmm *partialMessageManager) bufferPartialMessage(ctx context.Context, msg *PartiallyValidatedMessage) {
select {
case <-ctx.Done():
return
Expand All @@ -229,14 +229,14 @@ func (pmm *partialMessageManager) bufferPartialMessage(ctx context.Context, msg
}
}

func (pmm *partialMessageManager) getOrInitPartialMessageBuffer(instance uint64) *lru.Cache[partialMessageKey, *PartialGMessage] {
func (pmm *partialMessageManager) getOrInitPartialMessageBuffer(instance uint64) *lru.Cache[partialMessageKey, *PartiallyValidatedMessage] {
buffer, found := pmm.pmByInstance[instance]
if !found {
// TODO: parameterize this in the manifest?
// Based on 5 phases, 2K network size at a couple of rounds plus some headroom.
const maxBufferedMessagesPerInstance = 25_000
var err error
buffer, err = lru.New[partialMessageKey, *PartialGMessage](maxBufferedMessagesPerInstance)
buffer, err = lru.New[partialMessageKey, *PartiallyValidatedMessage](maxBufferedMessagesPerInstance)
if err != nil {
log.Fatalf("Failed to create buffer for instance %d: %s", instance, err)
panic(err)
Expand Down
Loading

0 comments on commit 79a7008

Please sign in to comment.