diff --git a/Makefile b/Makefile index 381f9b57..fecb0f0c 100644 --- a/Makefile +++ b/Makefile @@ -82,7 +82,7 @@ all: build .PHONY: all test: - $(GOTESTSUM) -- -timeout 5m -p 1 ./... + $(GOTESTSUM) -- -race -timeout 5m -p 1 ./... .PHONY: test install: install-buf install-protoc diff --git a/gateway/challenge_verifier/verifier.go b/gateway/challenge_verifier/verifier.go index e10ec5b5..2f202201 100644 --- a/gateway/challenge_verifier/verifier.go +++ b/gateway/challenge_verifier/verifier.go @@ -83,7 +83,7 @@ func (a *caching) Verify(ctx context.Context, challenge, signature []byte) (*Res logger := logging.FromContext(ctx).WithFields(log.Field(zap.Binary("challenge", challengeHash[:]))) if result, ok := a.cache.Get(challengeHash); ok { logger.Debug("retrieved challenge verifier result from the cache") - // SAFETY: type assertion will never panic as we insert only `*ATX` values. + // SAFETY: type assertion will never panic as we insert only `*challengeVerifierResult` values. result := result.(*challengeVerifierResult) return result.Result, result.err } diff --git a/go.mod b/go.mod index ffb9281a..7018f7f7 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/stretchr/testify v1.8.1 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 go.uber.org/zap v1.24.0 + golang.org/x/exp v0.0.0-20221208152030-732eee02a75a golang.org/x/sync v0.1.0 google.golang.org/genproto v0.0.0-20221114212237-e4508ebdbee1 google.golang.org/grpc v1.51.0 diff --git a/go.sum b/go.sum index b953a003..6f86e65d 100644 --- a/go.sum +++ b/go.sum @@ -104,6 +104,8 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20221208152030-732eee02a75a h1:4iLhBPcpqFmylhnkbY3W0ONLUYYkDAW9xMFLfxgsvCw= +golang.org/x/exp v0.0.0-20221208152030-732eee02a75a/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/rpc/rpcserver.go b/rpc/rpcserver.go index ead16097..cbedebe1 100644 --- a/rpc/rpcserver.go +++ b/rpc/rpcserver.go @@ -82,7 +82,7 @@ func (r *rpcServer) Start(ctx context.Context, in *api.StartRequest) (*api.Start verifier, err := service.CreateChallengeVerifier(gtwManager.Connections()) if err != nil { - return nil, fmt.Errorf("failed to create ATX provider: %w", err) + return nil, fmt.Errorf("failed to create challenge verifier: %w", err) } // Swap the new and old gateway managers. @@ -136,7 +136,7 @@ func (r *rpcServer) UpdateGateway(ctx context.Context, in *api.UpdateGatewayRequ verifier, err := service.CreateChallengeVerifier(gtwManager.Connections()) if err != nil { - return nil, fmt.Errorf("failed to create ATX provider: %w", err) + return nil, fmt.Errorf("failed to create challenge verifier: %w", err) } // Swap the new and old gateway managers. @@ -163,13 +163,13 @@ func (r *rpcServer) Submit(ctx context.Context, in *api.SubmitRequest) (*api.Sub } out := new(api.SubmitResponse) - out.RoundId = round.ID + out.RoundId = round out.Hash = hash return out, nil } func (r *rpcServer) GetInfo(ctx context.Context, in *api.GetInfoRequest) (*api.GetInfoResponse, error) { - info, err := r.s.Info() + info, err := r.s.Info(ctx) if err != nil { return nil, err } diff --git a/service/round.go b/service/round.go index 31588b24..d921a0ed 100644 --- a/service/round.go +++ b/service/round.go @@ -7,7 +7,6 @@ import ( "os" "path/filepath" "strconv" - "sync" "time" "github.com/spacemeshos/merkle-tree" @@ -16,6 +15,7 @@ import ( "github.com/syndtr/goleveldb/leveldb/opt" "github.com/spacemeshos/poet/hash" + "github.com/spacemeshos/poet/logging" "github.com/spacemeshos/poet/prover" "github.com/spacemeshos/poet/shared" ) @@ -63,8 +63,6 @@ type round struct { teardownChan chan struct{} stateCache *roundState - - submitMtx sync.Mutex } func (r *round) Epoch() uint32 { @@ -141,8 +139,6 @@ func (r *round) submit(key, challenge []byte) error { return errors.New("round is not open") } - r.submitMtx.Lock() - defer r.submitMtx.Unlock() if has, err := r.challengesDb.Has(key); err != nil { return err } else if has { @@ -170,6 +166,9 @@ func (r *round) isEmpty() bool { } func (r *round) execute(ctx context.Context, end time.Time, minMemoryLayer uint) error { + logger := logging.FromContext(ctx).WithFields(log.String("round", r.ID)) + logger.Info("executing until %v...", end) + r.executionStarted = time.Now() if err := r.saveState(); err != nil { return err @@ -177,13 +176,11 @@ func (r *round) execute(ctx context.Context, end time.Time, minMemoryLayer uint) close(r.executionStartedChan) - r.submitMtx.Lock() var err error r.execution.Members, r.execution.Statement, err = r.calcMembersAndStatement() if err != nil { return err } - r.submitMtx.Unlock() if err := r.saveState(); err != nil { return err @@ -208,6 +205,7 @@ func (r *round) execute(ctx context.Context, end time.Time, minMemoryLayer uint) close(r.executionEndedChan) + logger.Info("execution ended, phi=%x, duration %v", r.execution.NIP.Root, time.Since(r.executionStarted)) return nil } diff --git a/service/round_test.go b/service/round_test.go index 57f6f0fb..f83d0172 100644 --- a/service/round_test.go +++ b/service/round_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "crypto/rand" "fmt" "path/filepath" "testing" @@ -12,6 +13,19 @@ import ( "github.com/spacemeshos/poet/prover" ) +func genChallenges(num int) ([][]byte, error) { + ch := make([][]byte, num) + for i := 0; i < num; i++ { + ch[i] = make([]byte, 32) + _, err := rand.Read(ch[i]) + if err != nil { + return nil, err + } + } + + return ch, nil +} + // TestRound_Recovery test round recovery functionality. // The scenario proceeds as follows: // - Execute r1 as a reference round. @@ -183,7 +197,7 @@ func TestRound_State(t *testing.T) { req.Equal(prevState, state) // Recover execution. - req.NoError(r.recoverExecution(ctx, state.Execution, time.Now().Add(100*time.Microsecond))) + req.NoError(r.recoverExecution(ctx, state.Execution, time.Now().Add(200*time.Millisecond))) req.False(r.executionStarted.IsZero()) proof, err := r.proof(false) diff --git a/service/service.go b/service/service.go index d6d84d90..48c83679 100644 --- a/service/service.go +++ b/service/service.go @@ -54,12 +54,23 @@ type serviceState struct { PrivKey []byte } +// ServiceClient is an interface for interacting with the Service actor. +// It is created when the Service is started. +type ServiceClient struct { + serviceStarted *atomic.Bool + command chan<- Command + challengeVerifier atomic.Value // holds challenge_verifier.Verifier +} + // Service orchestrates rounds functionality; each responsible for accepting challenges, // generating a proof from their hash digest, and broadcasting the result to the Spacemesh network. // // Service is single-use, meaning it can be started with `Start()` and then stopped with `Shutdown()` // but it cannot be restarted. A new instance of `Service` must be created. type Service struct { + commands <-chan Command + ServiceClient + runningGroup errgroup.Group stop context.CancelFunc @@ -71,24 +82,22 @@ type Service struct { // openRound is the round which is currently open for accepting challenges registration from miners. // At any given time there is one single open round. - // openRoundMutex guards openRound, any access to it must be protected by this mutex. - openRound *round - openRoundMutex sync.RWMutex - - // executingRounds are the rounds which are currently executing, hence generating a proof. - executingRounds map[string]*round - executingRoundsMutex sync.RWMutex + openRound *round + executingRounds map[string]struct{} PubKey ed25519.PublicKey privKey ed25519.PrivateKey - broadcaster atomic.Value // holds Broadcaster interface - challengeVerifier atomic.Value // holds challenge_verifier.Verifier + broadcaster Broadcaster - errChan chan error sync.Mutex } +// Command is a function that will be run in the main Service loop. +// Commands are run serially hence they don't require additional synchronization. +// The functions cannot block and should be kept short to not block the Service loop. +type Command func(*Service) + type InfoResponse struct { OpenRoundID string ExecutingRoundsIds []string @@ -164,65 +173,112 @@ func NewService(cfg *Config, datadir string) (*Service, error) { state = initialState() } + cmds := make(chan Command, 1) + privateKey := ed25519.NewKeyFromSeed(state.PrivKey[:32]) s := &Service{ + commands: cmds, + ServiceClient: ServiceClient{ + command: cmds, + }, cfg: cfg, minMemoryLayer: uint(minMemoryLayer), genesis: genesis, datadir: datadir, - executingRounds: make(map[string]*round), - errChan: make(chan error, 10), + executingRounds: make(map[string]struct{}), privKey: privateKey, PubKey: privateKey.Public().(ed25519.PublicKey), } + s.ServiceClient.serviceStarted = &s.started + log.Info("Service public key: %x", s.PubKey) return s, nil } -func (s *Service) loop(ctx context.Context) { - var executingRounds errgroup.Group - defer executingRounds.Wait() +type roundResult struct { + round *round + err error +} + +func (s *Service) loop(ctx context.Context, roundsToResume []*round) { + var eg errgroup.Group + defer eg.Wait() + + logger := log.AppLog.WithName("worker") + ctx = logging.NewContext(ctx, logger) + + roundResults := make(chan roundResult, 1) + + // Resume recovered rounds + for _, round := range roundsToResume { + round := round + s.executingRounds[round.ID] = struct{}{} + end := s.roundEndTime(round) + eg.Go(func() error { + err := round.recoverExecution(ctx, round.stateCache.Execution, end) + roundResults <- roundResult{round: round, err: err} + return nil + }) + } + + timer := s.scheduleRound(ctx, s.openRound) for { - s.openRoundMutex.RLock() - epoch := s.openRound.Epoch() - s.openRoundMutex.RUnlock() - - start := s.genesis.Add(s.cfg.EpochDuration * time.Duration(epoch)).Add(s.cfg.PhaseShift) - waitTime := time.Until(start) - timer := time.After(waitTime) - if waitTime > 0 { - log.Info("Round %v waiting for execution to start for %v", s.openRoundID(), waitTime) - } select { + case cmd := <-s.commands: + cmd(s) + + case result := <-roundResults: + if result.err == nil { + broadcaster := s.broadcaster + go broadcastProof(s, result.round, result.round.execution, broadcaster) + } else { + logger.With().Warning("round execution failed", log.Err(result.err), log.String("round", result.round.ID)) + } + delete(s.executingRounds, result.round.ID) + case <-timer: + round := s.openRound + s.openRound = s.newRound(ctx, round.Epoch()+1) + s.executingRounds[round.ID] = struct{}{} + + end := s.roundEndTime(round) + minMemoryLayer := s.minMemoryLayer + eg.Go(func() error { + err := round.execute(ctx, end, minMemoryLayer) + roundResults <- roundResult{round, err} + return nil + }) + + // schedule the next round + timer = s.scheduleRound(ctx, s.openRound) + case <-ctx.Done(): - log.Info("service shutting down") - s.openRoundMutex.Lock() - s.openRound = nil - s.openRoundMutex.Unlock() + logger.Info("service shutting down") return } + } +} - s.openRoundMutex.Lock() - prevRound := s.openRound - s.newRound(ctx, prevRound.Epoch()+1) - s.openRoundMutex.Unlock() +func (s *Service) roundStartTime(round *round) time.Time { + return s.genesis.Add(s.cfg.PhaseShift).Add(s.cfg.EpochDuration * time.Duration(round.Epoch())) +} - executingRounds.Go(func() error { - round := prevRound - if err := s.executeRound(ctx, round); err != nil { - s.asyncError(fmt.Errorf("round %v execution error: %v", round.ID, err)) - return nil - } - broadcastProof(s, round, round.execution, s.getBroadcaster()) - return nil - }) +func (s *Service) roundEndTime(round *round) time.Time { + return s.roundStartTime(round).Add(s.cfg.EpochDuration).Add(-s.cfg.CycleGap) +} + +func (s *Service) scheduleRound(ctx context.Context, round *round) <-chan time.Time { + waitTime := time.Until(s.roundStartTime(round)) + timer := time.After(waitTime) + if waitTime > 0 { + logging.FromContext(ctx).With().Info("waiting for execution to start", log.Duration("wait time", waitTime), log.String("round", round.ID)) } + return timer } -func (s *Service) Start(b Broadcaster, atxProvider challenge_verifier.Verifier) error { +func (s *Service) Start(b Broadcaster, verifier challenge_verifier.Verifier) error { s.Lock() defer s.Unlock() if s.Started() { @@ -234,27 +290,32 @@ func (s *Service) Start(b Broadcaster, atxProvider challenge_verifier.Verifier) ctx, stop := context.WithCancel(context.Background()) s.stop = stop - s.SetBroadcaster(b) - s.SetChallengeVerifier(atxProvider) + s.broadcaster = b + var toResume []*round if s.cfg.NoRecovery { log.Info("Recovery is disabled") - } else if err := s.recover(ctx); err != nil { - return fmt.Errorf("failed to recover: %v", err) + } else { + var err error + s.openRound, toResume, err = s.recover(ctx) + if err != nil { + return fmt.Errorf("failed to recover: %v", err) + } } + now := time.Now() epoch := time.Duration(0) if d := now.Sub(s.genesis); d > 0 { epoch = d / s.cfg.EpochDuration } - s.openRoundMutex.Lock() if s.openRound == nil { - s.newRound(ctx, uint32(epoch)) + s.openRound = s.newRound(ctx, uint32(epoch)) } - s.openRoundMutex.Unlock() + + s.ServiceClient.SetChallengeVerifier(verifier) s.runningGroup.Go(func() error { - s.loop(ctx) + s.loop(ctx, toResume) return nil }) s.started.Store(true) @@ -279,125 +340,75 @@ func (s *Service) Started() bool { return s.started.Load() } -func (s *Service) recover(ctx context.Context) error { - log.With().Info("Recovering service state", log.String("datadir", s.datadir)) +func (s *Service) recover(ctx context.Context) (open *round, executing []*round, err error) { + logger := log.AppLog.WithName("recovery") + logger.With().Info("Recovering service state", log.String("datadir", s.datadir)) entries, err := os.ReadDir(s.datadir) if err != nil { - return err + return nil, nil, err } for _, entry := range entries { - log.Info("Recovering entry %s", entry.Name()) + logger.Info("recovering entry %s", entry.Name()) if !entry.IsDir() { continue } epoch, err := strconv.ParseUint(entry.Name(), 10, 32) if err != nil { - return fmt.Errorf("entry is not a uint32 %s", entry.Name()) + return nil, nil, fmt.Errorf("entry is not a uint32 %s", entry.Name()) } r := newRound(ctx, s.datadir, uint32(epoch)) state, err := r.state() if err != nil { - return fmt.Errorf("invalid round state: %v", err) + return nil, nil, fmt.Errorf("invalid round state: %v", err) } if state.isExecuted() { - log.Info("Recovery: found round %v in executed state. broadcasting...", r.ID) - go broadcastProof(s, r, state.Execution, s.getBroadcaster()) + logger.Info("found round %v in executed state. broadcasting...", r.ID) + go broadcastProof(s, r, state.Execution, s.broadcaster) continue } if state.isOpen() { - log.Info("Recovery: found round %v in open state.", r.ID) + logger.Info("found round %v in open state.", r.ID) if err := r.open(); err != nil { - return fmt.Errorf("failed to open round: %v", err) + return nil, nil, fmt.Errorf("failed to open round: %v", err) } // Keep the last open round as openRound (multiple open rounds state is possible // only if recovery was previously disabled). - s.openRound = r + open = r continue } - log.Info("Recovery: found round %v in executing state. recovering execution...", r.ID) - s.executingRoundsMutex.Lock() - s.executingRounds[r.ID] = r - s.executingRoundsMutex.Unlock() - s.runningGroup.Go(func() error { - r, rs := r, state - defer func() { - s.executingRoundsMutex.Lock() - delete(s.executingRounds, r.ID) - s.executingRoundsMutex.Unlock() - }() - - end := s.genesis. - Add(s.cfg.EpochDuration * time.Duration(r.Epoch()+1)). - Add(s.cfg.PhaseShift). - Add(-s.cfg.CycleGap) - - if err = r.recoverExecution(ctx, rs.Execution, end); err != nil { - s.asyncError(fmt.Errorf("recovery: round %v execution failure: %v", r.ID, err)) - return nil - } - - log.Info("Recovery: round %v execution ended, phi=%x", r.ID, r.execution.NIP.Root) - broadcastProof(s, r, r.execution, s.getBroadcaster()) - return nil - }) + logger.Info("found round %v in executing state.", r.ID) + executing = append(executing, r) } - return nil + return open, executing, nil } -func (s *Service) getBroadcaster() Broadcaster { - return s.broadcaster.Load().(Broadcaster) -} - -func (s *Service) SetBroadcaster(b Broadcaster) { - if s.broadcaster.Swap(b) != nil { - log.Info("Service broadcaster updated") +func (s *ServiceClient) SetBroadcaster(b Broadcaster) { + // No need to wait for the Command to execute. + s.command <- func(s *Service) { + old := s.broadcaster + s.broadcaster = b + if old != nil { + log.Info("Service broadcaster updated") + } } } -func (s *Service) SetChallengeVerifier(provider challenge_verifier.Verifier) { +func (s *ServiceClient) SetChallengeVerifier(provider challenge_verifier.Verifier) { s.challengeVerifier.Store(provider) } -func (s *Service) executeRound(ctx context.Context, r *round) error { - s.executingRoundsMutex.Lock() - s.executingRounds[r.ID] = r - s.executingRoundsMutex.Unlock() - - defer func() { - s.executingRoundsMutex.Lock() - delete(s.executingRounds, r.ID) - s.executingRoundsMutex.Unlock() - }() - - start := time.Now() - end := s.genesis. - Add(s.cfg.EpochDuration * time.Duration(r.Epoch()+1)). - Add(s.cfg.PhaseShift). - Add(-s.cfg.CycleGap) - - log.Info("Round %v executing until %v...", r.ID, end) - - if err := r.execute(ctx, end, uint(s.minMemoryLayer)); err != nil { - return err +func (s *ServiceClient) Submit(ctx context.Context, challenge, signature []byte) (string, []byte, error) { + if !s.serviceStarted.Load() { + return "", nil, ErrNotStarted } - - log.Info("Round %v execution ended, phi=%x, duration %v", r.ID, r.execution.NIP.Root, time.Since(start)) - - return nil -} - -func (s *Service) Submit(ctx context.Context, challenge, signature []byte) (*round, []byte, error) { logger := logging.FromContext(ctx) - if !s.Started() { - return nil, nil, ErrNotStarted - } logger.Debug("Received challenge") // SAFETY: it will never panic as `s.ChallengeVerifier` is set in Start @@ -405,42 +416,65 @@ func (s *Service) Submit(ctx context.Context, challenge, signature []byte) (*rou result, err := verifier.Verify(ctx, challenge, signature) if err != nil { logger.With().Debug("challenge verification failed", log.Err(err)) - return nil, nil, err + return "", nil, err } logger.With().Debug("verified challenge", log.String("hash", hex.EncodeToString(result.Hash)), log.String("node_id", hex.EncodeToString(result.NodeId))) - s.openRoundMutex.Lock() - r := s.openRound - err = r.submit(result.NodeId, result.Hash) - s.openRoundMutex.Unlock() - switch { - case errors.Is(err, ErrChallengeAlreadySubmitted): - return r, result.Hash, nil - case err != nil: - return nil, nil, err + + type response struct { + round string + err error + } + done := make(chan response, 1) + s.command <- func(s *Service) { + done <- response{ + round: s.openRound.ID, + err: s.openRound.submit(result.NodeId, result.Hash), + } + close(done) + } + + select { + case resp := <-done: + switch { + case errors.Is(resp.err, ErrChallengeAlreadySubmitted): + return resp.round, result.Hash, nil + case err != nil: + return "", nil, resp.err + } + logger.With().Debug("submitted challenge for round", log.String("round", resp.round)) + return resp.round, result.Hash, nil + case <-ctx.Done(): + return "", nil, ctx.Err() } - return r, result.Hash, nil } -func (s *Service) Info() (*InfoResponse, error) { - if !s.Started() { +func (s *ServiceClient) Info(ctx context.Context) (*InfoResponse, error) { + if !s.serviceStarted.Load() { return nil, ErrNotStarted } - s.executingRoundsMutex.RLock() - ids := make([]string, 0, len(s.executingRounds)) - for id := range s.executingRounds { - ids = append(ids, id) + resp := make(chan *InfoResponse, 1) + s.command <- func(s *Service) { + defer close(resp) + ids := make([]string, 0, len(s.executingRounds)) + for id := range s.executingRounds { + ids = append(ids, id) + } + resp <- &InfoResponse{ + OpenRoundID: s.openRound.ID, + ExecutingRoundsIds: ids, + } + } + select { + case resp := <-resp: + return resp, nil + case <-ctx.Done(): + return nil, ctx.Err() } - s.executingRoundsMutex.RUnlock() - - return &InfoResponse{ - OpenRoundID: s.openRoundID(), - ExecutingRoundsIds: ids, - }, nil } -// newRound creates a new round with the given epoch. This method MUST be guarded by a write lock on openRoundMutex. -func (s *Service) newRound(ctx context.Context, epoch uint32) { +// newRound creates a new round with the given epoch. +func (s *Service) newRound(ctx context.Context, epoch uint32) *round { if err := saveState(s.datadir, s.privKey); err != nil { panic(err) } @@ -449,19 +483,8 @@ func (s *Service) newRound(ctx context.Context, epoch uint32) { panic(fmt.Errorf("failed to open round: %v", err)) } - s.openRound = r - log.With().Info("Round opened", log.String("ID", s.openRound.ID)) -} - -func (s *Service) openRoundID() string { - s.openRoundMutex.RLock() - defer s.openRoundMutex.RUnlock() - return s.openRound.ID -} - -func (s *Service) asyncError(err error) { - log.Error(err.Error()) - s.errChan <- err + log.With().Info("Round opened", log.String("ID", r.ID)) + return r } func broadcastProof(s *Service, r *round, execution *executionState, broadcaster Broadcaster) { diff --git a/service/service_test.go b/service/service_test.go index 3a06d2fe..28ebcf57 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -1,4 +1,4 @@ -package service +package service_test import ( "bytes" @@ -11,13 +11,13 @@ import ( "github.com/golang/mock/gomock" "github.com/spacemeshos/go-scale" - "github.com/spacemeshos/merkle-tree" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "github.com/spacemeshos/poet/gateway/challenge_verifier" "github.com/spacemeshos/poet/gateway/challenge_verifier/mocks" - "github.com/spacemeshos/poet/prover" + "github.com/spacemeshos/poet/service" ) type MockBroadcaster struct { @@ -32,13 +32,12 @@ func (b *MockBroadcaster) BroadcastProof(msg []byte, roundID string, members [][ type challenge struct { data []byte nodeID []byte - round *round } func TestService_Recovery(t *testing.T) { req := require.New(t) broadcaster := &MockBroadcaster{receivedMessages: make(chan []byte)} - cfg := &Config{ + cfg := &service.Config{ Genesis: time.Now().Add(time.Second).Format(time.RFC3339), EpochDuration: time.Second, PhaseShift: time.Second / 2, @@ -49,20 +48,17 @@ func TestService_Recovery(t *testing.T) { verifier := mocks.NewMockVerifier(ctrl) tempdir := t.TempDir() + // Create a new service instance. - s, err := NewService(cfg, tempdir) + s, err := service.NewService(cfg, tempdir) req.NoError(err) err = s.Start(broadcaster, verifier) req.NoError(err) - // Track the service rounds. - numRounds := 3 - rounds := make([]*round, numRounds) - // Generate 4 groups of random challenges. - challengeGroupSize := 40 - challengeGroups := make([][]challenge, 4) - for i := 0; i < 4; i++ { + challengeGroupSize := 10 + challengeGroups := make([][]challenge, 3) + for i := 0; i < 3; i++ { challengeGroup := make([]challenge, challengeGroupSize) for i := 0; i < challengeGroupSize; i++ { challengeGroup[i] = challenge{data: make([]byte, 32), nodeID: make([]byte, 32)} @@ -74,130 +70,68 @@ func TestService_Recovery(t *testing.T) { challengeGroups[i] = challengeGroup } - submittedChallenges := make(map[int][]challenge) - submitChallenges := func(roundIndex int, groupIndex int) { - challengesGroup := challengeGroups[groupIndex] - for _, challenge := range challengeGroups[groupIndex] { + submitChallenges := func(roundID string, challenges []challenge) { + for _, challenge := range challenges { verifier.EXPECT().Verify(gomock.Any(), challenge.data, nil).Return(&challenge_verifier.Result{Hash: challenge.data, NodeId: challenge.nodeID}, nil) round, hash, err := s.Submit(context.Background(), challenge.data, nil) req.NoError(err) req.Equal(challenge.data, hash) - req.Equal(strconv.Itoa(roundIndex), round.ID) - - // Verify that all submissions returned the same round instance. - if rounds[roundIndex] == nil { - rounds[roundIndex] = round - } else { - req.Equal(rounds[roundIndex], round) - } + req.Equal(roundID, round) } - - // Track the submitted challenges per-round for later validation. - submittedChallenges[roundIndex] = append(submittedChallenges[roundIndex], challengesGroup...) } // Submit challenges to open round (0). - submitChallenges(0, 0) - - // Verify that round is still open. - req.Equal(rounds[0].ID, s.openRoundID()) + submitChallenges("0", challengeGroups[0]) // Wait for round 0 to start executing. - select { - case <-rounds[0].executionStartedChan: - case err := <-s.errChan: - req.Fail(err.Error()) - } - - // Verify that round iteration proceeds: a new round opened, previous round is executing. - s.Lock() - req.Contains(s.executingRounds, rounds[0].ID) - s.Unlock() - s.openRoundMutex.Lock() - rounds[1] = s.openRound - s.openRoundMutex.Unlock() + req.Eventually(func() bool { + info, err := s.Info(context.Background()) + req.NoError(err) + return slices.Contains(info.ExecutingRoundsIds, "0") + }, cfg.EpochDuration*2, time.Millisecond*20) // Submit challenges to open round (1). - submitChallenges(1, 1) + submitChallenges("1", challengeGroups[1]) req.NoError(s.Shutdown()) - // Verify shutdown error is received. - select { - case err := <-s.errChan: - req.EqualError(err, fmt.Sprintf("round %v execution error: %v", rounds[0].ID, prover.ErrShutdownRequested.Error())) - case <-rounds[0].executionEndedChan: - req.Fail("round execution ended instead of shutting down") - } - - // Verify service state. should have no open or executing rounds. Check after a delay to allow service to update state. - time.Sleep(100 * time.Millisecond) - s.openRoundMutex.Lock() - req.Nil(s.openRound) - s.openRoundMutex.Unlock() - s.Lock() - req.Equal(len(s.executingRounds), 0) - s.Unlock() // Create a new service instance. - s, err = NewService(cfg, tempdir) + s, err = service.NewService(cfg, tempdir) req.NoError(err) err = s.Start(broadcaster, verifier) req.NoError(err) - // Service instance should recover 2 rounds: round 1 in executing state, and round 2 in open state. - req.Eventually(func() bool { s.Lock(); defer s.Unlock(); return len(s.executingRounds) == 1 }, time.Second, time.Millisecond) - s.Lock() - first, ok := s.executingRounds["0"] - s.Unlock() - req.True(ok) - req.Equal(s.openRoundID(), "1") - rounds[0] = first - s.openRoundMutex.Lock() - rounds[1] = s.openRound - s.openRoundMutex.Unlock() - - select { - case <-rounds[1].executionStartedChan: - case err := <-s.errChan: - req.Fail(err.Error()) - } - - submitChallenges(2, 2) + // Service instance should recover 2 rounds: round 0 in executing state, and round 1 in open state. + info, err := s.Info(context.Background()) + req.NoError(err) + req.Equal("1", info.OpenRoundID) + req.Len(info.ExecutingRoundsIds, 1) + req.Contains(info.ExecutingRoundsIds, "0") + req.Equal([]string{"0"}, info.ExecutingRoundsIds) + + // Wait for round 2 to open + req.Eventually(func() bool { + info, err := s.Info(context.Background()) + req.NoError(err) + return info.OpenRoundID == "2" + }, cfg.EpochDuration*2, time.Millisecond*20) - select { - case <-rounds[2].executionEndedChan: - case err := <-s.errChan: - req.Fail(err.Error()) - } + submitChallenges("2", challengeGroups[2]) - // Verify that new service instance broadcast 3 distinct rounds proofs, by the expected order. - for i := 0; i < numRounds; i++ { - proofMsg := PoetProofMessage{} - select { - case <-time.After(10 * time.Second): - req.Fail("proof message wasn't sent") - case msg := <-broadcaster.receivedMessages: - dec := scale.NewDecoder(bytes.NewReader(msg)) - _, err := proofMsg.DecodeScale(dec) - req.NoError(err) - } + for i := 0; i < len(challengeGroups); i++ { + msg := <-broadcaster.receivedMessages + proofMsg := service.PoetProofMessage{} + dec := scale.NewDecoder(bytes.NewReader(msg)) + _, err := proofMsg.DecodeScale(dec) + req.NoError(err) + req.Equal(strconv.Itoa(i), proofMsg.RoundID) // Verify the submitted challenges. - req.Len(proofMsg.Members, len(submittedChallenges[i])) - for _, ch := range submittedChallenges[i] { - req.Contains(proofMsg.Members, ch.data, "proof %v, round %v", proofMsg.RoundID, i) - } - - // Verify round statement. - mtree, err := merkle.NewTree() - req.NoError(err) - for _, m := range proofMsg.Members { - req.NoError(mtree.AddLeaf(m)) + req.Len(proofMsg.Members, len(challengeGroups[i]), "round: %v i: %d", proofMsg.RoundID, i) + for _, ch := range challengeGroups[i] { + req.Contains(proofMsg.Members, ch.data, "round: %v, i: %d", proofMsg.RoundID, i) } - proof, err := rounds[i].proof(false) - req.NoError(err, "round %d", i) - req.Equal(mtree.Root(), proof.Statement) } req.NoError(s.Shutdown()) @@ -207,7 +141,7 @@ func TestConcurrentServiceStartAndShutdown(t *testing.T) { t.Parallel() req := require.New(t) - cfg := Config{ + cfg := service.Config{ Genesis: time.Now().Add(2 * time.Second).Format(time.RFC3339), EpochDuration: time.Second, PhaseShift: time.Second / 2, @@ -219,7 +153,7 @@ func TestConcurrentServiceStartAndShutdown(t *testing.T) { for i := 0; i < 100; i += 1 { t.Run(fmt.Sprintf("iteration %d", i), func(t *testing.T) { t.Parallel() - s, err := NewService(&cfg, t.TempDir()) + s, err := service.NewService(&cfg, t.TempDir()) req.NoError(err) var eg errgroup.Group @@ -238,13 +172,13 @@ func TestNewService(t *testing.T) { req := require.New(t) tempdir := t.TempDir() - cfg := new(Config) + cfg := new(service.Config) cfg.Genesis = time.Now().Add(time.Second).Format(time.RFC3339) cfg.EpochDuration = time.Second cfg.PhaseShift = time.Second / 2 cfg.CycleGap = time.Second / 4 - s, err := NewService(cfg, tempdir) + s, err := service.NewService(cfg, tempdir) req.NoError(err) proofBroadcaster := &MockBroadcaster{receivedMessages: make(chan []byte)} ctrl := gomock.NewController(t) @@ -264,7 +198,7 @@ func TestNewService(t *testing.T) { req.NoError(err) } - info, err := s.Info() + info, err := s.Info(context.Background()) require.NoError(t, err) currentRound := info.OpenRoundID @@ -274,52 +208,49 @@ func TestNewService(t *testing.T) { round, hash, err := s.Submit(context.Background(), challenges[i].data, nil) req.NoError(err) req.Equal(challenges[i].data, hash) - req.Equal(currentRound, round.ID) - challenges[i].round = round - - // Verify that all submissions returned the same round instance. - if i > 0 { - req.Equal(challenges[i].round, challenges[i-1].round) - } + req.Equal(currentRound, round) } // Verify that round is still open. - info, err = s.Info() + info, err = s.Info(context.Background()) req.NoError(err) req.Equal(currentRound, info.OpenRoundID) // Wait for round to start execution. - select { - case <-challenges[0].round.executionStartedChan: - case err := <-s.errChan: - req.Fail(err.Error()) - } - - // Verify that round iteration proceeded. - prevInfo := info - info, err = s.Info() - req.NoError(err) - prevIndex, err := strconv.Atoi(prevInfo.OpenRoundID) - req.NoError(err) - req.Equal(fmt.Sprintf("%d", prevIndex+1), info.OpenRoundID) - req.Contains(info.ExecutingRoundsIds, prevInfo.OpenRoundID) + req.Eventually(func() bool { + info, err := s.Info(context.Background()) + req.NoError(err) + for _, r := range info.ExecutingRoundsIds { + if r == currentRound { + return true + } + } + return false + }, cfg.EpochDuration*2, time.Millisecond*20) // Wait for end of execution. - select { - case <-challenges[0].round.executionEndedChan: - case err := <-s.errChan: - req.Fail(err.Error()) - } + req.Eventually(func() bool { + info, err := s.Info(context.Background()) + req.NoError(err) + prevRoundID, err := strconv.Atoi(currentRound) + req.NoError(err) + currRoundID, err := strconv.Atoi(info.OpenRoundID) + req.NoError(err) + return currRoundID >= prevRoundID+1 + }, time.Second, time.Millisecond*20) // Wait for proof message broadcast. - select { - case msg := <-proofBroadcaster.receivedMessages: - poetProof := PoetProofMessage{} - dec := scale.NewDecoder(bytes.NewReader(msg)) - _, err := poetProof.DecodeScale(dec) - req.NoError(err) - case <-time.After(100 * time.Millisecond): - req.Fail("proof message wasn't sent") + msg := <-proofBroadcaster.receivedMessages + proof := service.PoetProofMessage{} + dec := scale.NewDecoder(bytes.NewReader(msg)) + _, err = proof.DecodeScale(dec) + req.NoError(err) + + req.Equal(currentRound, proof.RoundID) + // Verify the submitted challenges. + req.Len(proof.Members, len(challenges)) + for _, ch := range challenges { + req.Contains(proof.Members, ch.data) } req.NoError(s.Shutdown()) @@ -327,7 +258,7 @@ func TestNewService(t *testing.T) { func TestSubmitIdempotency(t *testing.T) { req := require.New(t) - cfg := Config{ + cfg := service.Config{ Genesis: time.Now().Add(time.Second).Format(time.RFC3339), EpochDuration: time.Second, PhaseShift: time.Second / 2, @@ -336,7 +267,7 @@ func TestSubmitIdempotency(t *testing.T) { challenge := []byte("challenge") signature := []byte("signature") - s, err := NewService(&cfg, t.TempDir()) + s, err := service.NewService(&cfg, t.TempDir()) req.NoError(err) proofBroadcaster := &MockBroadcaster{receivedMessages: make(chan []byte)} @@ -357,16 +288,3 @@ func TestSubmitIdempotency(t *testing.T) { req.NoError(s.Shutdown()) } - -func genChallenges(num int) ([][]byte, error) { - ch := make([][]byte, num) - for i := 0; i < num; i++ { - ch[i] = make([]byte, 32) - _, err := rand.Read(ch[i]) - if err != nil { - return nil, err - } - } - - return ch, nil -}