diff --git a/cmd/stakerd/main.go b/cmd/stakerd/main.go index 3e4cb89..c8d93d9 100644 --- a/cmd/stakerd/main.go +++ b/cmd/stakerd/main.go @@ -1,10 +1,13 @@ package main import ( + "context" "fmt" "net/http" "os" + "os/signal" "runtime/pprof" + "syscall" "github.com/babylonlabs-io/btc-staker/metrics" staker "github.com/babylonlabs-io/btc-staker/staker" @@ -12,16 +15,12 @@ import ( service "github.com/babylonlabs-io/btc-staker/stakerservice" "github.com/jessevdk/go-flags" - "github.com/lightningnetwork/lnd/signal" ) func main() { // Hook interceptor for os signals. - shutdownInterceptor, err := signal.Intercept() - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() cfg, cfgLogger, zapLogger, err := scfg.LoadConfig() @@ -89,15 +88,13 @@ func main() { cfg, staker, cfgLogger, - shutdownInterceptor, dbBackend, ) addr := fmt.Sprintf("%s:%d", cfg.MetricsConfig.Host, cfg.MetricsConfig.ServerPort) metrics.Start(cfgLogger, addr, stakerMetrics.Registry) - err = service.RunUntilShutdown() - if err != nil { + if err = service.RunUntilShutdown(ctx); err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } diff --git a/itest/e2e_test.go b/itest/e2e_test.go index 2a81a05..f196344 100644 --- a/itest/e2e_test.go +++ b/itest/e2e_test.go @@ -54,7 +54,6 @@ import ( sdkquerytypes "github.com/cosmos/cosmos-sdk/types/query" sttypes "github.com/cosmos/cosmos-sdk/x/staking/types" "github.com/lightningnetwork/lnd/kvdb" - "github.com/lightningnetwork/lnd/signal" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" ) @@ -144,7 +143,6 @@ type TestManager struct { BabylonClient *babylonclient.BabylonController WalletPubKey *btcec.PublicKey MinerAddr btcutil.Address - serverStopper *signal.Interceptor wg *sync.WaitGroup serviceAddress string StakerClient *dc.StakerServiceJsonRpcClient @@ -221,6 +219,7 @@ func (td *testStakingData) withStakingAmout(amout int64) *testStakingData { func StartManager( t *testing.T, + ctx context.Context, numMatureOutputsInWallet uint32, ) *TestManager { manager, err := containers.NewManager(t) @@ -321,9 +320,6 @@ func StartManager( walletPubKey, err := btcec.ParsePubKey(pubKeyBytes) require.NoError(t, err) - interceptor, err := signal.Intercept() - require.NoError(t, err) - addressString := fmt.Sprintf("127.0.0.1:%d", testutil.AllocateUniquePort(t)) addrPort := netip.MustParseAddrPort(addressString) address := net.TCPAddrFromAddrPort(addrPort) @@ -333,7 +329,6 @@ func StartManager( cfg, stakerApp, logger, - interceptor, dbbackend, ) @@ -341,7 +336,7 @@ func StartManager( wg.Add(1) go func() { defer wg.Done() - err := stakerService.RunUntilShutdown() + err := stakerService.RunUntilShutdown(ctx) if err != nil { t.Fatalf("Error running server: %v", err) } @@ -359,7 +354,6 @@ func StartManager( BabylonClient: bl, WalletPubKey: walletPubKey, MinerAddr: minerAddressDecoded, - serverStopper: &interceptor, wg: &wg, serviceAddress: addressString, StakerClient: stakerClient, @@ -370,8 +364,8 @@ func StartManager( } } -func (tm *TestManager) Stop(t *testing.T) { - tm.serverStopper.RequestShutdown() +func (tm *TestManager) Stop(t *testing.T, cancelFunc context.CancelFunc) { + cancelFunc() tm.wg.Wait() err := tm.manger.ClearResources() require.NoError(t, err) @@ -379,9 +373,9 @@ func (tm *TestManager) Stop(t *testing.T) { require.NoError(t, err) } -func (tm *TestManager) RestartApp(t *testing.T) { +func (tm *TestManager) RestartApp(t *testing.T, newCtx context.Context, cancelFunc context.CancelFunc) { // Restart the app with no-op action - tm.RestartAppWithAction(t, func(t *testing.T) {}) + tm.RestartAppWithAction(t, newCtx, cancelFunc, func(t *testing.T) {}) } // RestartAppWithAction: @@ -389,9 +383,9 @@ func (tm *TestManager) RestartApp(t *testing.T) { // 2. Perform provided action. Warning:this action must not use staker app as // app is stopped at this point // 3. Start the staker app -func (tm *TestManager) RestartAppWithAction(t *testing.T, action func(t *testing.T)) { +func (tm *TestManager) RestartAppWithAction(t *testing.T, ctx context.Context, cancelFunc context.CancelFunc, action func(t *testing.T)) { // First stop the app - tm.serverStopper.RequestShutdown() + cancelFunc() tm.wg.Wait() // Perform the action @@ -408,14 +402,10 @@ func (tm *TestManager) RestartAppWithAction(t *testing.T, action func(t *testing stakerApp, err := staker.NewStakerAppFromConfig(tm.Config, logger, zapLogger, dbbackend, m) require.NoError(t, err) - interceptor, err := signal.Intercept() - require.NoError(t, err) - service := service.NewStakerService( tm.Config, stakerApp, logger, - interceptor, dbbackend, ) @@ -423,7 +413,7 @@ func (tm *TestManager) RestartAppWithAction(t *testing.T, action func(t *testing wg.Add(1) go func() { defer wg.Done() - err := service.RunUntilShutdown() + err := service.RunUntilShutdown(ctx) if err != nil { t.Fatalf("Error running server: %v", err) } @@ -431,7 +421,6 @@ func (tm *TestManager) RestartAppWithAction(t *testing.T, action func(t *testing // Wait for the server to start time.Sleep(3 * time.Second) - tm.serverStopper = &interceptor tm.wg = &wg tm.Db = dbbackend tm.Sa = stakerApp @@ -1110,9 +1099,11 @@ func (tm *TestManager) insertCovenantSigForDelegation(t *testing.T, btcDel *btcs } func TestStakingFailures(t *testing.T) { + t.Parallel() numMatureOutputs := uint32(200) - tm := StartManager(t, numMatureOutputs) - defer tm.Stop(t) + ctx, cancel := context.WithCancel(context.Background()) + tm := StartManager(t, ctx, numMatureOutputs) + defer tm.Stop(t, cancel) tm.insertAllMinedBlocksToBabylon(t) cl := tm.Sa.BabylonController() @@ -1146,12 +1137,14 @@ func TestStakingFailures(t *testing.T) { } func TestSendingStakingTransaction(t *testing.T) { + t.Parallel() // need to have at least 300 block on testnet as only then segwit is activated. // Mature output is out which has 100 confirmations, which means 200mature outputs // will generate 300 blocks numMatureOutputs := uint32(200) - tm := StartManager(t, numMatureOutputs) - defer tm.Stop(t) + ctx, cancel := context.WithCancel(context.Background()) + tm := StartManager(t, ctx, numMatureOutputs) + defer tm.Stop(t, cancel) tm.insertAllMinedBlocksToBabylon(t) cl := tm.Sa.BabylonController() @@ -1220,12 +1213,14 @@ func TestSendingStakingTransaction(t *testing.T) { } func TestMultipleWithdrawableStakingTransactions(t *testing.T) { + t.Parallel() // need to have at least 300 block on testnet as only then segwit is activated. // Mature output is out which has 100 confirmations, which means 200mature outputs // will generate 300 blocks numMatureOutputs := uint32(200) - tm := StartManager(t, numMatureOutputs) - defer tm.Stop(t) + ctx, cancel := context.WithCancel(context.Background()) + tm := StartManager(t, ctx, numMatureOutputs) + defer tm.Stop(t, cancel) tm.insertAllMinedBlocksToBabylon(t) cl := tm.Sa.BabylonController() @@ -1286,12 +1281,14 @@ func TestMultipleWithdrawableStakingTransactions(t *testing.T) { } func TestSendingWatchedStakingTransaction(t *testing.T) { + t.Parallel() // need to have at least 300 block on testnet as only then segwit is activated. // Mature output is out which has 100 confirmations, which means 200mature outputs // will generate 300 blocks numMatureOutputs := uint32(200) - tm := StartManager(t, numMatureOutputs) - defer tm.Stop(t) + ctx, cancel := context.WithCancel(context.Background()) + tm := StartManager(t, ctx, numMatureOutputs) + defer tm.Stop(t, cancel) tm.insertAllMinedBlocksToBabylon(t) cl := tm.Sa.BabylonController() @@ -1308,12 +1305,14 @@ func TestSendingWatchedStakingTransaction(t *testing.T) { } func TestRestartingTxNotDeepEnough(t *testing.T) { + t.Parallel() // need to have at least 300 block on testnet as only then segwit is activated. // Mature output is out which has 100 confirmations, which means 200mature outputs // will generate 300 blocks numMatureOutputs := uint32(200) - tm := StartManager(t, numMatureOutputs) - defer tm.Stop(t) + ctx, cancel := context.WithCancel(context.Background()) + tm := StartManager(t, ctx, numMatureOutputs) + defer tm.Stop(t, cancel) tm.insertAllMinedBlocksToBabylon(t) cl := tm.Sa.BabylonController() @@ -1325,20 +1324,24 @@ func TestRestartingTxNotDeepEnough(t *testing.T) { tm.createAndRegisterFinalityProviders(t, testStakingData) txHash := tm.sendStakingTxBTC(t, testStakingData) + newCtx, newCancel := context.WithCancel(context.Background()) + defer newCancel() // restart app when tx is not deep enough - tm.RestartApp(t) + tm.RestartApp(t, newCtx, cancel) go tm.mineNEmptyBlocks(t, params.ConfirmationTimeBlocks, true) tm.waitForStakingTxState(t, txHash, proto.TransactionState_SENT_TO_BABYLON) } func TestRestartingTxNotOnBabylon(t *testing.T) { + t.Parallel() // need to have at least 300 block on testnet as only then segwit is activated. // Mature output is out which has 100 confirmations, which means 200mature outputs // will generate 300 blocks numMatureOutputs := uint32(200) - tm := StartManager(t, numMatureOutputs) - defer tm.Stop(t) + ctx, cancel := context.WithCancel(context.Background()) + tm := StartManager(t, ctx, numMatureOutputs) + defer tm.Stop(t, cancel) tm.insertAllMinedBlocksToBabylon(t) cl := tm.Sa.BabylonController() @@ -1362,8 +1365,10 @@ func TestRestartingTxNotOnBabylon(t *testing.T) { tm.waitForStakingTxState(t, txHash, proto.TransactionState_CONFIRMED_ON_BTC) } + newCtx, newCancel := context.WithCancel(context.Background()) + defer newCancel() // restart app, tx is confirmed but not delivered to babylon - tm.RestartApp(t) + tm.RestartApp(t, newCtx, cancel) // send headers to babylon, so that we can send delegation tx go tm.sendHeadersToBabylon(t, minedBlocks) @@ -1374,12 +1379,14 @@ func TestRestartingTxNotOnBabylon(t *testing.T) { } func TestStakingUnbonding(t *testing.T) { + t.Parallel() // need to have at least 300 block on testnet as only then segwit is activated. // Mature output is out which has 100 confirmations, which means 200mature outputs // will generate 300 blocks numMatureOutputs := uint32(200) - tm := StartManager(t, numMatureOutputs) - defer tm.Stop(t) + ctx, cancel := context.WithCancel(context.Background()) + tm := StartManager(t, ctx, numMatureOutputs) + defer tm.Stop(t, cancel) tm.insertAllMinedBlocksToBabylon(t) cl := tm.Sa.BabylonController() @@ -1445,12 +1452,14 @@ func TestStakingUnbonding(t *testing.T) { } func TestUnbondingRestartWaitingForSignatures(t *testing.T) { + t.Parallel() // need to have at least 300 block on testnet as only then segwit is activated. // Mature output is out which has 100 confirmations, which means 200mature outputs // will generate 300 blocks numMatureOutputs := uint32(200) - tm := StartManager(t, numMatureOutputs) - defer tm.Stop(t) + ctx, cancel := context.WithCancel(context.Background()) + tm := StartManager(t, ctx, numMatureOutputs) + defer tm.Stop(t, cancel) tm.insertAllMinedBlocksToBabylon(t) cl := tm.Sa.BabylonController() @@ -1468,8 +1477,10 @@ func TestUnbondingRestartWaitingForSignatures(t *testing.T) { tm.waitForStakingTxState(t, txHash, proto.TransactionState_SENT_TO_BABYLON) require.NoError(t, err) + newCtx, newCancel := context.WithCancel(context.Background()) + defer newCancel() // restart app, tx was sent to babylon but we did not receive covenant signatures yet - tm.RestartApp(t) + tm.RestartApp(t, newCtx, cancel) pend, err := tm.BabylonClient.QueryPendingBTCDelegations() require.NoError(t, err) @@ -1624,12 +1635,14 @@ func TestBitcoindWalletBip322Signing(t *testing.T) { } func TestSendingStakingTransaction_Restaking(t *testing.T) { + t.Parallel() // need to have at least 300 block on testnet as only then segwit is activated. // Mature output is out which has 100 confirmations, which means 200mature outputs // will generate 300 blocks numMatureOutputs := uint32(200) - tm := StartManager(t, numMatureOutputs) - defer tm.Stop(t) + ctx, cancel := context.WithCancel(context.Background()) + tm := StartManager(t, ctx, numMatureOutputs) + defer tm.Stop(t, cancel) tm.insertAllMinedBlocksToBabylon(t) cl := tm.Sa.BabylonController() @@ -1664,12 +1677,14 @@ func TestSendingStakingTransaction_Restaking(t *testing.T) { } func TestRecoverAfterRestartDuringWithdrawal(t *testing.T) { + t.Parallel() // need to have at least 300 block on testnet as only then segwit is activated. // Mature output is out which has 100 confirmations, which means 200mature outputs // will generate 300 blocks numMatureOutputs := uint32(200) - tm := StartManager(t, numMatureOutputs) - defer tm.Stop(t) + ctx, cancel := context.WithCancel(context.Background()) + tm := StartManager(t, ctx, numMatureOutputs) + defer tm.Stop(t, cancel) tm.insertAllMinedBlocksToBabylon(t) cl := tm.Sa.BabylonController() @@ -1722,7 +1737,10 @@ func TestRecoverAfterRestartDuringWithdrawal(t *testing.T) { return true }, 1*time.Minute, eventuallyPollTime) - tm.RestartAppWithAction(t, func(t *testing.T) { + ctxAfter, cancelAfter := context.WithCancel(context.Background()) + defer cancelAfter() + + tm.RestartAppWithAction(t, ctxAfter, cancel, func(t *testing.T) { // unbodning tx got confirmed during the stop period _ = tm.mineNEmptyBlocks(t, staker.UnbondingTxConfirmations+1, false) }) diff --git a/stakerservice/service.go b/stakerservice/service.go index 9c02c54..ace294e 100644 --- a/stakerservice/service.go +++ b/stakerservice/service.go @@ -2,6 +2,7 @@ package stakerservice import ( "bytes" + "context" "encoding/hex" "fmt" "math" @@ -26,7 +27,6 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/lightningnetwork/lnd/kvdb" - "github.com/lightningnetwork/lnd/signal" "github.com/sirupsen/logrus" ) @@ -41,26 +41,23 @@ type RoutesMap map[string]*rpc.RPCFunc type StakerService struct { started int32 - config *scfg.Config - staker *str.StakerApp - logger *logrus.Logger - db kvdb.Backend - interceptor signal.Interceptor + config *scfg.Config + staker *str.StakerApp + logger *logrus.Logger + db kvdb.Backend } func NewStakerService( c *scfg.Config, s *str.StakerApp, l *logrus.Logger, - sig signal.Interceptor, db kvdb.Backend, ) *StakerService { return &StakerService{ - config: c, - staker: s, - logger: l, - interceptor: sig, - db: db, + config: c, + staker: s, + logger: l, + db: db, } } @@ -563,7 +560,7 @@ func (s *StakerService) GetRoutes() RoutesMap { } } -func (s *StakerService) RunUntilShutdown() error { +func (s *StakerService) RunUntilShutdown(ctx context.Context) error { if atomic.AddInt32(&s.started, 1) != 1 { return nil } @@ -646,9 +643,8 @@ func (s *StakerService) RunUntilShutdown() error { s.logger.Info("Staker Service fully started") - // Wait for shutdown signal from either a graceful service stop or from - // the interrupt handler. - <-s.interceptor.ShutdownChannel() + // Wait for shutdown signal from either a graceful service stop or from cancel() + <-ctx.Done() s.logger.Info("Received shutdown signal. Stopping...")