From 83a96ffe47423597101610820fb30f25f9f8ab2a Mon Sep 17 00:00:00 2001 From: Hanjun Kim Date: Fri, 29 Nov 2024 16:14:40 +0900 Subject: [PATCH] refactor: return an error instead of (bool, error) this is the common convention and the caller is still able to check if the error was collections.ErrNotFound --- .../keeper/send_restriction_test.go | 9 +- x/operators/abci_test.go | 6 +- x/operators/keeper/alias_functions.go | 12 +- x/operators/keeper/genesis.go | 11 +- x/operators/keeper/genesis_test.go | 6 +- x/operators/keeper/grpc_query.go | 20 ++-- x/operators/keeper/msg_server.go | 55 ++++----- x/operators/keeper/msg_server_test.go | 24 ++-- x/operators/keeper/operators.go | 11 +- x/operators/keeper/operators_test.go | 40 +++---- x/pools/keeper/grpc_query.go | 11 +- x/pools/keeper/pools.go | 12 +- x/pools/keeper/pools_test.go | 35 +++--- x/restaking/keeper/alias_functions.go | 46 ++++---- x/restaking/keeper/alias_functions_test.go | 2 +- x/restaking/keeper/grpc_query.go | 110 ++++++++---------- x/restaking/keeper/invariants.go | 49 ++++---- x/restaking/keeper/msg_server.go | 91 +++++++-------- x/restaking/keeper/operator_restaking.go | 18 ++- x/restaking/keeper/operator_restaking_test.go | 9 +- x/restaking/keeper/operators_hooks.go | 10 +- x/restaking/keeper/operators_hooks_test.go | 18 +-- x/restaking/keeper/pool_restaking_test.go | 9 +- x/restaking/keeper/service_restaking.go | 18 ++- x/restaking/keeper/service_restaking_test.go | 9 +- x/restaking/types/expected_keepers.go | 6 +- x/rewards/keeper/allocation.go | 9 +- x/rewards/keeper/allocation_test.go | 2 +- x/rewards/keeper/common_test.go | 18 +-- x/rewards/keeper/hooks.go | 11 +- x/rewards/keeper/msg_server.go | 28 +++-- x/rewards/keeper/rewards_plan.go | 19 ++- x/rewards/keeper/target.go | 24 ++-- x/rewards/keeper/withdraw.go | 11 +- x/rewards/types/expected_keepers.go | 6 +- x/services/keeper/grpc_query.go | 20 ++-- x/services/keeper/msg_server.go | 55 ++++----- x/services/keeper/msg_server_test.go | 32 ++--- x/services/keeper/services.go | 47 +++----- x/services/keeper/services_test.go | 33 ++---- 40 files changed, 410 insertions(+), 552 deletions(-) diff --git a/x/liquidvesting/keeper/send_restriction_test.go b/x/liquidvesting/keeper/send_restriction_test.go index ce699936..a614bbfd 100644 --- a/x/liquidvesting/keeper/send_restriction_test.go +++ b/x/liquidvesting/keeper/send_restriction_test.go @@ -81,9 +81,8 @@ func (suite *KeeperTestSuite) TestKeeper_BankSend() { lockedDenom, err := types.GetLockedRepresentationDenom("stake") suite.Require().NoError(err) - pool, found, err := suite.pk.GetPool(ctx, 1) + pool, err := suite.pk.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) poolCoins := suite.bk.GetAllBalances(ctx, sdk.MustAccAddressFromBech32(pool.GetAddress())) suite.Require().Equal(sdk.NewCoins(sdk.NewInt64Coin(lockedDenom, 1000)), poolCoins) @@ -126,9 +125,8 @@ func (suite *KeeperTestSuite) TestKeeper_BankSend() { lockedDenom, err := types.GetLockedRepresentationDenom("stake") suite.Require().NoError(err) - operator, found, err := suite.ok.GetOperator(ctx, 1) + operator, err := suite.ok.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) poolCoins := suite.bk.GetAllBalances(ctx, sdk.MustAccAddressFromBech32(operator.GetAddress())) suite.Require().Equal(sdk.NewCoins(sdk.NewInt64Coin(lockedDenom, 1000)), poolCoins) @@ -173,9 +171,8 @@ func (suite *KeeperTestSuite) TestKeeper_BankSend() { lockedDenom, err := types.GetLockedRepresentationDenom("stake") suite.Require().NoError(err) - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) poolCoins := suite.bk.GetAllBalances(ctx, sdk.MustAccAddressFromBech32(service.GetAddress())) suite.Require().Equal(sdk.NewCoins(sdk.NewInt64Coin(lockedDenom, 1000)), poolCoins) diff --git a/x/operators/abci_test.go b/x/operators/abci_test.go index e6ad6857..7db77cea 100644 --- a/x/operators/abci_test.go +++ b/x/operators/abci_test.go @@ -78,9 +78,8 @@ func TestBeginBlocker(t *testing.T) { }, check: func(ctx sdk.Context) { // Make sure the operator is still inactivating - operator, found, err := operatorsKeeper.GetOperator(ctx, 1) + operator, err := operatorsKeeper.GetOperator(ctx, 1) require.NoError(t, err) - require.True(t, found) require.Equal(t, types.OPERATOR_STATUS_INACTIVATING, operator.Status) // Make sure the operator is still in the inactivating queue @@ -116,9 +115,8 @@ func TestBeginBlocker(t *testing.T) { }, check: func(ctx sdk.Context) { // Make sure the operator is inactive - operator, found, err := operatorsKeeper.GetOperator(ctx, 1) + operator, err := operatorsKeeper.GetOperator(ctx, 1) require.NoError(t, err) - require.True(t, found) require.Equal(t, types.OPERATOR_STATUS_INACTIVE, operator.Status) // Make sure the operator is not in the inactivating queue diff --git a/x/operators/keeper/alias_functions.go b/x/operators/keeper/alias_functions.go index 864a94b9..8c9892c4 100644 --- a/x/operators/keeper/alias_functions.go +++ b/x/operators/keeper/alias_functions.go @@ -2,9 +2,11 @@ package keeper import ( "context" + "errors" "fmt" "time" + "cosmossdk.io/collections" storetypes "cosmossdk.io/store/types" "github.com/cosmos/cosmos-sdk/runtime" "github.com/cosmos/cosmos-sdk/telemetry" @@ -44,15 +46,13 @@ func (k *Keeper) GetOperators(ctx context.Context) ([]types.Operator, error) { func (k *Keeper) IterateInactivatingOperatorQueue(ctx context.Context, endTime time.Time, fn func(operator types.Operator) (stop bool, err error)) error { return k.iterateInactivatingOperatorsKeys(ctx, endTime, func(key, value []byte) (stop bool, err error) { operatorID, _ := types.SplitInactivatingOperatorQueueKey(key) - operator, found, err := k.GetOperator(ctx, operatorID) + operator, err := k.GetOperator(ctx, operatorID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return true, fmt.Errorf("operator %d does not exist", operatorID) + } return true, err } - - if !found { - return true, fmt.Errorf("operator %d does not exist", operatorID) - } - return fn(operator) }) } diff --git a/x/operators/keeper/genesis.go b/x/operators/keeper/genesis.go index 8b18b176..76330134 100644 --- a/x/operators/keeper/genesis.go +++ b/x/operators/keeper/genesis.go @@ -1,8 +1,10 @@ package keeper import ( + "errors" "fmt" + "cosmossdk.io/collections" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/milkyway-labs/milkyway/v2/x/operators/types" @@ -65,15 +67,14 @@ func (k *Keeper) InitGenesis(ctx sdk.Context, state *types.GenesisState) error { // Store the operator params for _, operatorParams := range state.OperatorsParams { // Ensure that the operator is present - _, found, err := k.GetOperator(ctx, operatorParams.OperatorID) + _, err := k.GetOperator(ctx, operatorParams.OperatorID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return fmt.Errorf("can't set operator params for %d, operator not found", operatorParams.OperatorID) + } return err } - if !found { - return fmt.Errorf("can't set operator params for %d, operator not found", operatorParams.OperatorID) - } - err = k.SaveOperatorParams(ctx, operatorParams.OperatorID, operatorParams.Params) if err != nil { return err diff --git a/x/operators/keeper/genesis_test.go b/x/operators/keeper/genesis_test.go index de5a4ffb..406b4824 100644 --- a/x/operators/keeper/genesis_test.go +++ b/x/operators/keeper/genesis_test.go @@ -289,9 +289,8 @@ func (suite *KeeperTestSuite) TestKeeper_InitGenesis() { Params: types.DefaultParams(), }, check: func(ctx sdk.Context) { - operator, found, err := suite.k.GetOperator(ctx, 1) + operator, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -305,9 +304,8 @@ func (suite *KeeperTestSuite) TestKeeper_InitGenesis() { params, err := suite.k.GetOperatorParams(ctx, 1) suite.Require().Equal(types.DefaultOperatorParams(), params) - operator, found, err = suite.k.GetOperator(ctx, 2) + operator, err = suite.k.GetOperator(ctx, 2) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 2, types.OPERATOR_STATUS_INACTIVATING, diff --git a/x/operators/keeper/grpc_query.go b/x/operators/keeper/grpc_query.go index c6adae6a..05eb1672 100644 --- a/x/operators/keeper/grpc_query.go +++ b/x/operators/keeper/grpc_query.go @@ -2,7 +2,9 @@ package keeper import ( "context" + "errors" + "cosmossdk.io/collections" "github.com/cosmos/cosmos-sdk/types/query" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -14,15 +16,14 @@ var _ types.QueryServer = &Keeper{} // Operator implements the Query/Operator gRPC method func (k *Keeper) Operator(ctx context.Context, request *types.QueryOperatorRequest) (*types.QueryOperatorResponse, error) { - operator, found, err := k.GetOperator(ctx, request.OperatorId) + operator, err := k.GetOperator(ctx, request.OperatorId) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "operator not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "operator not found") - } - return &types.QueryOperatorResponse{Operator: operator}, nil } @@ -41,15 +42,14 @@ func (k *Keeper) Operators(ctx context.Context, request *types.QueryOperatorsReq } func (k *Keeper) OperatorParams(ctx context.Context, request *types.QueryOperatorParamsRequest) (*types.QueryOperatorParamsResponse, error) { - _, found, err := k.GetOperator(ctx, request.OperatorId) + _, err := k.GetOperator(ctx, request.OperatorId) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "operator not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, types.ErrOperatorNotFound - } - params, err := k.GetOperatorParams(ctx, request.OperatorId) if err != nil { return nil, err diff --git a/x/operators/keeper/msg_server.go b/x/operators/keeper/msg_server.go index 9dac7563..78d76a6c 100644 --- a/x/operators/keeper/msg_server.go +++ b/x/operators/keeper/msg_server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "cosmossdk.io/collections" "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -98,15 +99,14 @@ func (k msgServer) RegisterOperator(ctx context.Context, msg *types.MsgRegisterO // UpdateOperator defines the rpc method for Msg/UpdateOperator func (k msgServer) UpdateOperator(ctx context.Context, msg *types.MsgUpdateOperator) (*types.MsgUpdateOperatorResponse, error) { // Check if the operator exists - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can update the operator if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can update the operator") @@ -141,15 +141,14 @@ func (k msgServer) UpdateOperator(ctx context.Context, msg *types.MsgUpdateOpera // DeactivateOperator defines the rpc method for Msg/DeactivateOperator func (k msgServer) DeactivateOperator(ctx context.Context, msg *types.MsgDeactivateOperator) (*types.MsgDeactivateOperatorResponse, error) { // Check if the operator exists - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can deactivate the operator if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can deactivate the operator") @@ -175,15 +174,14 @@ func (k msgServer) DeactivateOperator(ctx context.Context, msg *types.MsgDeactiv // ReactivateOperator defines the rpc method for Msg/ReactivateOperator func (k msgServer) ReactivateOperator(ctx context.Context, msg *types.MsgReactivateOperator) (*types.MsgReactivateOperatorResponse, error) { // Check if the operator exists - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can reactivate the operator if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can deactivate the operator") @@ -209,15 +207,14 @@ func (k msgServer) ReactivateOperator(ctx context.Context, msg *types.MsgReactiv // DeleteOperator defines the rpc method for Msg/DeleteOperator func (k msgServer) DeleteOperator(ctx context.Context, msg *types.MsgDeleteOperator) (*types.MsgDeleteOperatorResponse, error) { // Check if the operator exists - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can delete the operator if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can delete the operator") @@ -243,15 +240,14 @@ func (k msgServer) DeleteOperator(ctx context.Context, msg *types.MsgDeleteOpera // TransferOperatorOwnership defines the rpc method for Msg/TransferOperatorOwnership func (k msgServer) TransferOperatorOwnership(ctx context.Context, msg *types.MsgTransferOperatorOwnership) (*types.MsgTransferOperatorOwnershipResponse, error) { // Check if the operator exists - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can transfer the operator ownership if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can transfer the operator ownership") @@ -278,15 +274,14 @@ func (k msgServer) TransferOperatorOwnership(ctx context.Context, msg *types.Msg // SetOperatorParams defines the rpc method for Msg/SetOperatorParams func (k msgServer) SetOperatorParams(ctx context.Context, msg *types.MsgSetOperatorParams) (*types.MsgSetOperatorParamsResponse, error) { - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can update the operator if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can update the operator params") diff --git a/x/operators/keeper/msg_server_test.go b/x/operators/keeper/msg_server_test.go index 84fbda15..9e25806a 100644 --- a/x/operators/keeper/msg_server_test.go +++ b/x/operators/keeper/msg_server_test.go @@ -3,6 +3,7 @@ package keeper_test import ( "time" + "cosmossdk.io/collections" sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" @@ -122,9 +123,8 @@ func (suite *KeeperTestSuite) TestMsgServer_RegisterOperator() { }, check: func(ctx sdk.Context) { // Make sure the operator was stored - stored, found, err := suite.k.GetOperator(ctx, 2) + stored, err := suite.k.GetOperator(ctx, 2) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 2, types.OPERATOR_STATUS_ACTIVE, @@ -210,9 +210,8 @@ func (suite *KeeperTestSuite) TestMsgServer_RegisterOperator() { }, check: func(ctx sdk.Context) { // Make sure the operator was stored - stored, found, err := suite.k.GetOperator(ctx, 2) + stored, err := suite.k.GetOperator(ctx, 2) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 2, types.OPERATOR_STATUS_ACTIVE, @@ -387,9 +386,8 @@ func (suite *KeeperTestSuite) TestMsgServer_UpdateOperator() { }, check: func(ctx sdk.Context) { // Make sure the operator was updated - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -548,9 +546,8 @@ func (suite *KeeperTestSuite) TestMsgServer_DeactivateOperator() { }, check: func(ctx sdk.Context) { // Make sure the operator was updated - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_INACTIVATING, @@ -703,9 +700,8 @@ func (suite *KeeperTestSuite) TestMsgServer_ReactivateOperator() { ), }, check: func(ctx sdk.Context) { - operator, found, err := suite.k.GetOperator(ctx, 1) + operator, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -823,9 +819,8 @@ func (suite *KeeperTestSuite) TestMsgServer_TransferOperatorOwnership() { }, check: func(ctx sdk.Context) { // Make sure the operator was updated - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -971,9 +966,8 @@ func (suite *KeeperTestSuite) TestMsgServer_DeleteOperator() { }, check: func(ctx sdk.Context) { // Make sure the operator was updated - _, found, err := suite.k.GetOperator(ctx, 1) - suite.Require().NoError(err) - suite.Require().False(found) + _, err := suite.k.GetOperator(ctx, 1) + suite.Require().ErrorIs(err, collections.ErrNotFound) // Ensure the hook has been called suite.Require().True(suite.hooks.CalledMap["BeforeOperatorDeleted"]) diff --git a/x/operators/keeper/operators.go b/x/operators/keeper/operators.go index b1129015..3a162627 100644 --- a/x/operators/keeper/operators.go +++ b/x/operators/keeper/operators.go @@ -56,15 +56,8 @@ func (k *Keeper) CreateOperator(ctx context.Context, operator types.Operator) er // GetOperator returns the operator with the given ID. // If the operator does not exist, false is returned. -func (k *Keeper) GetOperator(ctx context.Context, operatorID uint32) (operator types.Operator, found bool, err error) { - operator, err = k.operators.Get(ctx, operatorID) - if err != nil { - if errors.IsOf(err, collections.ErrNotFound) { - return types.Operator{}, false, nil - } - return types.Operator{}, false, err - } - return operator, true, nil +func (k *Keeper) GetOperator(ctx context.Context, operatorID uint32) (operator types.Operator, err error) { + return k.operators.Get(ctx, operatorID) } // SaveOperator stores the given operator in the KVStore diff --git a/x/operators/keeper/operators_test.go b/x/operators/keeper/operators_test.go index a5f8b202..a6fb4ae0 100644 --- a/x/operators/keeper/operators_test.go +++ b/x/operators/keeper/operators_test.go @@ -3,6 +3,7 @@ package keeper_test import ( "time" + "cosmossdk.io/collections" sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" @@ -147,9 +148,8 @@ func (suite *KeeperTestSuite) TestKeeper_CreateOperator() { shouldErr: false, check: func(ctx sdk.Context) { // Make sure the operator has been stored - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -198,14 +198,12 @@ func (suite *KeeperTestSuite) TestKeeper_GetOperator() { setup func() store func(ctx sdk.Context) operatorID uint32 - shouldErr bool expFound bool expOperator types.Operator }{ { name: "non existing operator returns false", operatorID: 1, - shouldErr: false, expFound: false, }, { @@ -223,7 +221,6 @@ func (suite *KeeperTestSuite) TestKeeper_GetOperator() { }, operatorID: 1, expFound: true, - shouldErr: false, expOperator: types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -248,15 +245,12 @@ func (suite *KeeperTestSuite) TestKeeper_GetOperator() { tc.store(ctx) } - operator, found, err := suite.k.GetOperator(ctx, tc.operatorID) - if tc.shouldErr { - suite.Require().Error(err) - } else { + operator, err := suite.k.GetOperator(ctx, tc.operatorID) + if tc.expFound { suite.Require().NoError(err) - suite.Require().Equal(tc.expFound, found) - if tc.expFound { - suite.Require().Equal(tc.expOperator, operator) - } + suite.Require().Equal(tc.expOperator, operator) + } else { + suite.Require().ErrorIs(err, collections.ErrNotFound) } }) } @@ -283,9 +277,8 @@ func (suite *KeeperTestSuite) TestKeeper_SaveOperator() { ), shouldErr: false, check: func(ctx sdk.Context) { - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -319,9 +312,8 @@ func (suite *KeeperTestSuite) TestKeeper_SaveOperator() { ), shouldErr: false, check: func(ctx sdk.Context) { - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_INACTIVATING, @@ -431,9 +423,8 @@ func (suite *KeeperTestSuite) TestKeeper_StartOperatorInactivation() { ), check: func(ctx sdk.Context) { // Make sure the operator status has been updated - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_INACTIVATING, @@ -524,9 +515,8 @@ func (suite *KeeperTestSuite) TestKeeper_CompleteOperatorInactivation() { ), check: func(ctx sdk.Context) { // Make sure the operator status has been updated - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_INACTIVE, @@ -637,9 +627,8 @@ func (suite *KeeperTestSuite) TestKeeper_ReactivateInactiveOperator() { operatorID: 1, shouldErr: false, check: func(ctx sdk.Context) { - operator, found, err := suite.k.GetOperator(ctx, 1) + operator, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -665,11 +654,8 @@ func (suite *KeeperTestSuite) TestKeeper_ReactivateInactiveOperator() { if tc.store != nil { tc.store(ctx) } - operator, found, err := suite.k.GetOperator(ctx, tc.operatorID) + operator, err := suite.k.GetOperator(ctx, tc.operatorID) suite.Require().NoError(err) - if !found { - suite.Fail("operator not found") - } err = suite.k.ReactivateInactiveOperator(ctx, operator) if tc.shouldErr { diff --git a/x/pools/keeper/grpc_query.go b/x/pools/keeper/grpc_query.go index 3a8e652c..8a4f5ffc 100644 --- a/x/pools/keeper/grpc_query.go +++ b/x/pools/keeper/grpc_query.go @@ -2,7 +2,9 @@ package keeper import ( "context" + "errors" + "cosmossdk.io/collections" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/query" "google.golang.org/grpc/codes" @@ -19,15 +21,14 @@ func (k *Keeper) PoolByID(ctx context.Context, request *types.QueryPoolByIdReque return nil, status.Error(codes.InvalidArgument, "invalid pool id") } - pool, found, err := k.GetPool(ctx, request.PoolId) + pool, err := k.GetPool(ctx, request.PoolId) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "pool not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "pool not found") - } - return &types.QueryPoolResponse{Pool: pool}, nil } diff --git a/x/pools/keeper/pools.go b/x/pools/keeper/pools.go index 497a5faa..b0694418 100644 --- a/x/pools/keeper/pools.go +++ b/x/pools/keeper/pools.go @@ -3,7 +3,6 @@ package keeper import ( "context" - "cosmossdk.io/collections" "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -52,13 +51,6 @@ func (k *Keeper) SavePool(ctx context.Context, pool types.Pool) error { // GetPool retrieves the pool with the given ID from the store. // If the pool does not exist, false is returned instead -func (k *Keeper) GetPool(ctx context.Context, id uint32) (types.Pool, bool, error) { - pool, err := k.pools.Get(ctx, id) - if err != nil { - if errors.IsOf(err, collections.ErrNotFound) { - return types.Pool{}, false, nil - } - return types.Pool{}, false, err - } - return pool, true, nil +func (k *Keeper) GetPool(ctx context.Context, id uint32) (types.Pool, error) { + return k.pools.Get(ctx, id) } diff --git a/x/pools/keeper/pools_test.go b/x/pools/keeper/pools_test.go index 46e447be..0af6096d 100644 --- a/x/pools/keeper/pools_test.go +++ b/x/pools/keeper/pools_test.go @@ -1,6 +1,7 @@ package keeper_test import ( + "cosmossdk.io/collections" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/milkyway-labs/milkyway/v2/x/pools/types" @@ -124,9 +125,8 @@ func (suite *KeeperTestSuite) TestKeeper_SavePool() { pool: types.NewPool(1, "uatom"), check: func(ctx sdk.Context) { // Make sure the pool is saved properly - pool, found, err := suite.k.GetPool(ctx, 1) + pool, err := suite.k.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewPool(1, "uatom"), pool) // Make sure the pool account is created @@ -147,9 +147,8 @@ func (suite *KeeperTestSuite) TestKeeper_SavePool() { shouldErr: false, check: func(ctx sdk.Context) { // Make sure the pool is saved properly - pool, found, err := suite.k.GetPool(ctx, 1) + pool, err := suite.k.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewPool(1, "usdt"), pool) // Make sure the pool account is created @@ -186,14 +185,13 @@ func (suite *KeeperTestSuite) TestKeeper_SavePool() { func (suite *KeeperTestSuite) TestKeeper_GetPool() { testCases := []struct { - name string - setup func() - store func(ctx sdk.Context) - poolID uint32 - shouldErr bool - expFound bool - expPool types.Pool - check func(ctx sdk.Context) + name string + setup func() + store func(ctx sdk.Context) + poolID uint32 + expFound bool + expPool types.Pool + check func(ctx sdk.Context) }{ { name: "not found pool returns error", @@ -223,15 +221,12 @@ func (suite *KeeperTestSuite) TestKeeper_GetPool() { tc.store(ctx) } - pool, found, err := suite.k.GetPool(ctx, tc.poolID) - if tc.shouldErr { - suite.Require().Error(err) - } else { + pool, err := suite.k.GetPool(ctx, tc.poolID) + if tc.expFound { suite.Require().NoError(err) - suite.Require().Equal(tc.expFound, found) - if tc.expFound { - suite.Require().Equal(tc.expPool, pool) - } + suite.Require().Equal(tc.expPool, pool) + } else { + suite.Require().ErrorIs(err, collections.ErrNotFound) } if tc.check != nil { diff --git a/x/restaking/keeper/alias_functions.go b/x/restaking/keeper/alias_functions.go index 48101068..519944df 100644 --- a/x/restaking/keeper/alias_functions.go +++ b/x/restaking/keeper/alias_functions.go @@ -217,7 +217,7 @@ func (k *Keeper) GetDelegationForTarget( // GetDelegationTargetFromDelegation returns the target of the given delegation. func (k *Keeper) GetDelegationTargetFromDelegation( ctx context.Context, delegation types.Delegation, -) (types.DelegationTarget, bool, error) { +) (types.DelegationTarget, error) { switch delegation.Type { case types.DELEGATION_TYPE_POOL: return k.poolsKeeper.GetPool(ctx, delegation.TargetID) @@ -226,7 +226,7 @@ func (k *Keeper) GetDelegationTargetFromDelegation( case types.DELEGATION_TYPE_OPERATOR: return k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) default: - return nil, false, nil + return nil, nil } } @@ -525,15 +525,14 @@ func (k *Keeper) GetAllDelegations(ctx context.Context) ([]types.Delegation, err func (k *Keeper) GetAllUserRestakedCoins(ctx context.Context, userAddress string) (sdk.DecCoins, error) { totalDelegatedCoins := sdk.NewDecCoins() err := k.IterateUserDelegations(ctx, userAddress, func(d types.Delegation) (bool, error) { - target, found, err := k.GetDelegationTargetFromDelegation(ctx, d) + target, err := k.GetDelegationTargetFromDelegation(ctx, d) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return true, fmt.Errorf("can't find target for delegation %d, target id: %d", d.Type, d.TargetID) + } return true, err } - if !found { - return true, fmt.Errorf("can't find target for delegation %d, target id: %d", d.Type, d.TargetID) - } - totalDelegatedCoins = totalDelegatedCoins.Add(target.TokensFromShares(d.Shares)...) return false, nil }) @@ -628,33 +627,33 @@ func (k *Keeper) PerformDelegation(ctx context.Context, data types.DelegationDat func (k *Keeper) getUnbondingDelegationTarget(ctx context.Context, ubd types.UnbondingDelegation) (types.DelegationTarget, error) { switch ubd.Type { case types.DELEGATION_TYPE_POOL: - pool, found, err := k.poolsKeeper.GetPool(ctx, ubd.TargetID) + pool, err := k.poolsKeeper.GetPool(ctx, ubd.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, poolstypes.ErrPoolNotFound + } return nil, err } - if !found { - return nil, poolstypes.ErrPoolNotFound - } return pool, nil case types.DELEGATION_TYPE_OPERATOR: - operator, found, err := k.operatorsKeeper.GetOperator(ctx, ubd.TargetID) + operator, err := k.operatorsKeeper.GetOperator(ctx, ubd.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } return operator, nil case types.DELEGATION_TYPE_SERVICE: - service, found, err := k.servicesKeeper.GetService(ctx, ubd.TargetID) + service, err := k.servicesKeeper.GetService(ctx, ubd.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } return service, nil default: @@ -792,15 +791,14 @@ func (k *Keeper) UnbondRestakedAssets(ctx context.Context, user sdk.AccAddress, toUndelegateTokens := sdk.NewDecCoinsFromCoins(amount...) err := k.IterateUserDelegations(ctx, user.String(), func(delegation types.Delegation) (bool, error) { - target, found, err := k.GetDelegationTargetFromDelegation(ctx, delegation) + target, err := k.GetDelegationTargetFromDelegation(ctx, delegation) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return false, nil + } return true, err } - if !found { - return false, nil - } - // Compute the shares that this delegation should have to undelegate // all the remaining tokens involvedShares, err := target.SharesFromDecCoins(toUndelegateTokens) diff --git a/x/restaking/keeper/alias_functions_test.go b/x/restaking/keeper/alias_functions_test.go index 959b7233..98995535 100644 --- a/x/restaking/keeper/alias_functions_test.go +++ b/x/restaking/keeper/alias_functions_test.go @@ -812,7 +812,7 @@ func (suite *KeeperTestSuite) TestKeeper_UnbondRestakedAssets() { suite.Require().NoError(err) suite.Assert().True(found) suite.Assert().Equal(types.DELEGATION_TYPE_OPERATOR, del.Type) - operator, _, err := suite.ok.GetOperator(ctx, 1) + operator, err := suite.ok.GetOperator(ctx, 1) suite.Require().NoError(err) suite.Assert().Equal( sdk.NewDecCoins(sdk.NewInt64DecCoin("stake", 50)), diff --git a/x/restaking/keeper/grpc_query.go b/x/restaking/keeper/grpc_query.go index 990cd310..d5f58f40 100644 --- a/x/restaking/keeper/grpc_query.go +++ b/x/restaking/keeper/grpc_query.go @@ -38,15 +38,14 @@ func (k Querier) OperatorJoinedServices(ctx context.Context, req *types.QueryOpe return nil, status.Error(codes.InvalidArgument, "operator id cannot be 0") } - _, found, err := k.operatorsKeeper.GetOperator(ctx, req.OperatorId) + _, err := k.operatorsKeeper.GetOperator(ctx, req.OperatorId) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, status.Error(codes.InvalidArgument, "operator not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.InvalidArgument, "operator not found") - } - // Get the operator joined services serviceIDs, pageResponse, err := query.CollectionPaginate(ctx, k.operatorJoinedServices, req.Pagination, func(key collections.Pair[uint32, uint32], _ collections.NoValue) (uint32, error) { @@ -124,15 +123,14 @@ func (k Querier) ServiceOperators(ctx context.Context, req *types.QueryServiceOp return nil, status.Error(codes.InvalidArgument, "service id cannot be 0") } - _, found, err := k.servicesKeeper.GetService(ctx, req.ServiceId) + _, err := k.servicesKeeper.GetService(ctx, req.ServiceId) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "service not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "service not found") - } - eligibleOperators, pageResponse, err := query.CollectionFilteredPaginate(ctx, k.operatorJoinedServices.Indexes.Service, req.Pagination, // Filter to return only the operators that have joined the service and // that are allowed to validate it @@ -146,15 +144,14 @@ func (k Querier) ServiceOperators(ctx context.Context, req *types.QueryServiceOp // Here is k2 the operator id since the Service index provides association // between a service and the operator securing it operatorID := key.K2() - operator, found, err := k.operatorsKeeper.GetOperator(ctx, operatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, operatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return operatorstypes.Operator{}, errors.Wrapf( + operatorstypes.ErrOperatorNotFound, "operator %d not found", operatorID) + } return operatorstypes.Operator{}, err } - - if !found { - return operatorstypes.Operator{}, errors.Wrapf( - operatorstypes.ErrOperatorNotFound, "operator %d not found", operatorID) - } return operator, nil }, query.WithCollectionPaginationPairPrefix[uint32, uint32](req.ServiceId)) if err != nil { @@ -834,15 +831,14 @@ func (k Querier) DelegatorPools(ctx context.Context, req *types.QueryDelegatorPo return err } - pool, found, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) + pool, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return poolstypes.ErrPoolNotFound + } return err } - if !found { - return poolstypes.ErrPoolNotFound - } - pools = append(pools, pool) return nil @@ -879,15 +875,14 @@ func (k Querier) DelegatorPool(ctx context.Context, req *types.QueryDelegatorPoo return nil, status.Error(codes.NotFound, "pool delegation not found") } - pool, found, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) + pool, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "pool not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "pool not found") - } - return &types.QueryDelegatorPoolResponse{ Pool: pool, }, nil @@ -915,15 +910,14 @@ func (k Querier) DelegatorOperators(ctx context.Context, req *types.QueryDelegat return err } - operator, found, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) + operator, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return operatorstypes.ErrOperatorNotFound + } return err } - if !found { - return operatorstypes.ErrOperatorNotFound - } - operators = append(operators, operator) return nil @@ -960,15 +954,14 @@ func (k Querier) DelegatorOperator(ctx context.Context, req *types.QueryDelegato return nil, status.Error(codes.NotFound, "operator delegation not found") } - operator, found, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) + operator, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "operator not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "operator not found") - } - return &types.QueryDelegatorOperatorResponse{ Operator: operator, }, nil @@ -996,15 +989,14 @@ func (k Querier) DelegatorServices(ctx context.Context, req *types.QueryDelegato return err } - pool, found, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) + pool, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return servicestypes.ErrServiceNotFound + } return err } - if !found { - return servicestypes.ErrServiceNotFound - } - services = append(services, pool) return nil @@ -1041,15 +1033,14 @@ func (k Querier) DelegatorService(ctx context.Context, req *types.QueryDelegator return nil, status.Error(codes.NotFound, "service delegation not found") } - service, found, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) + service, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "service not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "service not found") - } - return &types.QueryDelegatorServiceResponse{ Service: service, }, nil @@ -1089,45 +1080,42 @@ func (k Querier) Params(ctx context.Context, _ *types.QueryParamsRequest) (*type // PoolDelegationToPoolDelegationResponse converts a PoolDelegation to a PoolDelegationResponse func PoolDelegationToPoolDelegationResponse(ctx context.Context, k *Keeper, delegation types.Delegation) (types.DelegationResponse, error) { - pool, found, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) + pool, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.DelegationResponse{}, poolstypes.ErrPoolNotFound + } return types.DelegationResponse{}, err } - if !found { - return types.DelegationResponse{}, poolstypes.ErrPoolNotFound - } - truncatedBalance, _ := pool.TokensFromShares(delegation.Shares).TruncateDecimal() return types.NewDelegationResponse(delegation, truncatedBalance), nil } // OperatorDelegationToOperatorDelegationResponse converts a OperatorDelegation to a OperatorDelegationResponse func OperatorDelegationToOperatorDelegationResponse(ctx context.Context, k *Keeper, delegation types.Delegation) (types.DelegationResponse, error) { - operator, found, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) + operator, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.DelegationResponse{}, operatorstypes.ErrOperatorNotFound + } return types.DelegationResponse{}, err } - if !found { - return types.DelegationResponse{}, operatorstypes.ErrOperatorNotFound - } - truncatedBalance, _ := operator.TokensFromShares(delegation.Shares).TruncateDecimal() return types.NewDelegationResponse(delegation, truncatedBalance), nil } // ServiceDelegationToServiceDelegationResponse converts a ServiceDelegation to a ServiceDelegationResponse func ServiceDelegationToServiceDelegationResponse(ctx context.Context, k *Keeper, delegation types.Delegation) (types.DelegationResponse, error) { - service, found, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) + service, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.DelegationResponse{}, servicestypes.ErrServiceNotFound + } return types.DelegationResponse{}, err } - if !found { - return types.DelegationResponse{}, servicestypes.ErrServiceNotFound - } - truncatedBalance, _ := service.TokensFromShares(delegation.Shares).TruncateDecimal() return types.NewDelegationResponse(delegation, truncatedBalance), nil } diff --git a/x/restaking/keeper/invariants.go b/x/restaking/keeper/invariants.go index db68a9c0..d1522d3c 100644 --- a/x/restaking/keeper/invariants.go +++ b/x/restaking/keeper/invariants.go @@ -1,8 +1,10 @@ package keeper import ( + "errors" "fmt" + "cosmossdk.io/collections" sdk "github.com/cosmos/cosmos-sdk/types" operatorstypes "github.com/milkyway-labs/milkyway/v2/x/operators/types" @@ -182,15 +184,14 @@ func PoolsDelegatorsSharesInvariant(k *Keeper) sdk.Invariant { } for poolID, delegatorsShares := range poolsDelegatorsShares { - pool, found, err := k.poolsKeeper.GetPool(ctx, poolID) + pool, err := k.poolsKeeper.GetPool(ctx, poolID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + panic(fmt.Errorf("pool with id %d not found", poolID)) + } panic(err) } - if !found { - panic(fmt.Errorf("pool with id %d not found", poolID)) - } - sharesAmount := delegatorsShares.AmountOf(pool.GetSharesDenom(pool.Denom)) if !pool.DelegatorShares.Equal(sharesAmount) { broken = true @@ -266,15 +267,14 @@ func OperatorsDelegatorsSharesInvariant(k *Keeper) sdk.Invariant { } for operatorID, delegatorsShares := range operatorsDelegatorsShares { - operator, found, err := k.operatorsKeeper.GetOperator(ctx, operatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, operatorID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + panic(fmt.Errorf("operator with id %d not found", operatorID)) + } panic(err) } - if !found { - panic(fmt.Errorf("operator with id %d not found", operatorID)) - } - if !operator.DelegatorShares.Equal(delegatorsShares) { broken = true msg += fmt.Sprintf("operator %d total shares: %v, delegators shares: %v\n", operatorID, operator.DelegatorShares, delegatorsShares) @@ -348,15 +348,14 @@ func ServicesDelegatorsSharesInvariant(k *Keeper) sdk.Invariant { } for serviceID, delegatorsShares := range servicesDelegatorsShares { - service, found, err := k.servicesKeeper.GetService(ctx, serviceID) + service, err := k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + panic(fmt.Errorf("service with id %d not found", serviceID)) + } panic(err) } - if !found { - panic(fmt.Errorf("service with id %d not found", serviceID)) - } - if !service.DelegatorShares.Equal(delegatorsShares) { broken = true msg += fmt.Sprintf("service %d total shares: %v, delegators shares: %v\n", serviceID, service.DelegatorShares, delegatorsShares) @@ -374,15 +373,13 @@ func AllowedOperatorsExistInvariant(k *Keeper) sdk.Invariant { // Iterate over all the services joined by operators var notFoundOperatorsIDs []uint32 err := k.IterateAllServicesAllowedOperators(ctx, func(serviceID uint32, operatorID uint32) (stop bool, err error) { - _, found, err := k.operatorsKeeper.GetOperator(ctx, serviceID) + _, err = k.operatorsKeeper.GetOperator(ctx, serviceID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + notFoundOperatorsIDs = append(notFoundOperatorsIDs, operatorID) + } return true, err } - - if !found { - notFoundOperatorsIDs = append(notFoundOperatorsIDs, operatorID) - } - return false, nil }) if err != nil { @@ -405,15 +402,13 @@ func OperatorsJoinedServicesExistInvariant(k *Keeper) sdk.Invariant { // Iterate over all the operators joined services var notFoundServicesIDs []uint32 err := k.IterateAllOperatorsJoinedServices(ctx, func(operatorID uint32, serviceID uint32) (stop bool, err error) { - _, found, err := k.servicesKeeper.GetService(ctx, serviceID) + _, err = k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + notFoundServicesIDs = append(notFoundServicesIDs, serviceID) + } return false, err } - - if !found { - notFoundServicesIDs = append(notFoundServicesIDs, serviceID) - } - return false, nil }) if err != nil { diff --git a/x/restaking/keeper/msg_server.go b/x/restaking/keeper/msg_server.go index 22cc73f5..1db0db8f 100644 --- a/x/restaking/keeper/msg_server.go +++ b/x/restaking/keeper/msg_server.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "cosmossdk.io/collections" "cosmossdk.io/errors" "github.com/cosmos/cosmos-sdk/telemetry" sdk "github.com/cosmos/cosmos-sdk/types" @@ -30,28 +31,26 @@ func NewMsgServer(keeper *Keeper) types.MsgServer { // JoinService defines the rpc method for Msg/JoinService func (k msgServer) JoinService(ctx context.Context, msg *types.MsgJoinService) (*types.MsgJoinServiceResponse, error) { - operator, found, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } - if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can join the service") } - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, errors.Wrapf(sdkerrors.ErrNotFound, "service %d not found", msg.ServiceID) + } return nil, err } - if !found { - return nil, errors.Wrapf(sdkerrors.ErrNotFound, "service %d not found", msg.ServiceID) - } - if !service.IsActive() { return nil, errors.Wrapf(servicestypes.ErrServiceNotActive, "service %d is not active", msg.ServiceID) } @@ -75,28 +74,26 @@ func (k msgServer) JoinService(ctx context.Context, msg *types.MsgJoinService) ( // LeaveService defines the rpc method for Msg/LeaveService func (k msgServer) LeaveService(ctx context.Context, msg *types.MsgLeaveService) (*types.MsgLeaveServiceResponse, error) { - operator, found, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } - if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can leave the service") } - _, found, err = k.servicesKeeper.GetService(ctx, msg.ServiceID) + _, err = k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, errors.Wrapf(sdkerrors.ErrNotFound, "service %d not found", msg.ServiceID) + } return nil, err } - if !found { - return nil, errors.Wrapf(sdkerrors.ErrNotFound, "service %d not found", msg.ServiceID) - } - err = k.RemoveServiceFromOperatorJoinedServices(ctx, msg.OperatorID, msg.ServiceID) if err != nil { return nil, err @@ -117,25 +114,23 @@ func (k msgServer) LeaveService(ctx context.Context, msg *types.MsgLeaveService) // AddOperatorToAllowList defines the rpc method for Msg/AddOperatorToAllowList func (k msgServer) AddOperatorToAllowList(ctx context.Context, msg *types.MsgAddOperatorToAllowList) (*types.MsgAddOperatorToAllowListResponse, error) { // Ensure that the service exists - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - // Ensure that the operator exists - _, found, err = k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) + _, err = k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } - // Ensure the service admin is performing this action if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the service admin can allow an operator") @@ -171,15 +166,14 @@ func (k msgServer) AddOperatorToAllowList(ctx context.Context, msg *types.MsgAdd // RemoveOperatorFromAllowlist defines the rpc method for Msg/RemoveOperatorFromAllowlist func (k msgServer) RemoveOperatorFromAllowlist(ctx context.Context, msg *types.MsgRemoveOperatorFromAllowlist) (*types.MsgRemoveOperatorFromAllowlistResponse, error) { // Ensure that the service exists - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - // Ensure the service admin is performing this action if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the service admin can allow an operator") @@ -215,30 +209,28 @@ func (k msgServer) RemoveOperatorFromAllowlist(ctx context.Context, msg *types.M // BorrowPoolSecurity defines the rpc method for Msg/BorrowPoolSecurity func (k msgServer) BorrowPoolSecurity(ctx context.Context, msg *types.MsgBorrowPoolSecurity) (*types.MsgBorrowPoolSecurityResponse, error) { // Ensure that the service exists - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - // Ensure the service is active if !service.IsActive() { return nil, errors.Wrapf(servicestypes.ErrServiceNotActive, "service %d is not active", msg.ServiceID) } // Ensure that the pool exists - _, found, err = k.poolsKeeper.GetPool(ctx, msg.PoolID) + _, err = k.poolsKeeper.GetPool(ctx, msg.PoolID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, poolstypes.ErrPoolNotFound + } return nil, err } - if !found { - return nil, poolstypes.ErrPoolNotFound - } - // Ensure the service admin is performing this action if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, @@ -276,15 +268,14 @@ func (k msgServer) BorrowPoolSecurity(ctx context.Context, msg *types.MsgBorrowP // CeasePoolSecurityBorrow defines the rpc method for Msg/CeasePoolSecurityBorrow func (k msgServer) CeasePoolSecurityBorrow(ctx context.Context, msg *types.MsgCeasePoolSecurityBorrow) (*types.MsgCeasePoolSecurityBorrowResponse, error) { // Ensure that the service exists - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - // Ensure the service admin is performing this action if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, diff --git a/x/restaking/keeper/operator_restaking.go b/x/restaking/keeper/operator_restaking.go index a2ec3332..6ee967af 100644 --- a/x/restaking/keeper/operator_restaking.go +++ b/x/restaking/keeper/operator_restaking.go @@ -81,15 +81,14 @@ func (k *Keeper) RemoveOperatorDelegation(ctx context.Context, delegation types. // DelegateToOperator sends the given amount to the operator account and saves the delegation for the given user func (k *Keeper) DelegateToOperator(ctx context.Context, operatorID uint32, amount sdk.Coins, delegator string) (sdk.DecCoins, error) { // Get the operator - operator, found, err := k.operatorsKeeper.GetOperator(ctx, operatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, operatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return sdk.NewDecCoins(), operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return sdk.NewDecCoins(), operatorstypes.ErrOperatorNotFound - } - restakableDenoms, err := k.GetRestakableDenoms(ctx) if err != nil { return nil, err @@ -164,15 +163,14 @@ func (k *Keeper) GetOperatorUnbondingDelegation(ctx context.Context, operatorID // unbonding delegation for the given user func (k *Keeper) UndelegateFromOperator(ctx context.Context, operatorID uint32, amount sdk.Coins, delegator string) (time.Time, error) { // Find the operator - operator, found, err := k.operatorsKeeper.GetOperator(ctx, operatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, operatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return time.Time{}, operatorstypes.ErrOperatorNotFound + } return time.Time{}, err } - if !found { - return time.Time{}, operatorstypes.ErrOperatorNotFound - } - // Get the shares shares, err := k.ValidateUnbondAmount(ctx, delegator, operator, amount) if err != nil { diff --git a/x/restaking/keeper/operator_restaking_test.go b/x/restaking/keeper/operator_restaking_test.go index 3ff22c66..006e2fd8 100644 --- a/x/restaking/keeper/operator_restaking_test.go +++ b/x/restaking/keeper/operator_restaking_test.go @@ -367,9 +367,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToOperator() { ), check: func(ctx sdk.Context) { // Make sure the operator now exists - operator, found, err := suite.ok.GetOperator(ctx, 1) + operator, err := suite.ok.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(operatorstypes.Operator{ ID: 1, Status: operatorstypes.OPERATOR_STATUS_ACTIVE, @@ -455,9 +454,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToOperator() { ), check: func(ctx sdk.Context) { // Make sure the operator now exists - operator, found, err := suite.ok.GetOperator(ctx, 1) + operator, err := suite.ok.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(operatorstypes.Operator{ ID: 1, Status: operatorstypes.OPERATOR_STATUS_ACTIVE, @@ -562,9 +560,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToOperator() { ), check: func(ctx sdk.Context) { // Make sure the operator now exists - operator, found, err := suite.ok.GetOperator(ctx, 1) + operator, err := suite.ok.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(operatorstypes.Operator{ ID: 1, Status: operatorstypes.OPERATOR_STATUS_ACTIVE, diff --git a/x/restaking/keeper/operators_hooks.go b/x/restaking/keeper/operators_hooks.go index 13ee6242..2da4abac 100644 --- a/x/restaking/keeper/operators_hooks.go +++ b/x/restaking/keeper/operators_hooks.go @@ -2,6 +2,7 @@ package keeper import ( "context" + "errors" "fmt" "cosmossdk.io/collections" @@ -85,15 +86,14 @@ func (o *OperatorsHooks) removeOperatorFromServicesAllowList(ctx context.Context return err } if !isConfigured { - service, found, err := o.servicesKeeper.GetService(ctx, serviceID) + service, err := o.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return fmt.Errorf("service %d not found", serviceID) + } return err } - if !found { - return fmt.Errorf("service %d not found", serviceID) - } - if !service.IsActive() { // The service is not active, nothing to do continue diff --git a/x/restaking/keeper/operators_hooks_test.go b/x/restaking/keeper/operators_hooks_test.go index 52e288c1..dd98ef4d 100644 --- a/x/restaking/keeper/operators_hooks_test.go +++ b/x/restaking/keeper/operators_hooks_test.go @@ -82,9 +82,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().False(joined) // Ensure that the service is status has not changed - service, found, err := suite.sk.GetService(ctx, 2) + service, err := suite.sk.GetService(ctx, 2) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_ACTIVE, service.Status) }, operatorID: 1, @@ -131,9 +130,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().True(joined) // Ensure that the service is status has not changed - service, found, err := suite.sk.GetService(ctx, 2) + service, err := suite.sk.GetService(ctx, 2) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_ACTIVE, service.Status) }, operatorID: 1, @@ -164,9 +162,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().False(joined) // Ensure that the service is now inactive - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_INACTIVE, service.Status) }, operatorID: 1, @@ -200,9 +197,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().False(joined) // Ensure the service status has not changed - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_ACTIVE, service.Status) }, operatorID: 1, @@ -233,9 +229,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().False(joined) // Ensure the service status has not changed - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_CREATED, service.Status) }, operatorID: 1, @@ -266,9 +261,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().False(joined) // Ensure the service status has not changed - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_INACTIVE, service.Status) }, operatorID: 1, diff --git a/x/restaking/keeper/pool_restaking_test.go b/x/restaking/keeper/pool_restaking_test.go index 8bd7e8b7..ea98a5cf 100644 --- a/x/restaking/keeper/pool_restaking_test.go +++ b/x/restaking/keeper/pool_restaking_test.go @@ -308,9 +308,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToPool() { expShares: sdk.NewDecCoins(sdk.NewDecCoinFromDec("pool/1/umilk", sdkmath.LegacyNewDec(100))), check: func(ctx sdk.Context) { // Make sure the pool now exists - pool, found, err := suite.pk.GetPool(ctx, 1) + pool, err := suite.pk.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(poolstypes.Pool{ ID: 1, Denom: "umilk", @@ -374,9 +373,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToPool() { expShares: sdk.NewDecCoins(sdk.NewDecCoinFromDec("pool/1/umilk", sdkmath.LegacyNewDec(500))), check: func(ctx sdk.Context) { // Make sure the pool now exists - pool, found, err := suite.pk.GetPool(ctx, 1) + pool, err := suite.pk.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(poolstypes.Pool{ ID: 1, Denom: "umilk", @@ -448,9 +446,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToPool() { expShares: sdk.NewDecCoins(sdk.NewDecCoinFromDec("pool/1/umilk", sdkmath.LegacyNewDecWithPrec(15625, 2))), check: func(ctx sdk.Context) { // Make sure the pool now exists - pool, found, err := suite.pk.GetPool(ctx, 1) + pool, err := suite.pk.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(poolstypes.Pool{ ID: 1, Denom: "umilk", diff --git a/x/restaking/keeper/service_restaking.go b/x/restaking/keeper/service_restaking.go index 3799fbc2..7fbe743a 100644 --- a/x/restaking/keeper/service_restaking.go +++ b/x/restaking/keeper/service_restaking.go @@ -214,15 +214,14 @@ func (k *Keeper) RemoveServiceDelegation(ctx context.Context, delegation types.D // DelegateToService sends the given amount to the service account and saves the delegation for the given user func (k *Keeper) DelegateToService(ctx context.Context, serviceID uint32, amount sdk.Coins, delegator string) (sdk.DecCoins, error) { // Get the service - service, found, err := k.servicesKeeper.GetService(ctx, serviceID) + service, err := k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return sdk.NewDecCoins(), servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return sdk.NewDecCoins(), servicestypes.ErrServiceNotFound - } - restakableDenoms, err := k.GetRestakableDenoms(ctx) if err != nil { return nil, err @@ -317,15 +316,14 @@ func (k *Keeper) GetServiceUnbondingDelegation(ctx context.Context, serviceID ui // unbonding delegation for the given user func (k *Keeper) UndelegateFromService(ctx context.Context, serviceID uint32, amount sdk.Coins, delegator string) (time.Time, error) { // Find the service - service, found, err := k.servicesKeeper.GetService(ctx, serviceID) + service, err := k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return time.Time{}, servicestypes.ErrServiceNotFound + } return time.Time{}, err } - if !found { - return time.Time{}, servicestypes.ErrServiceNotFound - } - // Get the shares shares, err := k.ValidateUnbondAmount(ctx, delegator, service, amount) if err != nil { diff --git a/x/restaking/keeper/service_restaking_test.go b/x/restaking/keeper/service_restaking_test.go index 4c72dce0..e5ef47f9 100644 --- a/x/restaking/keeper/service_restaking_test.go +++ b/x/restaking/keeper/service_restaking_test.go @@ -717,9 +717,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToService() { expShares: sdk.NewDecCoins(sdk.NewDecCoinFromDec("service/1/umilk", sdkmath.LegacyNewDec(500))), check: func(ctx sdk.Context) { // Make sure the service now exists - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.Service{ ID: 1, Status: servicestypes.SERVICE_STATUS_ACTIVE, @@ -804,9 +803,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToService() { expShares: sdk.NewDecCoins(sdk.NewDecCoinFromDec("service/1/uinit", sdkmath.LegacyNewDec(100))), check: func(ctx sdk.Context) { // Make sure the service now exists - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.Service{ ID: 1, Status: servicestypes.SERVICE_STATUS_ACTIVE, @@ -912,9 +910,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToService() { ), check: func(ctx sdk.Context) { // Make sure the service now exists - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.Service{ ID: 1, Status: servicestypes.SERVICE_STATUS_ACTIVE, diff --git a/x/restaking/types/expected_keepers.go b/x/restaking/types/expected_keepers.go index 62afe412..53faa944 100644 --- a/x/restaking/types/expected_keepers.go +++ b/x/restaking/types/expected_keepers.go @@ -24,14 +24,14 @@ type BankKeeper interface { type PoolsKeeper interface { GetPoolByDenom(ctx context.Context, denom string) (poolstypes.Pool, bool, error) CreateOrGetPoolByDenom(ctx context.Context, denom string) (poolstypes.Pool, error) - GetPool(ctx context.Context, poolID uint32) (poolstypes.Pool, bool, error) + GetPool(ctx context.Context, poolID uint32) (poolstypes.Pool, error) SavePool(ctx context.Context, pool poolstypes.Pool) error IteratePools(ctx context.Context, cb func(poolstypes.Pool) (bool, error)) error GetPools(ctx context.Context) ([]poolstypes.Pool, error) } type OperatorsKeeper interface { - GetOperator(ctx context.Context, operatorID uint32) (operatorstypes.Operator, bool, error) + GetOperator(ctx context.Context, operatorID uint32) (operatorstypes.Operator, error) SaveOperator(ctx context.Context, operator operatorstypes.Operator) error IterateOperators(ctx context.Context, cb func(operatorstypes.Operator) (bool, error)) error GetOperators(ctx context.Context) ([]operatorstypes.Operator, error) @@ -41,7 +41,7 @@ type OperatorsKeeper interface { type ServicesKeeper interface { HasService(ctx context.Context, serviceID uint32) (bool, error) - GetService(ctx context.Context, serviceID uint32) (servicestypes.Service, bool, error) + GetService(ctx context.Context, serviceID uint32) (servicestypes.Service, error) SaveService(ctx context.Context, service servicestypes.Service) error IterateServices(ctx context.Context, cb func(servicestypes.Service) (bool, error)) error GetServices(ctx context.Context) ([]servicestypes.Service, error) diff --git a/x/rewards/keeper/allocation.go b/x/rewards/keeper/allocation.go index 91251d74..a52f4bf3 100644 --- a/x/rewards/keeper/allocation.go +++ b/x/rewards/keeper/allocation.go @@ -197,15 +197,14 @@ func (k *Keeper) AllocateRewardsByPlan( return err } - service, found, err := k.servicesKeeper.GetService(ctx, plan.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, plan.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return servicestypes.ErrServiceNotFound + } return err } - if !found { - return servicestypes.ErrServiceNotFound - } - // Ensure that we are distribution rewards only for active services if !service.IsActive() { return nil diff --git a/x/rewards/keeper/allocation_test.go b/x/rewards/keeper/allocation_test.go index 2b2e41e3..4796dd81 100644 --- a/x/rewards/keeper/allocation_test.go +++ b/x/rewards/keeper/allocation_test.go @@ -970,7 +970,7 @@ func (suite *KeeperTestSuite) TestAllocateRewards_InactiveOperator() { suite.Require().NoError(err) // Refresh the updated state of operator 2. - operator2, _, err = suite.operatorsKeeper.GetOperator(ctx, operator2.ID) + operator2, err = suite.operatorsKeeper.GetOperator(ctx, operator2.ID) suite.Require().NoError(err) // Operator 2 becomes inactive. err = suite.operatorsKeeper.StartOperatorInactivation(ctx, operator2) diff --git a/x/rewards/keeper/common_test.go b/x/rewards/keeper/common_test.go index 4666c575..8ae40b82 100644 --- a/x/rewards/keeper/common_test.go +++ b/x/rewards/keeper/common_test.go @@ -145,9 +145,8 @@ func (suite *KeeperTestSuite) CreateService(ctx sdk.Context, name string, admin _, err = servicesMsgServer.ActivateService(ctx, servicestypes.NewMsgActivateService(resp.NewServiceID, admin)) suite.Require().NoError(err) - service, found, err := suite.servicesKeeper.GetService(ctx, resp.NewServiceID) + service, err := suite.servicesKeeper.GetService(ctx, resp.NewServiceID) suite.Require().NoError(err) - suite.Require().True(found, "service must be found") return service } @@ -167,9 +166,8 @@ func (suite *KeeperTestSuite) CreateOperator(ctx sdk.Context, name string, admin suite.Require().NoError(err) // Make sure the operator is found - operator, found, err := suite.operatorsKeeper.GetOperator(ctx, resp.NewOperatorID) + operator, err := suite.operatorsKeeper.GetOperator(ctx, resp.NewOperatorID) suite.Require().NoError(err) - suite.Require().True(found, "operator must be found") return operator } @@ -182,9 +180,8 @@ func (suite *KeeperTestSuite) UpdateOperatorParams( joinedServicesIDs []uint32, ) { // Make sure the operator is found - _, found, err := suite.operatorsKeeper.GetOperator(ctx, operatorID) + _, err := suite.operatorsKeeper.GetOperator(ctx, operatorID) suite.Require().NoError(err) - suite.Require().True(found, "operator must be found") // Sets the operator commission rate err = suite.operatorsKeeper.SaveOperatorParams(ctx, operatorID, operatorstypes.NewOperatorParams(commissionRate)) @@ -221,9 +218,8 @@ func (suite *KeeperTestSuite) AddPoolsToServiceSecuringPools( whitelistedPoolsIDs []uint32, ) { // Make sure the service is found - _, found, err := suite.servicesKeeper.GetService(ctx, serviceID) + _, err := suite.servicesKeeper.GetService(ctx, serviceID) suite.Require().NoError(err) - suite.Require().True(found, "service must be found") for _, poolID := range whitelistedPoolsIDs { err := suite.restakingKeeper.AddPoolToServiceSecuringPools(ctx, serviceID, poolID) @@ -239,9 +235,8 @@ func (suite *KeeperTestSuite) AddOperatorsToServiceAllowList( allowedOperatorsID []uint32, ) { // Make sure the service is found - _, found, err := suite.servicesKeeper.GetService(ctx, serviceID) + _, err := suite.servicesKeeper.GetService(ctx, serviceID) suite.Require().NoError(err) - suite.Require().True(found, "service must be found") for _, operatorID := range allowedOperatorsID { err := suite.restakingKeeper.AddOperatorToServiceAllowList(ctx, serviceID, operatorID) @@ -263,9 +258,8 @@ func (suite *KeeperTestSuite) CreateRewardsPlan( usersDistr rewardstypes.UsersDistribution, initialRewards sdk.Coins, ) rewardstypes.RewardsPlan { - service, found, err := suite.servicesKeeper.GetService(ctx, serviceID) + service, err := suite.servicesKeeper.GetService(ctx, serviceID) suite.Require().NoError(err) - suite.Require().True(found, "service must be found") rewardsMsgServer := keeper.NewMsgServer(suite.keeper) resp, err := rewardsMsgServer.CreateRewardsPlan(ctx, rewardstypes.NewMsgCreateRewardsPlan( diff --git a/x/rewards/keeper/hooks.go b/x/rewards/keeper/hooks.go index 67de03e9..0331926c 100644 --- a/x/rewards/keeper/hooks.go +++ b/x/rewards/keeper/hooks.go @@ -2,7 +2,9 @@ package keeper import ( "context" + "errors" + "cosmossdk.io/collections" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" restakingtypes "github.com/milkyway-labs/milkyway/v2/x/restaking/types" @@ -133,15 +135,14 @@ func (k *Keeper) AfterDelegationModified(ctx context.Context, delType restakingt // AfterServiceAccreditationModified implements servicestypes.ServicesHooks func (k *Keeper) AfterServiceAccreditationModified(ctx context.Context, serviceID uint32) error { - service, found, err := k.servicesKeeper.GetService(ctx, serviceID) + service, err := k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return servicestypes.ErrServiceNotFound + } return err } - if !found { - return servicestypes.ErrServiceNotFound - } - err = k.restakingKeeper.IterateServiceDelegations(ctx, serviceID, func(del restakingtypes.Delegation) (stop bool, err error) { preferences, err := k.restakingKeeper.GetUserPreferences(ctx, del.UserAddress) if err != nil { diff --git a/x/rewards/keeper/msg_server.go b/x/rewards/keeper/msg_server.go index a8c57635..32895bee 100644 --- a/x/rewards/keeper/msg_server.go +++ b/x/rewards/keeper/msg_server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "cosmossdk.io/collections" "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -27,15 +28,14 @@ func NewMsgServer(k *Keeper) types.MsgServer { // CreateRewardsPlan defines the rpc method for Msg/CreateRewardsPlan func (k msgServer) CreateRewardsPlan(ctx context.Context, msg *types.MsgCreateRewardsPlan) (*types.MsgCreateRewardsPlanResponse, error) { // Make sure the creator is the admin of the service - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - if msg.Sender != service.Admin { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only service admin can create rewards plan") } @@ -103,15 +103,14 @@ func (k msgServer) EditRewardsPlan(ctx context.Context, msg *types.MsgEditReward } // Get the service to which the rewards is associated - service, found, err := k.servicesKeeper.GetService(ctx, rewardsPlan.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, rewardsPlan.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - // Make sure the editor is the admin of the service if msg.Sender != service.Admin { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only service admin can create rewards plan") @@ -212,15 +211,14 @@ func (k msgServer) WithdrawOperatorCommission(ctx context.Context, msg *types.Ms return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid sender address: %s", err) } - operator, found, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } - if msg.Sender != operator.Admin { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only operator admin can withdraw operator commission") } diff --git a/x/rewards/keeper/rewards_plan.go b/x/rewards/keeper/rewards_plan.go index 010c5cee..550df4ed 100644 --- a/x/rewards/keeper/rewards_plan.go +++ b/x/rewards/keeper/rewards_plan.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "cosmossdk.io/collections" "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -24,15 +25,14 @@ func (k *Keeper) CreateRewardsPlan( operatorsDistribution types.Distribution, usersDistribution types.UsersDistribution, ) (types.RewardsPlan, error) { - _, found, err := k.servicesKeeper.GetService(ctx, serviceID) + _, err := k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.RewardsPlan{}, servicestypes.ErrServiceNotFound + } return types.RewardsPlan{}, err } - if !found { - return types.RewardsPlan{}, servicestypes.ErrServiceNotFound - } - // Get the plan id to be used planID, err := k.NextRewardsPlanID.Get(ctx) if err != nil { @@ -130,15 +130,14 @@ func (k *Keeper) terminateRewardsPlan(ctx context.Context, plan types.RewardsPla remaining := k.bankKeeper.GetAllBalances(ctx, rewardsPoolAddr) if remaining.IsAllPositive() { // Get the service's address. - service, found, err := k.servicesKeeper.GetService(ctx, plan.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, plan.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return servicestypes.ErrServiceNotFound + } return err } - if !found { - return servicestypes.ErrServiceNotFound - } - serviceAddr, err := k.accountKeeper.AddressCodec().StringToBytes(service.Address) if err != nil { return err diff --git a/x/rewards/keeper/target.go b/x/rewards/keeper/target.go index b0ab5313..39a3b17b 100644 --- a/x/rewards/keeper/target.go +++ b/x/rewards/keeper/target.go @@ -35,13 +35,13 @@ func (k *Keeper) GetDelegationTarget( ) (DelegationTarget, error) { switch delType { case restakingtypes.DELEGATION_TYPE_POOL: - pool, found, err := k.poolsKeeper.GetPool(ctx, targetID) + pool, err := k.poolsKeeper.GetPool(ctx, targetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return DelegationTarget{}, poolstypes.ErrPoolNotFound + } return DelegationTarget{}, err } - if !found { - return DelegationTarget{}, poolstypes.ErrPoolNotFound - } return DelegationTarget{ DelegationTarget: pool, DelegationType: delType, @@ -51,13 +51,13 @@ func (k *Keeper) GetDelegationTarget( OutstandingRewards: k.PoolOutstandingRewards, }, nil case restakingtypes.DELEGATION_TYPE_OPERATOR: - operator, found, err := k.operatorsKeeper.GetOperator(ctx, targetID) + operator, err := k.operatorsKeeper.GetOperator(ctx, targetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return DelegationTarget{}, operatorstypes.ErrOperatorNotFound + } return DelegationTarget{}, err } - if !found { - return DelegationTarget{}, operatorstypes.ErrOperatorNotFound - } return DelegationTarget{ DelegationTarget: operator, DelegationType: delType, @@ -67,13 +67,13 @@ func (k *Keeper) GetDelegationTarget( OutstandingRewards: k.OperatorOutstandingRewards, }, nil case restakingtypes.DELEGATION_TYPE_SERVICE: - service, found, err := k.servicesKeeper.GetService(ctx, targetID) + service, err := k.servicesKeeper.GetService(ctx, targetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return DelegationTarget{}, servicestypes.ErrServiceNotFound + } return DelegationTarget{}, err } - if !found { - return DelegationTarget{}, servicestypes.ErrServiceNotFound - } return DelegationTarget{ DelegationTarget: service, DelegationType: delType, diff --git a/x/rewards/keeper/withdraw.go b/x/rewards/keeper/withdraw.go index 43b978a3..338762ae 100644 --- a/x/rewards/keeper/withdraw.go +++ b/x/rewards/keeper/withdraw.go @@ -2,8 +2,10 @@ package keeper import ( "context" + "errors" "fmt" + "cosmossdk.io/collections" errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -63,15 +65,14 @@ func (k *Keeper) WithdrawDelegationRewards( // WithdrawOperatorCommission withdraws the operator's accumulated commission func (k *Keeper) WithdrawOperatorCommission(ctx context.Context, operatorID uint32) (types.Pools, error) { - operator, found, err := k.operatorsKeeper.GetOperator(ctx, operatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, operatorID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } - // Fetch the operator accumulated commission accumCommission, err := k.OperatorAccumulatedCommissions.Get(ctx, operatorID) if err != nil { diff --git a/x/rewards/types/expected_keepers.go b/x/rewards/types/expected_keepers.go index ceef2088..28b57156 100644 --- a/x/rewards/types/expected_keepers.go +++ b/x/rewards/types/expected_keepers.go @@ -43,20 +43,20 @@ type OracleKeeper interface { } type PoolsKeeper interface { - GetPool(ctx context.Context, poolID uint32) (poolstypes.Pool, bool, error) + GetPool(ctx context.Context, poolID uint32) (poolstypes.Pool, error) GetPools(ctx context.Context) ([]poolstypes.Pool, error) IteratePools(ctx context.Context, cb func(pool poolstypes.Pool) (stop bool, err error)) error } type OperatorsKeeper interface { - GetOperator(ctx context.Context, operatorID uint32) (operatorstypes.Operator, bool, error) + GetOperator(ctx context.Context, operatorID uint32) (operatorstypes.Operator, error) GetOperators(ctx context.Context) ([]operatorstypes.Operator, error) IterateOperators(ctx context.Context, cb func(operator operatorstypes.Operator) (stop bool, err error)) error GetOperatorParams(ctx context.Context, operatorID uint32) (operatorstypes.OperatorParams, error) } type ServicesKeeper interface { - GetService(ctx context.Context, serviceID uint32) (servicestypes.Service, bool, error) + GetService(ctx context.Context, serviceID uint32) (servicestypes.Service, error) GetServiceParams(ctx context.Context, serviceID uint32) (servicestypes.ServiceParams, error) IterateServices(ctx context.Context, cb func(service servicestypes.Service) (stop bool, err error)) error } diff --git a/x/services/keeper/grpc_query.go b/x/services/keeper/grpc_query.go index 44fdce1c..5e37f0b7 100644 --- a/x/services/keeper/grpc_query.go +++ b/x/services/keeper/grpc_query.go @@ -2,7 +2,9 @@ package keeper import ( "context" + "errors" + "cosmossdk.io/collections" "github.com/cosmos/cosmos-sdk/types/query" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -18,15 +20,14 @@ func (k *Keeper) Service(ctx context.Context, request *types.QueryServiceRequest return nil, status.Error(codes.InvalidArgument, "invalid service ID") } - service, found, err := k.GetService(ctx, request.ServiceId) + service, err := k.GetService(ctx, request.ServiceId) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "service not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "service not found") - } - return &types.QueryServiceResponse{Service: service}, nil } @@ -54,15 +55,14 @@ func (k *Keeper) ServiceParams(ctx context.Context, request *types.QueryServiceP } // Ensure the service exists - _, found, err := k.GetService(ctx, request.ServiceId) + _, err := k.GetService(ctx, request.ServiceId) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "service not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "service not found") - } - // Get the service params serviceParams, err := k.GetServiceParams(ctx, request.ServiceId) if err != nil { diff --git a/x/services/keeper/msg_server.go b/x/services/keeper/msg_server.go index 67002ee2..698c8acb 100644 --- a/x/services/keeper/msg_server.go +++ b/x/services/keeper/msg_server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "cosmossdk.io/collections" "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -104,15 +105,14 @@ func (k msgServer) CreateService(goCtx context.Context, msg *types.MsgCreateServ // UpdateService defines the rpc method for Msg/UpdateService func (k msgServer) UpdateService(ctx context.Context, msg *types.MsgUpdateService) (*types.MsgUpdateServiceResponse, error) { // Check if the service exists - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, errors.Wrapf(sdkerrors.ErrInvalidRequest, "service with id %d not found", msg.ServiceID) - } - // Make sure the user that is updating the service is the admin if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can update the service") @@ -146,15 +146,14 @@ func (k msgServer) UpdateService(ctx context.Context, msg *types.MsgUpdateServic func (k msgServer) ActivateService(ctx context.Context, msg *types.MsgActivateService) (*types.MsgActivateServiceResponse, error) { // Check if the service exists - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, errors.Wrapf(types.ErrServiceNotFound, "service with id %d not found", msg.ServiceID) - } - // Make sure the user that is activating the service is the admin if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can activate the service") @@ -181,15 +180,14 @@ func (k msgServer) ActivateService(ctx context.Context, msg *types.MsgActivateSe // DeactivateService defines the rpc method for Msg/DeactivateService func (k msgServer) DeactivateService(ctx context.Context, msg *types.MsgDeactivateService) (*types.MsgDeactivateServiceResponse, error) { // Check if the service exists - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, errors.Wrapf(types.ErrServiceNotFound, "service with id %d not found", msg.ServiceID) - } - // Make sure the user that is deactivating the service is the admin if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can deactivate the service") @@ -215,15 +213,14 @@ func (k msgServer) DeactivateService(ctx context.Context, msg *types.MsgDeactiva func (k msgServer) DeleteService(ctx context.Context, msg *types.MsgDeleteService) (*types.MsgDeleteServiceResponse, error) { // Check if the service exists - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, errors.Wrapf(types.ErrServiceNotFound, "service with id %d not found", msg.ServiceID) - } - // Make sure the user that is deleting the service is the admin if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can delete the service") @@ -250,15 +247,14 @@ func (k msgServer) DeleteService(ctx context.Context, msg *types.MsgDeleteServic // TransferServiceOwnership defines the rpc method for Msg/TransferServiceOwnership func (k msgServer) TransferServiceOwnership(ctx context.Context, msg *types.MsgTransferServiceOwnership) (*types.MsgTransferServiceOwnershipResponse, error) { // Check if the service exists - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, types.ErrServiceNotFound - } - // Make sure only the admin can transfer the service ownership if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can transfer the service ownership") @@ -286,15 +282,14 @@ func (k msgServer) TransferServiceOwnership(ctx context.Context, msg *types.MsgT // SetServiceParams define the rpc method for Msg/SetServiceParams func (k msgServer) SetServiceParams(ctx context.Context, msg *types.MsgSetServiceParams) (*types.MsgSetServiceParamsResponse, error) { // Get the service whose params are being set - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, types.ErrServiceNotFound - } - // Ensure the sender is the service admin if msg.Sender != service.Admin { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "sender must be the service admin") diff --git a/x/services/keeper/msg_server_test.go b/x/services/keeper/msg_server_test.go index 6eb9ccdf..f42293d5 100644 --- a/x/services/keeper/msg_server_test.go +++ b/x/services/keeper/msg_server_test.go @@ -1,6 +1,7 @@ package keeper_test import ( + "cosmossdk.io/collections" sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" @@ -118,9 +119,8 @@ func (suite *KeeperTestSuite) TestMsgServer_CreateService() { }, check: func(ctx sdk.Context) { // Make sure the service has been stored - stored, found, err := suite.k.GetService(ctx, 1) + stored, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_CREATED, @@ -200,9 +200,8 @@ func (suite *KeeperTestSuite) TestMsgServer_CreateService() { }, check: func(ctx sdk.Context) { // Make sure the service has been stored - stored, found, err := suite.k.GetService(ctx, 1) + stored, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_CREATED, @@ -386,9 +385,8 @@ func (suite *KeeperTestSuite) TestMsgServer_UpdateService() { }, check: func(ctx sdk.Context) { // Make sure the service was updated - stored, found, err := suite.k.GetService(ctx, 1) + stored, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_CREATED, @@ -632,9 +630,8 @@ func (suite *KeeperTestSuite) TestMsgServer_DeactivateService() { }, check: func(ctx sdk.Context) { // Make sure the service was deactivated - stored, found, err := suite.k.GetService(ctx, 1) + stored, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_INACTIVE, @@ -770,9 +767,8 @@ func (suite *KeeperTestSuite) TestMsgServer_DeleteService() { }, check: func(ctx sdk.Context) { // Make sure the service was removed - _, found, err := suite.k.GetService(ctx, 1) - suite.Require().NoError(err) - suite.Require().False(found) + _, err := suite.k.GetService(ctx, 1) + suite.Require().ErrorIs(err, collections.ErrNotFound) }, }, { @@ -803,9 +799,8 @@ func (suite *KeeperTestSuite) TestMsgServer_DeleteService() { }, check: func(ctx sdk.Context) { // Make sure the service was removed - _, found, err := suite.k.GetService(ctx, 1) - suite.Require().NoError(err) - suite.Require().False(found) + _, err := suite.k.GetService(ctx, 1) + suite.Require().ErrorIs(err, collections.ErrNotFound) }, }, } @@ -914,9 +909,8 @@ func (suite *KeeperTestSuite) TestMsgServer_TransferServiceOwnership() { }, check: func(ctx sdk.Context) { // Make sure the service was updated - stored, found, err := suite.k.GetService(ctx, 1) + stored, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_ACTIVE, @@ -1228,9 +1222,8 @@ func (suite *KeeperTestSuite) TestMsgServer_AccreditService() { ), }, check: func(ctx sdk.Context) { - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().True(service.Accredited) }, }, @@ -1325,9 +1318,8 @@ func (suite *KeeperTestSuite) TestMsgService_RevokeServiceAccreditation() { ), }, check: func(ctx sdk.Context) { - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().False(service.Accredited) }, }, diff --git a/x/services/keeper/services.go b/x/services/keeper/services.go index 8d5c31b4..18f6641e 100644 --- a/x/services/keeper/services.go +++ b/x/services/keeper/services.go @@ -64,15 +64,14 @@ func (k *Keeper) CreateService(ctx context.Context, service types.Service) error // ActivateService activates the service with the given ID func (k *Keeper) ActivateService(ctx context.Context, serviceID uint32) error { - service, found, err := k.GetService(ctx, serviceID) + service, err := k.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.ErrServiceNotFound + } return err } - if !found { - return types.ErrServiceNotFound - } - // Check if the service is already active if service.Status == types.SERVICE_STATUS_ACTIVE { return types.ErrServiceAlreadyActive @@ -93,15 +92,14 @@ func (k *Keeper) ActivateService(ctx context.Context, serviceID uint32) error { // DeactivateService deactivates the service with the given ID func (k *Keeper) DeactivateService(ctx context.Context, serviceID uint32) error { - service, exists, err := k.GetService(ctx, serviceID) + service, err := k.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.ErrServiceNotFound + } return err } - if !exists { - return types.ErrServiceNotFound - } - // Make sure the service is active if service.Status != types.SERVICE_STATUS_ACTIVE { return types.ErrServiceNotActive @@ -121,15 +119,14 @@ func (k *Keeper) DeactivateService(ctx context.Context, serviceID uint32) error // DeleteService deletes the service with the given ID func (k *Keeper) DeleteService(ctx context.Context, serviceID uint32) error { - service, exists, err := k.GetService(ctx, serviceID) + service, err := k.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.ErrServiceNotFound + } return err } - if !exists { - return types.ErrServiceNotFound - } - // Make sure the service is not active if service.Status == types.SERVICE_STATUS_ACTIVE { return types.ErrServiceIsActive @@ -153,15 +150,14 @@ func (k *Keeper) DeleteService(ctx context.Context, serviceID uint32) error { // SetServiceAccredited sets the accreditation of the service with the given ID func (k *Keeper) SetServiceAccredited(ctx context.Context, serviceID uint32, accredited bool) error { // Check if the service exists - service, found, err := k.GetService(ctx, serviceID) + service, err := k.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.ErrServiceNotFound + } return err } - if !found { - return errors.Wrapf(types.ErrServiceNotFound, "service with id %d not found", serviceID) - } - // Skip any operation if the service accreditation status does not change if service.Accredited == accredited { return nil @@ -184,15 +180,8 @@ func (k *Keeper) HasService(ctx context.Context, serviceID uint32) (bool, error) } // GetService returns an Service from the KVStore -func (k *Keeper) GetService(ctx context.Context, serviceID uint32) (service types.Service, found bool, err error) { - service, err = k.services.Get(ctx, serviceID) - if err != nil { - if errors.IsOf(err, collections.ErrNotFound) { - return service, false, nil - } - return service, false, err - } - return service, true, nil +func (k *Keeper) GetService(ctx context.Context, serviceID uint32) (service types.Service, err error) { + return k.services.Get(ctx, serviceID) } // GetServiceParams returns the params for the service with the given ID diff --git a/x/services/keeper/services_test.go b/x/services/keeper/services_test.go index d2e78893..20797276 100644 --- a/x/services/keeper/services_test.go +++ b/x/services/keeper/services_test.go @@ -1,6 +1,7 @@ package keeper_test import ( + "cosmossdk.io/collections" sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" @@ -145,9 +146,8 @@ func (suite *KeeperTestSuite) TestKeeper_CreateService() { suite.Require().True(hasAccount) // Make sure the service has been created - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_ACTIVE, @@ -240,9 +240,8 @@ func (suite *KeeperTestSuite) TestKeeper_ActivateService() { serviceID: 1, shouldErr: false, check: func(ctx sdk.Context) { - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_ACTIVE, @@ -335,9 +334,8 @@ func (suite *KeeperTestSuite) TestKeeper_DeactivateService() { serviceID: 1, shouldErr: false, check: func(ctx sdk.Context) { - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_INACTIVE, @@ -415,9 +413,8 @@ func (suite *KeeperTestSuite) TestKeeper_SetServiceAccreditation() { shouldErr: false, check: func(ctx sdk.Context) { // Accreditation didn't change - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().False(service.Accredited) // Make sure the hook wasn't called @@ -444,9 +441,8 @@ func (suite *KeeperTestSuite) TestKeeper_SetServiceAccreditation() { shouldErr: false, check: func(ctx sdk.Context) { // Accreditation changed - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().True(service.Accredited) // Make sure the hook was called @@ -485,14 +481,12 @@ func (suite *KeeperTestSuite) TestKeeper_GetService() { name string store func(ctx sdk.Context) serviceID uint32 - shouldErr bool expFound bool expService types.Service }{ { name: "service not found returns false", serviceID: 1, - shouldErr: false, expFound: false, }, { @@ -511,7 +505,6 @@ func (suite *KeeperTestSuite) TestKeeper_GetService() { suite.Require().NoError(err) }, serviceID: 1, - shouldErr: false, expFound: true, expService: types.NewService( 1, @@ -534,18 +527,12 @@ func (suite *KeeperTestSuite) TestKeeper_GetService() { tc.store(ctx) } - service, found, err := suite.k.GetService(ctx, tc.serviceID) - if tc.shouldErr { - suite.Require().Error(err) + service, err := suite.k.GetService(ctx, tc.serviceID) + if !tc.expFound { + suite.Require().ErrorIs(err, collections.ErrNotFound) } else { suite.Require().NoError(err) - - if !tc.expFound { - suite.Require().False(found) - } else { - suite.Require().True(found) - suite.Require().Equal(tc.expService, service) - } + suite.Require().Equal(tc.expService, service) } }) }