diff --git a/signer/cosigner_nonce_cache.go b/signer/cosigner_nonce_cache.go index eac24d92..1b153db6 100644 --- a/signer/cosigner_nonce_cache.go +++ b/signer/cosigner_nonce_cache.go @@ -13,8 +13,8 @@ import ( const ( defaultGetNoncesInterval = 3 * time.Second defaultGetNoncesTimeout = 4 * time.Second - cachePreSize = 10000 - nonceCacheExpiration = 10 * time.Second // half of the local cosigner cache expiration + defaultNonceExpiration = 10 * time.Second // half of the local cosigner cache expiration + targetTrim = 10 ) type CosignerNonceCache struct { @@ -29,10 +29,13 @@ type CosignerNonceCache struct { getNoncesInterval time.Duration getNoncesTimeout time.Duration + nonceExpiration time.Duration threshold uint8 cache NonceCache + + pruner NonceCachePruner } type lastCount struct { @@ -58,15 +61,13 @@ func (lc *lastCount) Get() int { return lc.count } -type NonceCache struct { - cache map[uuid.UUID]*CachedNonce - mu sync.RWMutex +type NonceCachePruner interface { + PruneNonces() int } -func NewNonceCache() NonceCache { - return NonceCache{ - cache: make(map[uuid.UUID]*CachedNonce, cachePreSize), - } +type NonceCache struct { + cache []*CachedNonce + mu sync.RWMutex } func (nc *NonceCache) Size() int { @@ -75,17 +76,14 @@ func (nc *NonceCache) Size() int { return len(nc.cache) } -func (nc *NonceCache) Get(uuid uuid.UUID) (*CachedNonce, bool) { - nc.mu.RLock() - defer nc.mu.RUnlock() - cn, ok := nc.cache[uuid] - return cn, ok -} - -func (nc *NonceCache) Set(uuid uuid.UUID, cn *CachedNonce) { +func (nc *NonceCache) Add(cn *CachedNonce) { nc.mu.Lock() defer nc.mu.Unlock() - nc.cache[uuid] = cn + nc.cache = append(nc.cache, cn) +} + +func (nc *NonceCache) Delete(index int) { + nc.cache = append(nc.cache[:index], nc.cache[index+1:]...) } type CosignerNoncesRel struct { @@ -115,17 +113,26 @@ func NewCosignerNonceCache( leader Leader, getNoncesInterval time.Duration, getNoncesTimeout time.Duration, + nonceExpiration time.Duration, threshold uint8, + pruner NonceCachePruner, ) *CosignerNonceCache { - return &CosignerNonceCache{ + cnc := &CosignerNonceCache{ logger: logger, - cache: NewNonceCache(), cosigners: cosigners, leader: leader, getNoncesInterval: getNoncesInterval, getNoncesTimeout: getNoncesTimeout, + nonceExpiration: nonceExpiration, threshold: threshold, + pruner: pruner, + } + // the only time pruner is expected to be non-nil is during tests, otherwise we use the cache logic. + if pruner == nil { + cnc.pruner = cnc } + + return cnc } func (cnc *CosignerNonceCache) getUuids(n int) []uuid.UUID { @@ -136,9 +143,13 @@ func (cnc *CosignerNonceCache) getUuids(n int) []uuid.UUID { return uuids } +func (cnc *CosignerNonceCache) target() int { + return int((cnc.noncesPerMinute/60)*cnc.getNoncesInterval.Seconds()*1.2) + int(cnc.noncesPerMinute/30) + targetTrim +} + func (cnc *CosignerNonceCache) reconcile(ctx context.Context) { // prune expired nonces - cnc.pruneNonces() + pruned := cnc.pruner.PruneNonces() if !cnc.leader.IsLeader() { return @@ -146,8 +157,9 @@ func (cnc *CosignerNonceCache) reconcile(ctx context.Context) { remainingNonces := cnc.cache.Size() timeSinceLastReconcile := time.Since(cnc.lastReconcileTime) + lastReconcileNonces := cnc.lastReconcileNonces.Get() // calculate nonces per minute - noncesPerMin := float64(cnc.lastReconcileNonces.Get()-remainingNonces) / timeSinceLastReconcile.Minutes() + noncesPerMin := float64(lastReconcileNonces-remainingNonces-pruned) / timeSinceLastReconcile.Minutes() if noncesPerMin < 0 { noncesPerMin = 0 } @@ -167,15 +179,16 @@ func (cnc *CosignerNonceCache) reconcile(ctx context.Context) { // calculate how many nonces we need to load to keep up with demand // load 120% the number of nonces we need to keep up with demand, + // plus a couple seconds worth of nonces to account for nonce consumption during LoadN // plus 10 for padding - target := int((cnc.noncesPerMinute/60)*cnc.getNoncesInterval.Seconds()*1.2) + 10 - additional := target - remainingNonces + t := cnc.target() + additional := t - remainingNonces if additional <= 0 { // we're ahead of demand, don't load any more cnc.logger.Debug( "Cosigner nonce cache ahead of demand", - "target", target, + "target", t, "remaining", remainingNonces, "noncesPerMin", cnc.noncesPerMinute, ) @@ -184,7 +197,7 @@ func (cnc *CosignerNonceCache) reconcile(ctx context.Context) { cnc.logger.Debug( "Loading additional nonces to meet demand", - "target", target, + "target", t, "remaining", remainingNonces, "additional", additional, "noncesPerMin", cnc.noncesPerMinute, @@ -202,7 +215,7 @@ func (cnc *CosignerNonceCache) LoadN(ctx context.Context, n int) { var wg sync.WaitGroup wg.Add(len(cnc.cosigners)) - expiration := time.Now().Add(nonceCacheExpiration) + expiration := time.Now().Add(cnc.nonceExpiration) for i, p := range cnc.cosigners { i := i @@ -251,7 +264,7 @@ func (cnc *CosignerNonceCache) LoadN(ctx context.Context, n int) { }) } if num >= cnc.threshold { - cnc.cache.Set(u, &nonce) + cnc.cache.Add(&nonce) added++ } } @@ -274,10 +287,10 @@ func (cnc *CosignerNonceCache) Start(ctx context.Context) { } func (cnc *CosignerNonceCache) GetNonces(fastestPeers []Cosigner) (*CosignerUUIDNonces, error) { - cnc.cache.mu.RLock() - defer cnc.cache.mu.RUnlock() + cnc.cache.mu.Lock() + defer cnc.cache.mu.Unlock() CheckNoncesLoop: - for u, cn := range cnc.cache.cache { + for i, cn := range cnc.cache.cache { var nonces CosignerNonces for _, p := range fastestPeers { found := false @@ -294,13 +307,12 @@ CheckNoncesLoop: } } - cnc.cache.mu.RUnlock() - cnc.clearNonce(u) - cnc.cache.mu.RLock() + // remove this set of nonces from the cache + cnc.cache.Delete(i) // all peers found return &CosignerUUIDNonces{ - UUID: u, + UUID: cn.UUID, Nonces: nonces, }, nil } @@ -316,26 +328,32 @@ CheckNoncesLoop: return nil, fmt.Errorf("no nonces found involving cosigners %+v", cosignerInts) } -func (cnc *CosignerNonceCache) pruneNonces() { +func (cnc *CosignerNonceCache) PruneNonces() int { cnc.cache.mu.Lock() defer cnc.cache.mu.Unlock() - for u, cn := range cnc.cache.cache { - if time.Now().After(cn.Expiration) { - delete(cnc.cache.cache, u) + nonExpiredIndex := len(cnc.cache.cache) - 1 + for i := len(cnc.cache.cache) - 1; i >= 0; i-- { + if time.Now().Before(cnc.cache.cache[i].Expiration) { + nonExpiredIndex = i + break + } + if i == 0 { + deleteCount := len(cnc.cache.cache) + cnc.cache.cache = nil + return deleteCount } } -} - -func (cnc *CosignerNonceCache) clearNonce(uuid uuid.UUID) { - cnc.cache.mu.Lock() - defer cnc.cache.mu.Unlock() - delete(cnc.cache.cache, uuid) + deleteCount := len(cnc.cache.cache) - nonExpiredIndex - 1 + if nonExpiredIndex != len(cnc.cache.cache)-1 { + cnc.cache.cache = cnc.cache.cache[:nonExpiredIndex+1] + } + return deleteCount } func (cnc *CosignerNonceCache) ClearNonces(cosigner Cosigner) { cnc.cache.mu.Lock() defer cnc.cache.mu.Unlock() - for u, cn := range cnc.cache.cache { + for i, cn := range cnc.cache.cache { deleteID := -1 for i, n := range cn.Nonces { if n.Cosigner.GetID() == cosigner.GetID() { @@ -347,16 +365,10 @@ func (cnc *CosignerNonceCache) ClearNonces(cosigner Cosigner) { if deleteID >= 0 { if len(cn.Nonces)-1 < int(cnc.threshold) { // If cosigners on this nonce drops below threshold, delete it as it's no longer usable - delete(cnc.cache.cache, u) + cnc.cache.Delete(i) } else { cn.Nonces = append(cn.Nonces[:deleteID], cn.Nonces[deleteID+1:]...) } } } } - -func (cnc *CosignerNonceCache) ClearAllNonces() { - cnc.cache.mu.Lock() - defer cnc.cache.mu.Unlock() - cnc.cache.cache = make(map[uuid.UUID]*CachedNonce, cachePreSize) -} diff --git a/signer/cosigner_nonce_cache_test.go b/signer/cosigner_nonce_cache_test.go index c7ac2586..f7d6cc55 100644 --- a/signer/cosigner_nonce_cache_test.go +++ b/signer/cosigner_nonce_cache_test.go @@ -3,13 +3,47 @@ package signer import ( "context" "os" + "sync" "testing" "time" cometlog "github.com/cometbft/cometbft/libs/log" + "github.com/google/uuid" "github.com/stretchr/testify/require" ) +func TestNonceCache(_ *testing.T) { + nc := NonceCache{} + for i := 0; i < 10; i++ { + nc.Add(&CachedNonce{UUID: uuid.New(), Expiration: time.Now().Add(1 * time.Second)}) + } + + nc.Delete(nc.Size() - 1) + nc.Delete(0) +} + +type mockPruner struct { + cnc *CosignerNonceCache + count int + pruned int + mu sync.Mutex +} + +func (mp *mockPruner) PruneNonces() int { + pruned := mp.cnc.PruneNonces() + mp.mu.Lock() + defer mp.mu.Unlock() + mp.count++ + mp.pruned += pruned + return pruned +} + +func (mp *mockPruner) Result() (int, int) { + mp.mu.Lock() + defer mp.mu.Unlock() + return mp.count, mp.pruned +} + func TestNonceCacheDemand(t *testing.T) { lcs, _ := getTestLocalCosigners(t, 2, 3) cosigners := make([]Cosigner, len(lcs)) @@ -17,15 +51,21 @@ func TestNonceCacheDemand(t *testing.T) { cosigners[i] = lc } + mp := &mockPruner{} + nonceCache := NewCosignerNonceCache( cometlog.NewTMLogger(cometlog.NewSyncWriter(os.Stdout)), cosigners, &MockLeader{id: 1, leader: &ThresholdValidator{myCosigner: lcs[0]}}, 500*time.Millisecond, 100*time.Millisecond, + defaultNonceExpiration, 2, + mp, ) + mp.cnc = nonceCache + ctx, cancel := context.WithCancel(context.Background()) nonceCache.LoadN(ctx, 500) @@ -45,6 +85,55 @@ func TestNonceCacheDemand(t *testing.T) { cancel() - target := int(nonceCache.noncesPerMinute*.01) + 10 - require.LessOrEqual(t, size, target) + require.LessOrEqual(t, size, nonceCache.target()) + + require.Greater(t, mp.count, 0) + require.Equal(t, 0, mp.pruned) +} + +func TestNonceCacheExpiration(t *testing.T) { + lcs, _ := getTestLocalCosigners(t, 2, 3) + cosigners := make([]Cosigner, len(lcs)) + for i, lc := range lcs { + cosigners[i] = lc + } + + mp := &mockPruner{} + + nonceCache := NewCosignerNonceCache( + cometlog.NewTMLogger(cometlog.NewSyncWriter(os.Stdout)), + cosigners, + &MockLeader{id: 1, leader: &ThresholdValidator{myCosigner: lcs[0]}}, + 250*time.Millisecond, + 10*time.Millisecond, + 500*time.Millisecond, + 2, + mp, + ) + + mp.cnc = nonceCache + + ctx, cancel := context.WithCancel(context.Background()) + + const loadN = 500 + + nonceCache.LoadN(ctx, loadN) + + go nonceCache.Start(ctx) + + time.Sleep(1 * time.Second) + + count, pruned := mp.Result() + + // we should have pruned at least three times after + // waiting for a second with a reconcile interval of 250ms + require.GreaterOrEqual(t, count, 3) + + // we should have pruned at least the number of nonces we loaded and knew would expire + require.GreaterOrEqual(t, pruned, loadN) + + cancel() + + // the cache should have at most the trim padding since no nonces are being consumed. + require.LessOrEqual(t, nonceCache.cache.Size(), targetTrim) } diff --git a/signer/threshold_validator.go b/signer/threshold_validator.go index 5a201d4b..f9dede6d 100644 --- a/signer/threshold_validator.go +++ b/signer/threshold_validator.go @@ -84,7 +84,9 @@ func NewThresholdValidator( leader, defaultGetNoncesInterval, defaultGetNoncesTimeout, + defaultNonceExpiration, uint8(threshold), + nil, ) return &ThresholdValidator{ logger: logger,