From a759a073b6a6924bfaac62f388c1eec81949cb44 Mon Sep 17 00:00:00 2001
From: Lazar <12626340+Lazar955@users.noreply.github.com>
Date: Mon, 11 Nov 2024 15:20:59 +0100
Subject: [PATCH] chore(*): staker funding and reduce contention  (#22)

* improve btc stakers

* fund staker faster
---
 harness/app.go           |  17 +++----
 harness/babylonclient.go |  17 ++++++-
 harness/btcstaker.go     |  28 ++++++++---
 harness/manager.go       | 104 ++++++++++++++++++++++++++++++++++++++-
 4 files changed, 147 insertions(+), 19 deletions(-)

diff --git a/harness/app.go b/harness/app.go
index a6143ae..a29ebe1 100644
--- a/harness/app.go
+++ b/harness/app.go
@@ -76,7 +76,7 @@ func startHarness(cmdCtx context.Context, cfg config.Config) error {
 	vig := NewSubReporter(tm, vigilanteSender)
 	vig.Start(ctx)
 
-	fpMgr := NewFinalityProviderManager(tm, fpmSender, zap.NewNop(), numFinalityProviders, fpMgrHome, eotsDir) // todo(lazar); fp count cfg
+	fpMgr := NewFinalityProviderManager(tm, fpmSender, zap.NewNop(), numFinalityProviders, fpMgrHome, eotsDir)
 	if err = fpMgr.Initialize(ctx, cfg.NumPubRand); err != nil {
 		return err
 	}
@@ -87,21 +87,20 @@ func startHarness(cmdCtx context.Context, cfg config.Config) error {
 		if err != nil {
 			return err
 		}
-		stakers = append(stakers, NewBTCStaker(tm, stakerSender, fpMgr.randomFp().btcPk.MustToBTCPK()))
+		stakers = append(stakers, NewBTCStaker(tm, stakerSender, fpMgr.randomFp().btcPk.MustToBTCPK(), tm.fundingRequests))
 	}
 
+	// periodically check if we need to fund the staker
+	go tm.fundForever(ctx)
+
 	// fund all stakers
 	if err := tm.fundAllParties(ctx, senders(stakers)); err != nil {
 		return err
 	}
 
-	// start stakers and defer stops
-	// TODO(lazar): Ideally stakers would start on different times to reduce contention
-	// on funding BTC wallet
-	for _, staker := range stakers {
-		if err := staker.Start(ctx); err != nil {
-			return err
-		}
+	// start stakers
+	if err := startStakersInBatches(ctx, stakers); err != nil {
+		return err
 	}
 
 	go printStatsForever(ctx, tm, stopChan, cfg)
diff --git a/harness/babylonclient.go b/harness/babylonclient.go
index 85c2092..94d2cd9 100644
--- a/harness/babylonclient.go
+++ b/harness/babylonclient.go
@@ -5,7 +5,9 @@ import (
 	"encoding/hex"
 	"fmt"
 	"github.com/avast/retry-go/v4"
+	"github.com/babylonlabs-io/babylon/app/params"
 	"math/rand"
+	"sync"
 	"time"
 
 	bbn "github.com/babylonlabs-io/babylon/app"
@@ -32,6 +34,18 @@ var (
 	RtyErr    = retry.LastErrorOnly(true)
 )
 
+var (
+	once   sync.Once
+	encCfg *params.EncodingConfig
+)
+
+func getEncodingConfig() *params.EncodingConfig {
+	once.Do(func() {
+		encCfg = bbn.GetEncodingConfig()
+	})
+	return encCfg
+}
+
 type Client struct {
 	*query.QueryClient
 
@@ -47,6 +61,7 @@ func New(
 		zapLogger *zap.Logger
 		err       error
 	)
+	getEncodingConfig()
 
 	// ensure cfg is valid
 	if err := cfg.Validate(); err != nil {
@@ -74,7 +89,7 @@ func New(
 
 	// Create tmp Babylon0 app to retrieve and register codecs
 	// Need to override this manually as otherwise option from config is ignored
-	encCfg := bbn.GetEncodingConfig()
+
 	cp.Cdc = cosmos.Codec{
 		InterfaceRegistry: encCfg.InterfaceRegistry,
 		Marshaler:         encCfg.Codec,
diff --git a/harness/btcstaker.go b/harness/btcstaker.go
index 47d05ce..acddbdb 100644
--- a/harness/btcstaker.go
+++ b/harness/btcstaker.go
@@ -23,25 +23,29 @@ import (
 	"github.com/btcsuite/btcd/wire"
 	"github.com/cometbft/cometbft/crypto/tmhash"
 	sdk "github.com/cosmos/cosmos-sdk/types"
+	"strings"
 	"sync/atomic"
 	"time"
 )
 
 type BTCStaker struct {
-	tm     *TestManager
-	client *SenderWithBabylonClient
-	fpPK   *btcec.PublicKey
+	tm             *TestManager
+	client         *SenderWithBabylonClient
+	fpPK           *btcec.PublicKey
+	fundingRequest chan sdk.AccAddress
 }
 
 func NewBTCStaker(
 	tm *TestManager,
 	client *SenderWithBabylonClient,
 	finalityProviderPublicKey *btcec.PublicKey,
+	fundingRequest chan sdk.AccAddress,
 ) *BTCStaker {
 	return &BTCStaker{
-		tm:     tm,
-		client: client,
-		fpPK:   finalityProviderPublicKey,
+		tm:             tm,
+		client:         client,
+		fpPK:           finalityProviderPublicKey,
+		fundingRequest: fundingRequest,
 	}
 }
 
@@ -83,7 +87,15 @@ func (s *BTCStaker) runForever(ctx context.Context, stakerAddress btcutil.Addres
 			}
 			err = s.buildAndSendStakingTransaction(ctx, stakerAddress, stakerPk, &paramsResp.Params)
 			if err != nil {
-				fmt.Printf("🚫 Err in BTC Staker: %v\n", err)
+				fmt.Printf("🚫 Err in BTC Staker (%s), err: %v\n", s.client.BabylonAddress.String(), err)
+				if strings.Contains(strings.ToLower(err.Error()), "insufficient funds") {
+					select {
+					case s.fundingRequest <- s.client.BabylonAddress:
+						time.Sleep(5 * time.Second)
+					default:
+						fmt.Println("fundingRequest channel is full or closed")
+					}
+				}
 			}
 		}
 	}
@@ -293,7 +305,7 @@ func (s *BTCStaker) waitForTransactionConfirmation(
 	txHash *chainhash.Hash,
 	requiredDepth uint32,
 ) *bstypes.InclusionProof {
-	t := time.NewTicker(10 * time.Second)
+	t := time.NewTicker(5 * time.Second)
 	defer t.Stop()
 
 	for {
diff --git a/harness/manager.go b/harness/manager.go
index e0343a7..20e6a28 100644
--- a/harness/manager.go
+++ b/harness/manager.go
@@ -24,6 +24,7 @@ import (
 	"github.com/cosmos/cosmos-sdk/types"
 	sdk "github.com/cosmos/cosmos-sdk/types"
 	banktypes "github.com/cosmos/cosmos-sdk/x/bank/types"
+	"golang.org/x/sync/errgroup"
 	"os"
 	"path/filepath"
 	"time"
@@ -76,6 +77,7 @@ type TestManager struct {
 	manger             *container.Manager
 	babylonDir         string
 	benchConfig        benchcfg.Config
+	fundingRequests    chan sdk.AccAddress
 }
 
 // StartManager creates a test manager
@@ -214,6 +216,7 @@ func StartManager(ctx context.Context, outputsInWallet uint32, epochInterval uin
 		manger:             manager,
 		babylonDir:         babylonDir,
 		benchConfig:        runCfg,
+		fundingRequests:    make(chan sdk.AccAddress, 100),
 	}, nil
 }
 
@@ -344,7 +347,7 @@ func (tm *TestManager) fundAllParties(
 	var msgs []sdk.Msg
 
 	for _, sender := range senders {
-		msg := banktypes.NewMsgSend(fundingAddress, sender.BabylonAddress, types.NewCoins(types.NewInt64Coin("ubbn", 100000000)))
+		msg := banktypes.NewMsgSend(fundingAddress, sender.BabylonAddress, types.NewCoins(types.NewInt64Coin("ubbn", 100_000_000)))
 		msgs = append(msgs, msg)
 	}
 
@@ -364,6 +367,39 @@ func (tm *TestManager) fundAllParties(
 	return nil
 }
 
+func (tm *TestManager) fundBnnAddress(
+	ctx context.Context,
+	addr sdk.AccAddress,
+) error {
+	if err := ctx.Err(); err != nil {
+		return fmt.Errorf("context error before funding: %w", err)
+	}
+
+	fundingAccount := tm.BabylonClientNode0.MustGetAddr()
+	fundingAddress, err := sdk.AccAddressFromBech32(fundingAccount)
+	if err != nil {
+		return fmt.Errorf("failed to parse funding address: %w", err)
+	}
+
+	amount := types.NewCoins(types.NewInt64Coin("ubbn", 100_000_000))
+	msg := banktypes.NewMsgSend(fundingAddress, addr, amount)
+
+	resp, err := tm.BabylonClientNode0.ReliablySendMsg(ctx, msg, nil, nil)
+	if err != nil {
+		return fmt.Errorf("failed to send fund transaction: %w", err)
+	}
+
+	if resp == nil {
+		return fmt.Errorf("transaction response is nil")
+	}
+
+	if resp.Code != 0 {
+		return fmt.Errorf("funding transaction failed with code %d", resp.Code)
+	}
+
+	return nil
+}
+
 func (tm *TestManager) listBlocksForever(ctx context.Context) {
 	lt := time.NewTicker(5 * time.Second)
 	defer lt.Stop()
@@ -388,3 +424,69 @@ func (tm *TestManager) listBlocksForever(ctx context.Context) {
 		}
 	}
 }
+
+func (tm *TestManager) fundForever(ctx context.Context) {
+	ticker := time.NewTicker(3 * time.Second)
+	defer ticker.Stop()
+	for {
+		select {
+		case <-ctx.Done():
+			return
+		case <-ticker.C:
+		case addr := <-tm.fundingRequests:
+			go func() {
+				if err := tm.fundBnnAddress(ctx, addr); err != nil {
+					fmt.Printf("🚫 Failed to fund addr %s, err %v\n", addr.String(), err)
+				}
+			}()
+		}
+	}
+}
+
+func startStakersInBatches(ctx context.Context, stakers []*BTCStaker) error {
+	const (
+		batchSize     = 25
+		batchInterval = 2 * time.Second
+	)
+
+	fmt.Printf("⌛ Starting %d stakers in batches of %d, with %s interval\n",
+		len(stakers), batchSize, batchInterval)
+
+	start := time.Now()
+	var g errgroup.Group
+	for i := 0; i < len(stakers); i += batchSize {
+		end := i + batchSize
+		if end > len(stakers) {
+			end = len(stakers)
+		}
+		batch := stakers[i:end]
+
+		g.Go(func() error {
+			return startBatch(ctx, batch)
+		})
+
+		// Wait before starting the next batch, unless it's the last batch
+		if end < len(stakers) {
+			select {
+			case <-ctx.Done():
+				return ctx.Err()
+			case <-time.After(batchInterval):
+			}
+		}
+	}
+
+	elapsed := time.Since(start)
+	fmt.Printf("✅ All %d stakers started in %s\n", len(stakers), elapsed)
+
+	return g.Wait()
+}
+
+func startBatch(ctx context.Context, batch []*BTCStaker) error {
+	var g errgroup.Group
+	for _, staker := range batch {
+		g.Go(func() error {
+			return staker.Start(ctx)
+		})
+	}
+	return g.Wait()
+}