diff --git a/.github/workflows/beekeeper.yml b/.github/workflows/beekeeper.yml index a606e8bfc65..470712c5566 100644 --- a/.github/workflows/beekeeper.yml +++ b/.github/workflows/beekeeper.yml @@ -12,7 +12,7 @@ env: REPLICA: 3 RUN_TYPE: "PR RUN" SETUP_CONTRACT_IMAGE: "ethersphere/bee-localchain" - SETUP_CONTRACT_IMAGE_TAG: "0.9.0-rc3" + SETUP_CONTRACT_IMAGE_TAG: "0.9.1-rc1" BEELOCAL_BRANCH: "main" BEEKEEPER_BRANCH: "master" BEEKEEPER_METRICS_ENABLED: false diff --git a/go.mod b/go.mod index 8ce4086f3b7..dac2171ca23 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/coreos/go-semver v0.3.0 github.com/ethereum/go-ethereum v1.14.3 github.com/ethersphere/go-price-oracle-abi v0.2.0 - github.com/ethersphere/go-storage-incentives-abi v0.9.0-rc3 + github.com/ethersphere/go-storage-incentives-abi v0.9.1-rc1 github.com/ethersphere/go-sw3-abi v0.6.5 github.com/ethersphere/langos v1.0.0 github.com/go-playground/validator/v10 v10.11.1 diff --git a/go.sum b/go.sum index 09b91a216ae..acde65cdd7b 100644 --- a/go.sum +++ b/go.sum @@ -236,8 +236,8 @@ github.com/ethereum/go-ethereum v1.14.3 h1:5zvnAqLtnCZrU9uod1JCvHWJbPMURzYFHfc2e github.com/ethereum/go-ethereum v1.14.3/go.mod h1:1STrq471D0BQbCX9He0hUj4bHxX2k6mt5nOQJhDNOJ8= github.com/ethersphere/go-price-oracle-abi v0.2.0 h1:wtIcYLgNZHY4BjYwJCnu93SvJdVAZVvBaKinspyyHvQ= github.com/ethersphere/go-price-oracle-abi v0.2.0/go.mod h1:sI/Qj4/zJ23/b1enzwMMv0/hLTpPNVNacEwCWjo6yBk= -github.com/ethersphere/go-storage-incentives-abi v0.9.0-rc3 h1:TCCGtf1jODBUusTiH94Nhgw03apgchQZaJ03L/vt4z4= -github.com/ethersphere/go-storage-incentives-abi v0.9.0-rc3/go.mod h1:SXvJVtM4sEsaSKD0jc1ClpDLw8ErPoROZDme4Wrc/Nc= +github.com/ethersphere/go-storage-incentives-abi v0.9.1-rc1 h1:+BfZEv0zIN9MKQeRsZG/y3KruIuPlHHaG01oJ2wAvNA= +github.com/ethersphere/go-storage-incentives-abi v0.9.1-rc1/go.mod h1:SXvJVtM4sEsaSKD0jc1ClpDLw8ErPoROZDme4Wrc/Nc= github.com/ethersphere/go-sw3-abi v0.6.5 h1:M5dcIe1zQYvGpY2K07UNkNU9Obc4U+A1fz68Ho/Q+XE= github.com/ethersphere/go-sw3-abi v0.6.5/go.mod h1:BmpsvJ8idQZdYEtWnvxA8POYQ8Rl/NhyCdF0zLMOOJU= github.com/ethersphere/langos v1.0.0 h1:NBtNKzXTTRSue95uOlzPN4py7Aofs0xWPzyj4AI1Vcc= diff --git a/openapi/Swarm.yaml b/openapi/Swarm.yaml index de8614c879f..ffd6ee14b6b 100644 --- a/openapi/Swarm.yaml +++ b/openapi/Swarm.yaml @@ -1,7 +1,7 @@ openapi: 3.0.3 info: - version: 7.0.0 + version: 7.1.0 title: Bee API description: "A list of the currently provided Interfaces to interact with the swarm, implementing file operations and sending messages" @@ -2033,6 +2033,34 @@ paths: default: description: Default response + "/stake/migrate": + post: + summary: Withdraws all past staked amount back to the wallet. + description: Be aware, the endpoint call only be called when the contract is paused and is in the process of being migrated to a new contract. + tags: + - Staking + responses: + "200": + $ref: "SwarmCommon.yaml#/components/schemas/StakeTransactionResponse" + "500": + $ref: "SwarmCommon.yaml#/components/responses/500" + default: + description: Default response + + "/stake/withdrawable": + get: + summary: Get the withdrawable staked amount. + description: This endpoint fetches any amount that is possible to withdraw as surplus. + tags: + - Staking + responses: + "200": + $ref: "SwarmCommon.yaml#/components/schemas/GetStakeResponse" + "500": + $ref: "SwarmCommon.yaml#/components/responses/500" + default: + description: Default response + "/stake/{amount}": post: summary: Deposit some amount for staking. @@ -2049,7 +2077,7 @@ paths: - $ref: "SwarmCommon.yaml#/components/parameters/GasLimitParameter" responses: "200": - $ref: "SwarmCommon.yaml#/components/schemas/StakeDepositResponse" + $ref: "SwarmCommon.yaml#/components/schemas/StakeTransactionResponse" "400": $ref: "SwarmCommon.yaml#/components/responses/400" "500": @@ -2060,7 +2088,7 @@ paths: "/stake": get: summary: Get the staked amount. - description: This endpoint fetches the staked amount from the blockchain. + description: This endpoint fetches the total staked amount from the blockchain. tags: - Staking responses: @@ -2071,8 +2099,8 @@ paths: default: description: Default response delete: - summary: Withdraw all staked amount. - description: Be aware, this endpoint creates an on-chain transactions and transfers BZZ from the node's Ethereum account and hence directly manipulates the wallet balance. + summary: Withdraw the extra withdrawable staked amount. + description: This endpoint withdraws any amount that is possible to withdraw as surplus. tags: - Staking parameters: @@ -2080,7 +2108,7 @@ paths: - $ref: "SwarmCommon.yaml#/components/parameters/GasLimitParameter" responses: "200": - $ref: "SwarmCommon.yaml#/components/schemas/WithdrawAllStakeResponse" + $ref: "SwarmCommon.yaml#/components/schemas/StakeTransactionResponse" "400": $ref: "SwarmCommon.yaml#/components/responses/400" "500": diff --git a/openapi/SwarmCommon.yaml b/openapi/SwarmCommon.yaml index 1900ca2805b..50e56a96949 100644 --- a/openapi/SwarmCommon.yaml +++ b/openapi/SwarmCommon.yaml @@ -1,6 +1,6 @@ openapi: 3.0.3 info: - version: 4.0.0 + version: 4.1.0 title: Common Data Types description: | \*****bzzz***** @@ -605,18 +605,12 @@ components: stakedAmount: $ref: "#/components/schemas/BigInt" - StakeDepositResponse: + StakeTransactionResponse: type: object properties: txHash: $ref: "#/components/schemas/TransactionHash" - WithdrawAllStakeResponse: - type: object - properties: - txHash: - $ref: "#/components/schemas/TransactionHash" - SwarmOnlyReference: oneOf: - $ref: "#/components/schemas/SwarmAddress" diff --git a/pkg/api/export_test.go b/pkg/api/export_test.go index 6eba8a56841..c0f2d2f29fe 100644 --- a/pkg/api/export_test.go +++ b/pkg/api/export_test.go @@ -95,7 +95,7 @@ type ( WalletResponse = walletResponse WalletTxResponse = walletTxResponse GetStakeResponse = getStakeResponse - WithdrawAllStakeResponse = withdrawAllStakeResponse + StakeTransactionReponse = stakeTransactionReponse StatusSnapshotResponse = statusSnapshotResponse StatusResponse = statusResponse ) diff --git a/pkg/api/router.go b/pkg/api/router.go index b4dcb5423df..7018d1f6fcd 100644 --- a/pkg/api/router.go +++ b/pkg/api/router.go @@ -559,6 +559,22 @@ func (s *Service) mountBusinessDebug() { web.FinalHandlerFunc(s.healthHandler), )) + handle("/stake/migrate", web.ChainHandlers( + s.stakingAccessHandler, + s.gasConfigMiddleware("migrate stake"), + web.FinalHandler(jsonhttp.MethodHandler{ + "POST": http.HandlerFunc(s.migrateStakeHandler), + })), + ) + + handle("/stake/withdrawable", web.ChainHandlers( + s.stakingAccessHandler, + s.gasConfigMiddleware("get withdrawable stake"), + web.FinalHandler(jsonhttp.MethodHandler{ + "GET": http.HandlerFunc(s.getWithdrawableStakeHandler), + })), + ) + handle("/stake/{amount}", web.ChainHandlers( s.stakingAccessHandler, s.gasConfigMiddleware("deposit stake"), @@ -571,8 +587,8 @@ func (s *Service) mountBusinessDebug() { s.stakingAccessHandler, s.gasConfigMiddleware("get or withdraw stake"), web.FinalHandler(jsonhttp.MethodHandler{ - "GET": http.HandlerFunc(s.getStakedAmountHandler), - "DELETE": http.HandlerFunc(s.withdrawAllStakeHandler), + "GET": http.HandlerFunc(s.getPotentialStake), + "DELETE": http.HandlerFunc(s.withdrawStakeHandler), })), ) handle("/redistributionstate", jsonhttp.MethodHandler{ diff --git a/pkg/api/staking.go b/pkg/api/staking.go index 5ca5b06d409..581ba00f197 100644 --- a/pkg/api/staking.go +++ b/pkg/api/staking.go @@ -33,11 +33,7 @@ func (s *Service) stakingAccessHandler(h http.Handler) http.Handler { type getStakeResponse struct { StakedAmount *bigint.BigInt `json:"stakedAmount"` } -type stakeDepositResponse struct { - TxHash string `json:"txhash"` -} - -type withdrawAllStakeResponse struct { +type stakeTransactionReponse struct { TxHash string `json:"txhash"` } @@ -77,15 +73,29 @@ func (s *Service) stakingDepositHandler(w http.ResponseWriter, r *http.Request) jsonhttp.InternalServerError(w, "cannot stake") return } - jsonhttp.OK(w, stakeDepositResponse{ + jsonhttp.OK(w, stakeTransactionReponse{ TxHash: txHash.String(), }) } -func (s *Service) getStakedAmountHandler(w http.ResponseWriter, r *http.Request) { +func (s *Service) getPotentialStake(w http.ResponseWriter, r *http.Request) { + logger := s.logger.WithName("get_stake").Build() + + stakedAmount, err := s.stakingContract.GetPotentialStake(r.Context()) + if err != nil { + logger.Debug("get staked amount failed", "overlayAddr", s.overlay, "error", err) + logger.Error(nil, "get staked amount failed") + jsonhttp.InternalServerError(w, "get staked amount failed") + return + } + + jsonhttp.OK(w, getStakeResponse{StakedAmount: bigint.Wrap(stakedAmount)}) +} + +func (s *Service) getWithdrawableStakeHandler(w http.ResponseWriter, r *http.Request) { logger := s.logger.WithName("get_stake").Build() - stakedAmount, err := s.stakingContract.GetStake(r.Context()) + stakedAmount, err := s.stakingContract.GetWithdrawableStake(r.Context()) if err != nil { logger.Debug("get staked amount failed", "overlayAddr", s.overlay, "error", err) logger.Error(nil, "get staked amount failed") @@ -96,10 +106,10 @@ func (s *Service) getStakedAmountHandler(w http.ResponseWriter, r *http.Request) jsonhttp.OK(w, getStakeResponse{StakedAmount: bigint.Wrap(stakedAmount)}) } -func (s *Service) withdrawAllStakeHandler(w http.ResponseWriter, r *http.Request) { - logger := s.logger.WithName("delete_withdraw_all_stake").Build() +func (s *Service) withdrawStakeHandler(w http.ResponseWriter, r *http.Request) { + logger := s.logger.WithName("withdraw_stake").Build() - txHash, err := s.stakingContract.WithdrawAllStake(r.Context()) + txHash, err := s.stakingContract.WithdrawStake(r.Context()) if err != nil { if errors.Is(err, staking.ErrInsufficientStake) { logger.Debug("insufficient stake", "overlayAddr", s.overlay, "error", err) @@ -113,5 +123,31 @@ func (s *Service) withdrawAllStakeHandler(w http.ResponseWriter, r *http.Request return } - jsonhttp.OK(w, withdrawAllStakeResponse{TxHash: txHash.String()}) + jsonhttp.OK(w, stakeTransactionReponse{TxHash: txHash.String()}) +} + +func (s *Service) migrateStakeHandler(w http.ResponseWriter, r *http.Request) { + logger := s.logger.WithName("migrate_stake").Build() + + txHash, err := s.stakingContract.MigrateStake(r.Context()) + if err != nil { + if errors.Is(err, staking.ErrInsufficientStake) { + logger.Debug("insufficient stake", "overlayAddr", s.overlay, "error", err) + logger.Error(nil, "insufficient stake") + jsonhttp.BadRequest(w, "insufficient stake to migrate") + return + } + if errors.Is(err, staking.ErrNotPaused) { + logger.Debug("contract is not paused", "error", err) + logger.Error(nil, "contract is not paused") + jsonhttp.BadRequest(w, "contract is not paused") + return + } + logger.Debug("migrate stake failed", "error", err) + logger.Error(nil, "migrate stake failed") + jsonhttp.InternalServerError(w, "cannot migrate stake") + return + } + + jsonhttp.OK(w, stakeTransactionReponse{TxHash: txHash.String()}) } diff --git a/pkg/api/staking_test.go b/pkg/api/staking_test.go index 0d6dac27587..3375ae341fb 100644 --- a/pkg/api/staking_test.go +++ b/pkg/api/staking_test.go @@ -105,7 +105,7 @@ func TestDepositStake(t *testing.T) { }) } -func TestGetStake(t *testing.T) { +func TestGetStakeCommitted(t *testing.T) { t.Parallel() t.Run("ok", func(t *testing.T) { @@ -135,6 +135,36 @@ func TestGetStake(t *testing.T) { }) } +func TestGetStakeWithdrawable(t *testing.T) { + t.Parallel() + + t.Run("ok", func(t *testing.T) { + t.Parallel() + + contract := stakingContractMock.New( + stakingContractMock.WithGetStake(func(ctx context.Context) (*big.Int, error) { + return big.NewInt(1), nil + }), + ) + ts, _, _, _ := newTestServer(t, testServerOptions{StakingContract: contract}) + jsonhttptest.Request(t, ts, http.MethodGet, "/stake/withdrawable", http.StatusOK, + jsonhttptest.WithExpectedJSONResponse(&api.GetStakeResponse{StakedAmount: bigint.Wrap(big.NewInt(1))})) + }) + + t.Run("with error", func(t *testing.T) { + t.Parallel() + + contractWithError := stakingContractMock.New( + stakingContractMock.WithGetStake(func(ctx context.Context) (*big.Int, error) { + return big.NewInt(0), fmt.Errorf("get stake failed") + }), + ) + ts, _, _, _ := newTestServer(t, testServerOptions{StakingContract: contractWithError}) + jsonhttptest.Request(t, ts, http.MethodGet, "/stake/withdrawable", http.StatusInternalServerError, + jsonhttptest.WithExpectedJSONResponse(&jsonhttp.StatusResponse{Code: http.StatusInternalServerError, Message: "get staked amount failed"})) + }) +} + func Test_stakingDepositHandler_invalidInputs(t *testing.T) { t.Parallel() @@ -171,7 +201,7 @@ func Test_stakingDepositHandler_invalidInputs(t *testing.T) { } } -func TestWithdrawAllStake(t *testing.T) { +func TestWithdrawStake(t *testing.T) { t.Parallel() txHash := common.HexToHash("0x1234") @@ -180,20 +210,20 @@ func TestWithdrawAllStake(t *testing.T) { t.Parallel() contract := stakingContractMock.New( - stakingContractMock.WithWithdrawAllStake(func(ctx context.Context) (common.Hash, error) { + stakingContractMock.WithWithdrawStake(func(ctx context.Context) (common.Hash, error) { return txHash, nil }), ) ts, _, _, _ := newTestServer(t, testServerOptions{StakingContract: contract}) jsonhttptest.Request(t, ts, http.MethodDelete, "/stake", http.StatusOK, jsonhttptest.WithExpectedJSONResponse( - &api.WithdrawAllStakeResponse{TxHash: txHash.String()})) + &api.StakeTransactionReponse{TxHash: txHash.String()})) }) t.Run("with invalid stake amount", func(t *testing.T) { t.Parallel() contract := stakingContractMock.New( - stakingContractMock.WithWithdrawAllStake(func(ctx context.Context) (common.Hash, error) { + stakingContractMock.WithWithdrawStake(func(ctx context.Context) (common.Hash, error) { return common.Hash{}, staking.ErrInsufficientStake }), ) @@ -206,7 +236,7 @@ func TestWithdrawAllStake(t *testing.T) { t.Parallel() contract := stakingContractMock.New( - stakingContractMock.WithWithdrawAllStake(func(ctx context.Context) (common.Hash, error) { + stakingContractMock.WithWithdrawStake(func(ctx context.Context) (common.Hash, error) { return common.Hash{}, fmt.Errorf("some error") }), ) @@ -219,7 +249,7 @@ func TestWithdrawAllStake(t *testing.T) { t.Parallel() contract := stakingContractMock.New( - stakingContractMock.WithWithdrawAllStake(func(ctx context.Context) (common.Hash, error) { + stakingContractMock.WithWithdrawStake(func(ctx context.Context) (common.Hash, error) { gasLimit := sctx.GetGasLimit(ctx) if gasLimit != 2000000 { t.Fatalf("want 2000000, got %d", gasLimit) @@ -236,3 +266,69 @@ func TestWithdrawAllStake(t *testing.T) { ) }) } + +func TestMigrateStake(t *testing.T) { + t.Parallel() + + txHash := common.HexToHash("0x1234") + + t.Run("ok", func(t *testing.T) { + t.Parallel() + + contract := stakingContractMock.New( + stakingContractMock.WithMigrateStake(func(ctx context.Context) (common.Hash, error) { + return txHash, nil + }), + ) + ts, _, _, _ := newTestServer(t, testServerOptions{StakingContract: contract}) + jsonhttptest.Request(t, ts, http.MethodPost, "/stake/migrate", http.StatusOK, jsonhttptest.WithExpectedJSONResponse( + &api.StakeTransactionReponse{TxHash: txHash.String()})) + }) + + t.Run("with invalid stake amount", func(t *testing.T) { + t.Parallel() + + contract := stakingContractMock.New( + stakingContractMock.WithMigrateStake(func(ctx context.Context) (common.Hash, error) { + return common.Hash{}, staking.ErrInsufficientStake + }), + ) + ts, _, _, _ := newTestServer(t, testServerOptions{StakingContract: contract}) + jsonhttptest.Request(t, ts, http.MethodPost, "/stake/migrate", http.StatusBadRequest, + jsonhttptest.WithExpectedJSONResponse(&jsonhttp.StatusResponse{Code: http.StatusBadRequest, Message: "insufficient stake to migrate"})) + }) + + t.Run("internal error", func(t *testing.T) { + t.Parallel() + + contract := stakingContractMock.New( + stakingContractMock.WithMigrateStake(func(ctx context.Context) (common.Hash, error) { + return common.Hash{}, fmt.Errorf("some error") + }), + ) + ts, _, _, _ := newTestServer(t, testServerOptions{StakingContract: contract}) + jsonhttptest.Request(t, ts, http.MethodPost, "/stake/migrate", http.StatusInternalServerError) + jsonhttptest.WithExpectedJSONResponse(&jsonhttp.StatusResponse{Code: http.StatusInternalServerError, Message: "cannot withdraw stake"}) + }) + + t.Run("gas limit header", func(t *testing.T) { + t.Parallel() + + contract := stakingContractMock.New( + stakingContractMock.WithMigrateStake(func(ctx context.Context) (common.Hash, error) { + gasLimit := sctx.GetGasLimit(ctx) + if gasLimit != 2000000 { + t.Fatalf("want 2000000, got %d", gasLimit) + } + return txHash, nil + }), + ) + ts, _, _, _ := newTestServer(t, testServerOptions{ + StakingContract: contract, + }) + + jsonhttptest.Request(t, ts, http.MethodPost, "/stake/migrate", http.StatusOK, + jsonhttptest.WithRequestHeader(api.GasLimitHeader, "2000000"), + ) + }) +} diff --git a/pkg/node/devnode.go b/pkg/node/devnode.go index c8e35936920..1c8347bd194 100644 --- a/pkg/node/devnode.go +++ b/pkg/node/devnode.go @@ -315,7 +315,7 @@ func NewDevBee(logger log.Logger, o *DevOptions) (b *DevBee, err error) { stakingContractMock.WithGetStake(func(ctx context.Context) (*big.Int, error) { return nil, staking.ErrNotImplemented }), - stakingContractMock.WithWithdrawAllStake(func(ctx context.Context) (common.Hash, error) { + stakingContractMock.WithWithdrawStake(func(ctx context.Context) (common.Hash, error) { return common.Hash{}, staking.ErrNotImplemented }), stakingContractMock.WithIsFrozen(func(ctx context.Context, block uint64) (bool, error) { diff --git a/pkg/node/node.go b/pkg/node/node.go index f7e5f1f0ae1..52680b19ce6 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -971,7 +971,7 @@ func NewBee( stakingContract := staking.New(overlayEthAddress, stakingContractAddress, abiutil.MustParseABI(chainCfg.StakingABI), bzzTokenAddress, transactionService, common.BytesToHash(nonce), o.TrxDebugMode) if chainEnabled && changedOverlay { - stake, err := stakingContract.GetStake(ctx) + stake, err := stakingContract.GetPotentialStake(ctx) if err != nil { return nil, errors.New("getting stake balance") } diff --git a/pkg/storageincentives/staking/contract.go b/pkg/storageincentives/staking/contract.go index 736d86c4216..a26e56c2c9d 100644 --- a/pkg/storageincentives/staking/contract.go +++ b/pkg/storageincentives/staking/contract.go @@ -33,13 +33,16 @@ var ( approveDescription = "Approve tokens for stake deposit operations" depositStakeDescription = "Deposit Stake" withdrawStakeDescription = "Withdraw stake" + migrateStakeDescription = "Migrate stake" ) type Contract interface { DepositStake(ctx context.Context, stakedAmount *big.Int) (common.Hash, error) ChangeStakeOverlay(ctx context.Context, nonce common.Hash) (common.Hash, error) - GetStake(ctx context.Context) (*big.Int, error) - WithdrawAllStake(ctx context.Context) (common.Hash, error) + GetPotentialStake(ctx context.Context) (*big.Int, error) + GetWithdrawableStake(ctx context.Context) (*big.Int, error) + WithdrawStake(ctx context.Context) (common.Hash, error) + MigrateStake(ctx context.Context) (common.Hash, error) RedistributionStatuser } @@ -83,6 +86,134 @@ func New( } } +func (c *contract) DepositStake(ctx context.Context, stakedAmount *big.Int) (common.Hash, error) { + prevStakedAmount, err := c.GetPotentialStake(ctx) + if err != nil { + return common.Hash{}, err + } + + if len(prevStakedAmount.Bits()) == 0 { + if stakedAmount.Cmp(MinimumStakeAmount) == -1 { + return common.Hash{}, ErrInsufficientStakeAmount + } + } + + balance, err := c.getBalance(ctx) + if err != nil { + return common.Hash{}, err + } + + if balance.Cmp(stakedAmount) < 0 { + return common.Hash{}, ErrInsufficientFunds + } + + _, err = c.sendApproveTransaction(ctx, stakedAmount) + if err != nil { + return common.Hash{}, err + } + + receipt, err := c.sendDepositStakeTransaction(ctx, stakedAmount, c.overlayNonce) + if err != nil { + return common.Hash{}, err + } + + return receipt.TxHash, nil +} + +// ChangeStakeOverlay only changes the overlay address used in the redistribution game. +func (c *contract) ChangeStakeOverlay(ctx context.Context, nonce common.Hash) (common.Hash, error) { + c.overlayNonce = nonce + receipt, err := c.sendDepositStakeTransaction(ctx, new(big.Int), c.overlayNonce) + if err != nil { + return common.Hash{}, err + } + + return receipt.TxHash, nil +} + +func (c *contract) GetPotentialStake(ctx context.Context) (*big.Int, error) { + stakedAmount, err := c.getPotentialStake(ctx) + if err != nil { + return nil, fmt.Errorf("staking contract: failed to get stake: %w", err) + } + return stakedAmount, nil +} + +func (c *contract) GetWithdrawableStake(ctx context.Context) (*big.Int, error) { + stakedAmount, err := c.getwithdrawableStake(ctx) + if err != nil { + return nil, fmt.Errorf("staking contract: failed to get stake: %w", err) + } + return stakedAmount, nil +} + +func (c *contract) WithdrawStake(ctx context.Context) (txHash common.Hash, err error) { + stakedAmount, err := c.getwithdrawableStake(ctx) + if err != nil { + return + } + + if stakedAmount.Cmp(big.NewInt(0)) <= 0 { + return common.Hash{}, ErrInsufficientStake + } + + receipt, err := c.withdrawFromStake(ctx) + if err != nil { + return common.Hash{}, err + } + if receipt != nil { + txHash = receipt.TxHash + } + return txHash, nil +} + +func (c *contract) MigrateStake(ctx context.Context) (txHash common.Hash, err error) { + isPaused, err := c.paused(ctx) + if err != nil { + return + } + if !isPaused { + return common.Hash{}, ErrNotPaused + } + + receipt, err := c.migrateStake(ctx) + if err != nil { + return common.Hash{}, err + } + if receipt != nil { + txHash = receipt.TxHash + } + return txHash, nil +} + +func (c *contract) IsOverlayFrozen(ctx context.Context, block uint64) (bool, error) { + callData, err := c.stakingContractABI.Pack("lastUpdatedBlockNumberOfAddress", c.owner) + if err != nil { + return false, err + } + + result, err := c.transactionService.Call(ctx, &transaction.TxRequest{ + To: &c.stakingContractAddress, + Data: callData, + }) + if err != nil { + return false, err + } + + results, err := c.stakingContractABI.Unpack("lastUpdatedBlockNumberOfOverlay", result) + if err != nil { + return false, err + } + + if len(results) == 0 { + return false, errors.New("unexpected empty results") + } + + lastUpdate := abi.ConvertType(results[0], new(big.Int)).(*big.Int) + + return lastUpdate.Uint64() >= block, nil +} + func (c *contract) sendApproveTransaction(ctx context.Context, amount *big.Int) (receipt *types.Receipt, err error) { callData, err := erc20ABI.Pack("approve", c.stakingContractAddress, amount) if err != nil { @@ -174,8 +305,8 @@ func (c *contract) sendDepositStakeTransaction(ctx context.Context, stakedAmount return receipt, nil } -func (c *contract) getStake(ctx context.Context) (*big.Int, error) { - callData, err := c.stakingContractABI.Pack("stakeOfAddress", c.owner) +func (c *contract) getPotentialStake(ctx context.Context) (*big.Int, error) { + callData, err := c.stakingContractABI.Pack("stakes", c.owner) if err != nil { return nil, err } @@ -184,72 +315,49 @@ func (c *contract) getStake(ctx context.Context) (*big.Int, error) { Data: callData, }) if err != nil { - return nil, fmt.Errorf("get stake: %w", err) + return nil, fmt.Errorf("get potential stake: %w", err) } - results, err := c.stakingContractABI.Unpack("stakeOfAddress", result) + // overlay bytes32, + // committedStake uint256, + // potentialStake uint256, + // lastUpdatedBlockNumber uint256, + // isValue bool + results, err := c.stakingContractABI.Unpack("stakes", result) if err != nil { return nil, err } - if len(results) == 0 { + if len(results) < 5 { return nil, errors.New("unexpected empty results") } - return abi.ConvertType(results[0], new(big.Int)).(*big.Int), nil + return abi.ConvertType(results[2], new(big.Int)).(*big.Int), nil } -func (c *contract) DepositStake(ctx context.Context, stakedAmount *big.Int) (common.Hash, error) { - prevStakedAmount, err := c.GetStake(ctx) - if err != nil { - return common.Hash{}, err - } - - if len(prevStakedAmount.Bits()) == 0 { - if stakedAmount.Cmp(MinimumStakeAmount) == -1 { - return common.Hash{}, ErrInsufficientStakeAmount - } - } - - balance, err := c.getBalance(ctx) +func (c *contract) getwithdrawableStake(ctx context.Context) (*big.Int, error) { + callData, err := c.stakingContractABI.Pack("withdrawableStake") if err != nil { - return common.Hash{}, err - } - - if balance.Cmp(stakedAmount) < 0 { - return common.Hash{}, ErrInsufficientFunds + return nil, err } - - _, err = c.sendApproveTransaction(ctx, stakedAmount) + result, err := c.transactionService.Call(ctx, &transaction.TxRequest{ + To: &c.stakingContractAddress, + Data: callData, + }) if err != nil { - return common.Hash{}, err + return nil, fmt.Errorf("get withdrawable stake: %w", err) } - receipt, err := c.sendDepositStakeTransaction(ctx, stakedAmount, c.overlayNonce) + results, err := c.stakingContractABI.Unpack("withdrawableStake", result) if err != nil { - return common.Hash{}, err + return nil, err } - return receipt.TxHash, nil -} - -// ChangeStakeOverlay only changes the overlay address used in the redistribution game. -func (c *contract) ChangeStakeOverlay(ctx context.Context, nonce common.Hash) (common.Hash, error) { - c.overlayNonce = nonce - receipt, err := c.sendDepositStakeTransaction(ctx, new(big.Int), c.overlayNonce) - if err != nil { - return common.Hash{}, err + if len(results) == 0 { + return nil, errors.New("unexpected empty results") } - return receipt.TxHash, nil -} - -func (c *contract) GetStake(ctx context.Context) (*big.Int, error) { - stakedAmount, err := c.getStake(ctx) - if err != nil { - return nil, fmt.Errorf("staking contract: failed to get stake: %w", err) - } - return stakedAmount, nil + return abi.ConvertType(results[0], new(big.Int)).(*big.Int), nil } func (c *contract) getBalance(ctx context.Context) (*big.Int, error) { @@ -278,48 +386,29 @@ func (c *contract) getBalance(ctx context.Context) (*big.Int, error) { return abi.ConvertType(results[0], new(big.Int)).(*big.Int), nil } -func (c *contract) WithdrawAllStake(ctx context.Context) (txHash common.Hash, err error) { - isPaused, err := c.paused(ctx) - if err != nil { - return - } - if !isPaused { - return common.Hash{}, ErrNotPaused - } - - stakedAmount, err := c.getStake(ctx) +func (c *contract) migrateStake(ctx context.Context) (*types.Receipt, error) { + callData, err := c.stakingContractABI.Pack("migrateStake") if err != nil { - return - } - - if stakedAmount.Cmp(big.NewInt(0)) <= 0 { - return common.Hash{}, ErrInsufficientStake + return nil, err } - _, err = c.sendApproveTransaction(ctx, stakedAmount) + receipt, err := c.sendTransaction(ctx, callData, migrateStakeDescription) if err != nil { - return common.Hash{}, err + return nil, fmt.Errorf("migrate stake: %w", err) } - receipt, err := c.withdrawFromStake(ctx, stakedAmount) - if err != nil { - return common.Hash{}, err - } - if receipt != nil { - txHash = receipt.TxHash - } - return txHash, nil + return receipt, nil } -func (c *contract) withdrawFromStake(ctx context.Context, stakedAmount *big.Int) (*types.Receipt, error) { - callData, err := c.stakingContractABI.Pack("withdrawFromStake", stakedAmount) +func (c *contract) withdrawFromStake(ctx context.Context) (*types.Receipt, error) { + callData, err := c.stakingContractABI.Pack("withdrawFromStake") if err != nil { return nil, err } receipt, err := c.sendTransaction(ctx, callData, withdrawStakeDescription) if err != nil { - return nil, fmt.Errorf("withdraw stake: stakedAmount %d: %w", stakedAmount, err) + return nil, fmt.Errorf("withdraw stake: %w", err) } return receipt, nil @@ -350,31 +439,3 @@ func (c *contract) paused(ctx context.Context) (bool, error) { return results[0].(bool), nil } - -func (c *contract) IsOverlayFrozen(ctx context.Context, block uint64) (bool, error) { - callData, err := c.stakingContractABI.Pack("lastUpdatedBlockNumberOfAddress", c.owner) - if err != nil { - return false, err - } - - result, err := c.transactionService.Call(ctx, &transaction.TxRequest{ - To: &c.stakingContractAddress, - Data: callData, - }) - if err != nil { - return false, err - } - - results, err := c.stakingContractABI.Unpack("lastUpdatedBlockNumberOfOverlay", result) - if err != nil { - return false, err - } - - if len(results) == 0 { - return false, errors.New("unexpected empty results") - } - - lastUpdate := abi.ConvertType(results[0], new(big.Int)).(*big.Int) - - return lastUpdate.Uint64() >= block, nil -} diff --git a/pkg/storageincentives/staking/contract_test.go b/pkg/storageincentives/staking/contract_test.go index 1a56bfa0f2d..2f66d27dd82 100644 --- a/pkg/storageincentives/staking/contract_test.go +++ b/pkg/storageincentives/staking/contract_test.go @@ -83,7 +83,7 @@ func TestDepositStake(t *testing.T) { return totalAmount.FillBytes(make([]byte, 32)), nil } if *request.To == stakingContractAddress { - return prevStake.FillBytes(make([]byte, 32)), nil + return getPotentialStakeResponse(t, prevStake), nil } return nil, errors.New("unexpected call") }), @@ -144,7 +144,7 @@ func TestDepositStake(t *testing.T) { return totalAmount.FillBytes(make([]byte, 32)), nil } if *request.To == stakingContractAddress { - return prevStake.FillBytes(make([]byte, 32)), nil + return getPotentialStakeResponse(t, prevStake), nil } return nil, errors.New("unexpected call") }), @@ -157,7 +157,7 @@ func TestDepositStake(t *testing.T) { if err != nil { t.Fatal(err) } - stakedAmount, err := contract.GetStake(ctx) + stakedAmount, err := contract.GetPotentialStake(ctx) if err != nil { t.Fatal(err) } @@ -183,7 +183,7 @@ func TestDepositStake(t *testing.T) { return totalAmount.FillBytes(make([]byte, 32)), nil } if *request.To == stakingContractAddress { - return prevStake.FillBytes(make([]byte, 32)), nil + return getPotentialStakeResponse(t, prevStake), nil } return nil, errors.New("unexpected call") }), @@ -215,7 +215,7 @@ func TestDepositStake(t *testing.T) { return totalAmount.FillBytes(make([]byte, 32)), nil } if *request.To == stakingContractAddress { - return prevStake.FillBytes(make([]byte, 32)), nil + return getPotentialStakeResponse(t, prevStake), nil } return nil, errors.New("unexpected call") }), @@ -247,7 +247,7 @@ func TestDepositStake(t *testing.T) { return totalAmount.FillBytes(make([]byte, 32)), nil } if *request.To == stakingContractAddress { - return prevStake.FillBytes(make([]byte, 32)), nil + return getPotentialStakeResponse(t, prevStake), nil } return nil, errors.New("unexpected call") }), @@ -395,7 +395,8 @@ func TestDepositStake(t *testing.T) { return totalAmount.FillBytes(make([]byte, 32)), nil } if *request.To == stakingContractAddress { - return prevStake.FillBytes(make([]byte, 32)), nil + return getPotentialStakeResponse(t, prevStake), nil + } return nil, errors.New("unexpected call") }), @@ -454,7 +455,8 @@ func TestDepositStake(t *testing.T) { return totalAmount.FillBytes(make([]byte, 32)), nil } if *request.To == stakingContractAddress { - return prevStake.FillBytes(make([]byte, 32)), nil + return getPotentialStakeResponse(t, prevStake), nil + } return nil, errors.New("unexpected call") }), @@ -696,7 +698,7 @@ func TestChangeStakeOverlay(t *testing.T) { }) } -func TestGetStake(t *testing.T) { +func TestGetCommittedStake(t *testing.T) { t.Parallel() ctx := context.Background() @@ -705,14 +707,15 @@ func TestGetStake(t *testing.T) { bzzTokenAddress := common.HexToAddress("eeee") nonce := common.BytesToHash(make([]byte, 32)) + expectedCallData, err := stakingContractABI.Pack("stakes", owner) + if err != nil { + t.Fatal(err) + } + t.Run("ok", func(t *testing.T) { t.Parallel() - prevStake := big.NewInt(0) - expectedCallData, err := stakingContractABI.Pack("stakeOfAddress", owner) - if err != nil { - t.Fatal(err) - } + prevStake := big.NewInt(100000000000000000) contract := staking.New( owner, @@ -725,7 +728,7 @@ func TestGetStake(t *testing.T) { if !bytes.Equal(expectedCallData[:64], request.Data[:64]) { return nil, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallData, request.Data) } - return prevStake.FillBytes(make([]byte, 32)), nil + return getPotentialStakeResponse(t, prevStake), nil } return nil, errors.New("unexpected call") }), @@ -734,18 +737,19 @@ func TestGetStake(t *testing.T) { false, ) - stakedAmount, err := contract.GetStake(ctx) + stakedAmount, err := contract.GetPotentialStake(ctx) if err != nil { t.Fatal(err) } - if stakedAmount.Cmp(big.NewInt(100000000000000000)) == 0 { - t.Fatalf("expected %v got %v", big.NewInt(100000000000000000), stakedAmount) + + if stakedAmount.Cmp(prevStake) != 0 { + t.Fatalf("expected %v got %v", prevStake, stakedAmount) } }) t.Run("error with unpacking", func(t *testing.T) { t.Parallel() - expectedCallData, err := stakingContractABI.Pack("stakeOfAddress", owner) + expectedCallData, err := stakingContractABI.Pack("stakes", owner) if err != nil { t.Fatal(err) } @@ -770,7 +774,7 @@ func TestGetStake(t *testing.T) { false, ) - _, err = contract.GetStake(ctx) + _, err = contract.GetPotentialStake(ctx) if err == nil { t.Fatal("expected error with unpacking") } @@ -782,7 +786,7 @@ func TestGetStake(t *testing.T) { addr := swarm.MustParseHexAddress("f30c0aa7e9e2a0ef4c9b1b750ebfeaeb7c7c24da700bb089da19a46e3677824b") prevStake := big.NewInt(0) - expectedCallData, err := stakingContractABI.Pack("stakeOfAddress", common.BytesToHash(addr.Bytes())) + expectedCallData, err := stakingContractABI.Pack("stakes", common.BytesToHash(addr.Bytes())) if err != nil { t.Fatal(err) } @@ -807,7 +811,7 @@ func TestGetStake(t *testing.T) { false, ) - _, err = contract.GetStake(ctx) + _, err = contract.GetPotentialStake(ctx) if err == nil { t.Fatal("expected error due to wrong call data") } @@ -830,40 +834,138 @@ func TestGetStake(t *testing.T) { false, ) - _, err := contract.GetStake(ctx) + _, err := contract.GetPotentialStake(ctx) if err == nil { t.Fatal("expected error") } }) } -func TestWithdrawStake(t *testing.T) { +func TestGetWithdrawableStake(t *testing.T) { t.Parallel() ctx := context.Background() owner := common.HexToAddress("abcd") - stakingContractAddress := common.HexToAddress("ffff") + stakingAddress := common.HexToAddress("ffff") bzzTokenAddress := common.HexToAddress("eeee") nonce := common.BytesToHash(make([]byte, 32)) - stakedAmount := big.NewInt(100000000000000000) - txHashApprove := common.HexToHash("abb0") + + expectedCallData, err := stakingContractABI.Pack("withdrawableStake") + if err != nil { + t.Fatal(err) + } t.Run("ok", func(t *testing.T) { t.Parallel() - txHashWithdrawn := common.HexToHash("c3a1") - expected := big.NewInt(1) - expectedCallDataForPaused, err := stakingContractABI.Pack("paused") + prevStake := big.NewInt(100000000000000000) + + contract := staking.New( + owner, + stakingAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingAddress { + if !bytes.Equal(expectedCallData[:32], request.Data[:32]) { + return nil, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallData, request.Data) + } + return prevStake.FillBytes(make([]byte, 32)), nil + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + false, + ) + + stakedAmount, err := contract.GetWithdrawableStake(ctx) if err != nil { t.Fatal(err) } - expectedCallDataForWithdraw, err := stakingContractABI.Pack("withdrawFromStake", stakedAmount) + if stakedAmount.Cmp(prevStake) != 0 { + t.Fatalf("expected %v got %v", prevStake, stakedAmount) + } + }) + + t.Run("error with unpacking", func(t *testing.T) { + t.Parallel() + expectedCallData, err := stakingContractABI.Pack("withdrawableStake") if err != nil { t.Fatal(err) } - expectedCallDataForGetStake, err := stakingContractABI.Pack("stakeOfAddress", owner) + contract := staking.New( + owner, + stakingAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingAddress { + if !bytes.Equal(expectedCallData[:32], request.Data[:32]) { + return nil, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallData, request.Data) + } + return []byte{}, nil + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + false, + ) + + _, err = contract.GetPotentialStake(ctx) + if err == nil { + t.Fatal("expected error with unpacking") + } + }) + + t.Run("transaction error", func(t *testing.T) { + t.Parallel() + + contract := staking.New( + owner, + stakingAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + return nil, errors.New("some error") + }), + ), + nonce, + false, + ) + + _, err := contract.GetPotentialStake(ctx) + if err == nil { + t.Fatal("expected error") + } + }) +} + +func TestWithdrawStake(t *testing.T) { + t.Parallel() + + ctx := context.Background() + owner := common.HexToAddress("abcd") + stakingContractAddress := common.HexToAddress("ffff") + bzzTokenAddress := common.HexToAddress("eeee") + nonce := common.BytesToHash(make([]byte, 32)) + stakedAmount := big.NewInt(100000000000000000) + + t.Run("ok", func(t *testing.T) { + t.Parallel() + txHashWithdrawn := common.HexToHash("c3a1") + + expectedCallDataForWithdraw, err := stakingContractABI.Pack("withdrawFromStake") + if err != nil { + t.Fatal(err) + } + expectedCallDataForGetStake, err := stakingContractABI.Pack("withdrawableStake") if err != nil { t.Fatal(err) } @@ -875,9 +977,6 @@ func TestWithdrawStake(t *testing.T) { bzzTokenAddress, transactionMock.New( transactionMock.WithSendFunc(func(ctx context.Context, request *transaction.TxRequest, boost int) (txHash common.Hash, err error) { - if *request.To == bzzTokenAddress { - return txHashApprove, nil - } if *request.To == stakingContractAddress { if !bytes.Equal(expectedCallDataForWithdraw[:], request.Data[:]) { return common.Hash{}, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallDataForWithdraw, request.Data) @@ -887,11 +986,98 @@ func TestWithdrawStake(t *testing.T) { return common.Hash{}, errors.New("sent to wrong contract") }), transactionMock.WithWaitForReceiptFunc(func(ctx context.Context, txHash common.Hash) (receipt *types.Receipt, err error) { - if txHash == txHashApprove { + if txHash == txHashWithdrawn { return &types.Receipt{ Status: 1, }, nil } + return nil, errors.New("unknown tx hash") + }), + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingContractAddress { + if bytes.Equal(expectedCallDataForGetStake[:32], request.Data[:32]) { + return stakedAmount.FillBytes(make([]byte, 32)), nil + } + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + false, + ) + + _, err = contract.WithdrawStake(ctx) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("has no stake", func(t *testing.T) { + t.Parallel() + + invalidStakedAmount := big.NewInt(0) + + expectedCallDataForGetStake, err := stakingContractABI.Pack("withdrawableStake") + if err != nil { + t.Fatal(err) + } + + contract := staking.New( + owner, + stakingContractAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingContractAddress { + if bytes.Equal(expectedCallDataForGetStake[:32], request.Data[:32]) { + return invalidStakedAmount.FillBytes(make([]byte, 32)), nil + } + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + false, + ) + + _, err = contract.WithdrawStake(ctx) + if !errors.Is(err, staking.ErrInsufficientStake) { + t.Fatal(err) + } + }) + + t.Run("send tx failed", func(t *testing.T) { + t.Parallel() + txHashWithdrawn := common.HexToHash("c3a1") + + expectedCallDataForWithdraw, err := stakingContractABI.Pack("withdrawFromStake") + if err != nil { + t.Fatal(err) + } + expectedCallDataForGetStake, err := stakingContractABI.Pack("withdrawableStake") + if err != nil { + t.Fatal(err) + } + + expectedErr := errors.New("tx err") + + contract := staking.New( + owner, + stakingContractAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithSendFunc(func(ctx context.Context, request *transaction.TxRequest, boost int) (txHash common.Hash, err error) { + if *request.To == stakingContractAddress { + if !bytes.Equal(expectedCallDataForWithdraw[:], request.Data[:]) { + return common.Hash{}, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallDataForWithdraw, request.Data) + } + return common.Hash{}, fmt.Errorf("send tx failed: %w", expectedErr) + } + return common.Hash{}, errors.New("sent to wrong contract") + }), + transactionMock.WithWaitForReceiptFunc(func(ctx context.Context, txHash common.Hash) (receipt *types.Receipt, err error) { if txHash == txHashWithdrawn { return &types.Receipt{ Status: 1, @@ -901,10 +1087,7 @@ func TestWithdrawStake(t *testing.T) { }), transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { if *request.To == stakingContractAddress { - if bytes.Equal(expectedCallDataForPaused[:], request.Data[:]) { - return expected.FillBytes(make([]byte, 32)), nil - } - if bytes.Equal(expectedCallDataForGetStake[:64], request.Data[:64]) { + if bytes.Equal(expectedCallDataForGetStake[:32], request.Data[:32]) { return stakedAmount.FillBytes(make([]byte, 32)), nil } } @@ -915,17 +1098,71 @@ func TestWithdrawStake(t *testing.T) { false, ) - _, err = contract.WithdrawAllStake(ctx) + _, err = contract.WithdrawStake(ctx) + if !errors.Is(err, expectedErr) { + t.Fatalf("expected err %v, got %v", expectedErr, err) + } + }) + + t.Run("tx reverted", func(t *testing.T) { + t.Parallel() + txHashWithdrawn := common.HexToHash("c3a1") + + expectedCallDataForWithdraw, err := stakingContractABI.Pack("withdrawFromStake") + if err != nil { + t.Fatal(err) + } + expectedCallDataForGetStake, err := stakingContractABI.Pack("withdrawableStake") if err != nil { t.Fatal(err) } + + contract := staking.New( + owner, + stakingContractAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithSendFunc(func(ctx context.Context, request *transaction.TxRequest, boost int) (txHash common.Hash, err error) { + if *request.To == stakingContractAddress { + if !bytes.Equal(expectedCallDataForWithdraw[:], request.Data[:]) { + return common.Hash{}, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallDataForWithdraw, request.Data) + } + return txHashWithdrawn, nil + } + return common.Hash{}, errors.New("sent to wrong contract") + }), + transactionMock.WithWaitForReceiptFunc(func(ctx context.Context, txHash common.Hash) (receipt *types.Receipt, err error) { + if txHash == txHashWithdrawn { + return &types.Receipt{ + Status: 0, + }, nil + } + return nil, errors.New("unknown tx hash") + }), + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingContractAddress { + if bytes.Equal(expectedCallDataForGetStake[:32], request.Data[:32]) { + return stakedAmount.FillBytes(make([]byte, 32)), nil + } + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + false, + ) + + _, err = contract.WithdrawStake(ctx) + if err == nil { + t.Fatalf("expected non nil error, got nil") + } }) - t.Run("is paused", func(t *testing.T) { + t.Run("get stake with err", func(t *testing.T) { t.Parallel() - expected := big.NewInt(0) - expectedCallDataForPaused, err := stakingContractABI.Pack("paused") + expectedCallDataForGetStake, err := stakingContractABI.Pack("withdrawableStake") if err != nil { t.Fatal(err) } @@ -938,8 +1175,8 @@ func TestWithdrawStake(t *testing.T) { transactionMock.New( transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { if *request.To == stakingContractAddress { - if bytes.Equal(expectedCallDataForPaused[:], request.Data[:]) { - return expected.FillBytes(make([]byte, 32)), nil + if bytes.Equal(expectedCallDataForGetStake[:32], request.Data[:32]) { + return nil, fmt.Errorf("some error") } } return nil, errors.New("unexpected call") @@ -949,24 +1186,92 @@ func TestWithdrawStake(t *testing.T) { false, ) - _, err = contract.WithdrawAllStake(ctx) - if !errors.Is(err, staking.ErrNotPaused) { - t.Fatal(err) + _, err = contract.WithdrawStake(ctx) + if err == nil { + t.Fatalf("expected non nil error, got nil") } }) +} + +func TestMigrateStake(t *testing.T) { + t.Parallel() + + ctx := context.Background() + owner := common.HexToAddress("abcd") + stakingContractAddress := common.HexToAddress("ffff") + bzzTokenAddress := common.HexToAddress("eeee") + nonce := common.BytesToHash(make([]byte, 32)) + stakedAmount := big.NewInt(100000000000000000) + + t.Run("ok", func(t *testing.T) { + + expectedCallDataForPaused, err := stakingContractABI.Pack("paused") + if err != nil { + t.Fatal(err) + } + expectedCallDataForWithdraw, err := stakingContractABI.Pack("migrateStake") + if err != nil { + t.Fatal(err) + } + expectedCallDataForGetStake, err := stakingContractABI.Pack("nodeEffectiveStake", owner) + if err != nil { + t.Fatal(err) + } - t.Run("has no stake", func(t *testing.T) { t.Parallel() + txHashWithdrawn := common.HexToHash("c3a1") expected := big.NewInt(1) - expectedCallDataForPaused, err := stakingContractABI.Pack("paused") + contract := staking.New( + owner, + stakingContractAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithSendFunc(func(ctx context.Context, request *transaction.TxRequest, boost int) (txHash common.Hash, err error) { + if *request.To == stakingContractAddress { + if !bytes.Equal(expectedCallDataForWithdraw[:], request.Data[:]) { + return common.Hash{}, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallDataForWithdraw, request.Data) + } + return txHashWithdrawn, nil + } + return common.Hash{}, errors.New("sent to wrong contract") + }), + transactionMock.WithWaitForReceiptFunc(func(ctx context.Context, txHash common.Hash) (receipt *types.Receipt, err error) { + if txHash == txHashWithdrawn { + return &types.Receipt{ + Status: 1, + }, nil + } + return nil, errors.New("unknown tx hash") + }), + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingContractAddress { + if bytes.Equal(expectedCallDataForPaused[:], request.Data[:]) { + return expected.FillBytes(make([]byte, 32)), nil + } + if bytes.Equal(expectedCallDataForGetStake[:64], request.Data[:64]) { + return stakedAmount.FillBytes(make([]byte, 32)), nil + } + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + false, + ) + + _, err = contract.MigrateStake(ctx) if err != nil { t.Fatal(err) } + }) - invalidStakedAmount := big.NewInt(0) + t.Run("is paused", func(t *testing.T) { + t.Parallel() + expected := big.NewInt(0) - expectedCallDataForGetStake, err := stakingContractABI.Pack("stakeOfAddress", owner) + expectedCallDataForPaused, err := stakingContractABI.Pack("paused") if err != nil { t.Fatal(err) } @@ -982,9 +1287,6 @@ func TestWithdrawStake(t *testing.T) { if bytes.Equal(expectedCallDataForPaused[:], request.Data[:]) { return expected.FillBytes(make([]byte, 32)), nil } - if bytes.Equal(expectedCallDataForGetStake[:64], request.Data[:64]) { - return invalidStakedAmount.FillBytes(make([]byte, 32)), nil - } } return nil, errors.New("unexpected call") }), @@ -993,8 +1295,8 @@ func TestWithdrawStake(t *testing.T) { false, ) - _, err = contract.WithdrawAllStake(ctx) - if !errors.Is(err, staking.ErrInsufficientStake) { + _, err = contract.MigrateStake(ctx) + if !errors.Is(err, staking.ErrNotPaused) { t.Fatal(err) } }) @@ -1003,16 +1305,16 @@ func TestWithdrawStake(t *testing.T) { t.Parallel() _, err := stakingContractABI.Pack("paused", owner) if err == nil { - t.Fatal(err) + t.Fatalf("expected non nil error, got nil") } - _, err = stakingContractABI.Pack("withdrawFromStake", owner, stakedAmount) + _, err = stakingContractABI.Pack("migrateStake", owner) if err == nil { - t.Fatal(err) + t.Fatalf("expected non nil error, got nil") } - _, err = stakingContractABI.Pack("stakeOfAddress", stakedAmount) + _, err = stakingContractABI.Pack("nodeEffectiveStake", stakedAmount) if err == nil { - t.Fatal(err) + t.Fatalf("expected non nil error, got nil") } }) @@ -1025,13 +1327,11 @@ func TestWithdrawStake(t *testing.T) { if err != nil { t.Fatal(err) } - - expectedCallDataForWithdraw, err := stakingContractABI.Pack("withdrawFromStake", stakedAmount) + expectedCallDataForWithdraw, err := stakingContractABI.Pack("migrateStake") if err != nil { t.Fatal(err) } - - expectedCallDataForGetStake, err := stakingContractABI.Pack("stakeOfAddress", owner) + expectedCallDataForGetStake, err := stakingContractABI.Pack("nodeEffectiveStake", owner) if err != nil { t.Fatal(err) } @@ -1043,9 +1343,6 @@ func TestWithdrawStake(t *testing.T) { bzzTokenAddress, transactionMock.New( transactionMock.WithSendFunc(func(ctx context.Context, request *transaction.TxRequest, boost int) (txHash common.Hash, err error) { - if *request.To == bzzTokenAddress { - return txHashApprove, nil - } if *request.To == stakingContractAddress { if !bytes.Equal(expectedCallDataForWithdraw[:], request.Data[:]) { return common.Hash{}, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallDataForWithdraw, request.Data) @@ -1055,11 +1352,6 @@ func TestWithdrawStake(t *testing.T) { return common.Hash{}, errors.New("sent to wrong contract") }), transactionMock.WithWaitForReceiptFunc(func(ctx context.Context, txHash common.Hash) (receipt *types.Receipt, err error) { - if txHash == txHashApprove { - return &types.Receipt{ - Status: 1, - }, nil - } if txHash == txHashWithdrawn { return &types.Receipt{ Status: 1, @@ -1083,32 +1375,31 @@ func TestWithdrawStake(t *testing.T) { false, ) - _, err = contract.WithdrawAllStake(ctx) + _, err = contract.MigrateStake(ctx) if err == nil { - t.Fatal(err) + t.Fatalf("expected non nil error, got nil") } }) t.Run("tx reverted", func(t *testing.T) { - t.Parallel() - txHashWithdrawn := common.HexToHash("c3a1") - expected := big.NewInt(1) expectedCallDataForPaused, err := stakingContractABI.Pack("paused") if err != nil { t.Fatal(err) } - - expectedCallDataForGetStake, err := stakingContractABI.Pack("stakeOfAddress", owner) + expectedCallDataForWithdraw, err := stakingContractABI.Pack("migrateStake") if err != nil { t.Fatal(err) } - - expectedCallDataForWithdraw, err := stakingContractABI.Pack("withdrawFromStake", stakedAmount) + expectedCallDataForGetStake, err := stakingContractABI.Pack("nodeEffectiveStake", owner) if err != nil { t.Fatal(err) } + t.Parallel() + txHashWithdrawn := common.HexToHash("c3a1") + expected := big.NewInt(1) + contract := staking.New( owner, stakingContractAddress, @@ -1116,9 +1407,6 @@ func TestWithdrawStake(t *testing.T) { bzzTokenAddress, transactionMock.New( transactionMock.WithSendFunc(func(ctx context.Context, request *transaction.TxRequest, boost int) (txHash common.Hash, err error) { - if *request.To == bzzTokenAddress { - return txHashApprove, nil - } if *request.To == stakingContractAddress { if !bytes.Equal(expectedCallDataForWithdraw[:], request.Data[:]) { return common.Hash{}, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallDataForWithdraw, request.Data) @@ -1128,11 +1416,6 @@ func TestWithdrawStake(t *testing.T) { return common.Hash{}, errors.New("sent to wrong contract") }), transactionMock.WithWaitForReceiptFunc(func(ctx context.Context, txHash common.Hash) (receipt *types.Receipt, err error) { - if txHash == txHashApprove { - return &types.Receipt{ - Status: 1, - }, nil - } if txHash == txHashWithdrawn { return &types.Receipt{ Status: 0, @@ -1156,19 +1439,21 @@ func TestWithdrawStake(t *testing.T) { false, ) - _, err = contract.WithdrawAllStake(ctx) + _, err = contract.MigrateStake(ctx) if err == nil { - t.Fatal(err) + t.Fatalf("expected non nil error, got nil") } }) t.Run("is paused with err", func(t *testing.T) { - t.Parallel() + expectedCallDataForPaused, err := stakingContractABI.Pack("paused") if err != nil { t.Fatal(err) } + t.Parallel() + contract := staking.New( owner, stakingContractAddress, @@ -1188,26 +1473,26 @@ func TestWithdrawStake(t *testing.T) { false, ) - _, err = contract.WithdrawAllStake(ctx) + _, err = contract.WithdrawStake(ctx) if err == nil { - t.Fatal(err) + t.Fatalf("expected non nil error, got nil") } }) t.Run("get stake with err", func(t *testing.T) { - t.Parallel() - expected := big.NewInt(1) expectedCallDataForPaused, err := stakingContractABI.Pack("paused") if err != nil { t.Fatal(err) } - - expectedCallDataForGetStake, err := stakingContractABI.Pack("stakeOfAddress", owner) + expectedCallDataForGetStake, err := stakingContractABI.Pack("nodeEffectiveStake", owner) if err != nil { t.Fatal(err) } + t.Parallel() + expected := big.NewInt(1) + contract := staking.New( owner, stakingContractAddress, @@ -1230,9 +1515,19 @@ func TestWithdrawStake(t *testing.T) { false, ) - _, err = contract.WithdrawAllStake(ctx) + _, err = contract.MigrateStake(ctx) if err == nil { - t.Fatal(err) + t.Fatalf("expected non nil error, got nil") } }) } + +func getPotentialStakeResponse(t *testing.T, amount *big.Int) []byte { + t.Helper() + + ret := make([]byte, 32+32+32+32+32+32) + copy(ret, swarm.RandAddress(t).Bytes()) + copy(ret[64:], amount.FillBytes(make([]byte, 32))) + + return ret +} diff --git a/pkg/storageincentives/staking/mock/contract.go b/pkg/storageincentives/staking/mock/contract.go index 18f06e02eef..93fab73f937 100644 --- a/pkg/storageincentives/staking/mock/contract.go +++ b/pkg/storageincentives/staking/mock/contract.go @@ -16,6 +16,7 @@ type stakingContractMock struct { depositStake func(ctx context.Context, stakedAmount *big.Int) (common.Hash, error) getStake func(ctx context.Context) (*big.Int, error) withdrawAllStake func(ctx context.Context) (common.Hash, error) + migrateStake func(ctx context.Context) (common.Hash, error) isFrozen func(ctx context.Context, block uint64) (bool, error) } @@ -27,14 +28,22 @@ func (s *stakingContractMock) ChangeStakeOverlay(_ context.Context, h common.Has return h, nil } -func (s *stakingContractMock) GetStake(ctx context.Context) (*big.Int, error) { +func (s *stakingContractMock) GetPotentialStake(ctx context.Context) (*big.Int, error) { return s.getStake(ctx) } -func (s *stakingContractMock) WithdrawAllStake(ctx context.Context) (common.Hash, error) { +func (s *stakingContractMock) GetWithdrawableStake(ctx context.Context) (*big.Int, error) { + return s.getStake(ctx) +} + +func (s *stakingContractMock) WithdrawStake(ctx context.Context) (common.Hash, error) { return s.withdrawAllStake(ctx) } +func (s *stakingContractMock) MigrateStake(ctx context.Context) (common.Hash, error) { + return s.migrateStake(ctx) +} + func (s *stakingContractMock) IsOverlayFrozen(ctx context.Context, block uint64) (bool, error) { return s.isFrozen(ctx, block) } @@ -65,12 +74,18 @@ func WithGetStake(f func(ctx context.Context) (*big.Int, error)) Option { } } -func WithWithdrawAllStake(f func(ctx context.Context) (common.Hash, error)) Option { +func WithWithdrawStake(f func(ctx context.Context) (common.Hash, error)) Option { return func(mock *stakingContractMock) { mock.withdrawAllStake = f } } +func WithMigrateStake(f func(ctx context.Context) (common.Hash, error)) Option { + return func(mock *stakingContractMock) { + mock.migrateStake = f + } +} + func WithIsFrozen(f func(ctx context.Context, block uint64) (bool, error)) Option { return func(mock *stakingContractMock) { mock.isFrozen = f