From 46120671e0ab8c4aa08a2f09e59a154174bdb90d Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sat, 29 Jun 2024 17:51:50 +0300 Subject: [PATCH] refactor(x/participationrewards/keeper): combine GetProtocolData+UnmarshalProtocolData This change combines GetProtocolData and types.UnmarshalProtocolData into a generic function that unifies the functionality and the pattern. Fixes #1631 --- x/participationrewards/keeper/callbacks.go | 84 ++++--------------- .../keeper/callbacks_test.go | 52 ++---------- x/participationrewards/keeper/distribution.go | 8 +- .../keeper/protocol_data.go | 18 ++++ 4 files changed, 42 insertions(+), 120 deletions(-) diff --git a/x/participationrewards/keeper/callbacks.go b/x/participationrewards/keeper/callbacks.go index 892bab1cd..528267347 100644 --- a/x/participationrewards/keeper/callbacks.go +++ b/x/participationrewards/keeper/callbacks.go @@ -139,18 +139,12 @@ func OsmosisPoolUpdateCallback(ctx sdk.Context, k *Keeper, response []byte, quer } poolID := sdk.BigEndianToUint64(query.Request[1:]) - data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisPool, fmt.Sprintf("%d", poolID)) - if !ok { - return fmt.Errorf("unable to find protocol data for osmosispools/%d", poolID) - } - ipool, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisPool, data.Data) + key := fmt.Sprintf("%d", poolID) + data, pool, err := GetAndUnmarshalProtocolData[*types.OsmosisPoolProtocolData](ctx, k, key, types.ProtocolDataTypeOsmosisPool) if err != nil { return err } - pool, ok := ipool.(*types.OsmosisPoolProtocolData) - if !ok { - return fmt.Errorf("unable to unmarshal protocol data for osmosispools/%d", poolID) - } + pool.PoolData, err = json.Marshal(pd) if err != nil { return err @@ -185,18 +179,11 @@ func OsmosisClPoolUpdateCallback(ctx sdk.Context, k *Keeper, response []byte, qu return err } - data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisCLPool, fmt.Sprintf("%d", poolID)) - if !ok { - return fmt.Errorf("unable to find protocol data for osmosisclpools/%d", poolID) - } - ipool, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisCLPool, data.Data) + data, pool, err := GetAndUnmarshalProtocolData[*types.OsmosisClPoolProtocolData](ctx, k, fmt.Sprintf("%d", poolID), types.ProtocolDataTypeOsmosisCLPool) if err != nil { return err } - pool, ok := ipool.(*types.OsmosisClPoolProtocolData) - if !ok { - return fmt.Errorf("unable to unmarshal protocol data for osmosisclpools/%d", poolID) - } + pool.PoolData, err = json.Marshal(pd) if err != nil { return err @@ -222,18 +209,11 @@ func UmeeReservesUpdateCallback(ctx sdk.Context, k *Keeper, response []byte, que } denom := umeetypes.DenomFromKey(query.Request, umeetypes.KeyPrefixReserveAmount) - data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeUmeeReserves, denom) - if !ok { - return fmt.Errorf("unable to find protocol data for umeereserves/%s", denom) - } - ireserves, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeReserves, data.Data) + data, reserves, err := GetAndUnmarshalProtocolData[*types.UmeeReservesProtocolData](ctx, k, denom, types.ProtocolDataTypeUmeeReserves) if err != nil { return err } - reserves, ok := ireserves.(*types.UmeeReservesProtocolData) - if !ok { - return fmt.Errorf("unable to unmarshal protocol data for umeereserves/%s", denom) - } + reserves.Data, err = json.Marshal(reserveAmount) if err != nil { return err @@ -259,18 +239,11 @@ func UmeeTotalBorrowsUpdateCallback(ctx sdk.Context, k *Keeper, response []byte, } denom := umeetypes.DenomFromKey(query.Request, umeetypes.KeyPrefixAdjustedTotalBorrow) - data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeUmeeTotalBorrows, denom) - if !ok { - return fmt.Errorf("unable to find protocol data for umee-types total borrows/%s", denom) - } - iborrows, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeTotalBorrows, data.Data) + data, borrows, err := GetAndUnmarshalProtocolData[*types.UmeeTotalBorrowsProtocolData](ctx, k, denom, types.ProtocolDataTypeUmeeTotalBorrows) if err != nil { return err } - borrows, ok := iborrows.(*types.UmeeTotalBorrowsProtocolData) - if !ok { - return fmt.Errorf("unable to unmarshal protocol data for umee-types total borrows/%s", denom) - } + borrows.Data, err = json.Marshal(totalBorrows) if err != nil { return err @@ -296,18 +269,11 @@ func UmeeInterestScalarUpdateCallback(ctx sdk.Context, k *Keeper, response []byt } denom := umeetypes.DenomFromKey(query.Request, umeetypes.KeyPrefixInterestScalar) - data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeUmeeInterestScalar, denom) - if !ok { - return fmt.Errorf("unable to find protocol data for interestscalar/%s", denom) - } - iinterest, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeInterestScalar, data.Data) + data, interest, err := GetAndUnmarshalProtocolData[*types.UmeeInterestScalarProtocolData](ctx, k, denom, types.ProtocolDataTypeUmeeInterestScalar) if err != nil { return err } - interest, ok := iinterest.(*types.UmeeInterestScalarProtocolData) - if !ok { - return fmt.Errorf("unable to unmarshal protocol data for interestscalar/%s", denom) - } + interest.Data, err = json.Marshal(interestScalar) if err != nil { return err @@ -333,18 +299,10 @@ func UmeeUTokenSupplyUpdateCallback(ctx sdk.Context, k *Keeper, response []byte, } denom := umeetypes.DenomFromKey(query.Request, umeetypes.KeyPrefixUtokenSupply) - data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeUmeeUTokenSupply, denom) - if !ok { - return fmt.Errorf("unable to find protocol data for umee-types utoken supply/%s", denom) - } - isupply, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeUTokenSupply, data.Data) + data, supply, err := GetAndUnmarshalProtocolData[*types.UmeeUTokenSupplyProtocolData](ctx, k, denom, types.ProtocolDataTypeUmeeUTokenSupply) if err != nil { return err } - supply, ok := isupply.(*types.UmeeUTokenSupplyProtocolData) - if !ok { - return fmt.Errorf("unable to unmarshal protocol data for umee-types utoken supply/%s", denom) - } supply.Data, err = json.Marshal(supplyAmount) if err != nil { return err @@ -377,18 +335,10 @@ func UmeeLeverageModuleBalanceUpdateCallback(ctx sdk.Context, k *Keeper, respons } balanceAmount := balanceCoin.Amount - data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeUmeeLeverageModuleBalance, denom) - if !ok { - return fmt.Errorf("unable to find protocol data for umee-types leverage module/%s", denom) - } - ibalance, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeLeverageModuleBalance, data.Data) + data, balance, err := GetAndUnmarshalProtocolData[*types.UmeeLeverageModuleBalanceProtocolData](ctx, k, denom, types.ProtocolDataTypeUmeeLeverageModuleBalance) if err != nil { return err } - balance, ok := ibalance.(*types.UmeeLeverageModuleBalanceProtocolData) - if !ok { - return fmt.Errorf("unable to unmarshal protocol data for umee-types leverage module/%s", denom) - } balance.Data, err = json.Marshal(balanceAmount) if err != nil { return err @@ -405,14 +355,8 @@ func UmeeLeverageModuleBalanceUpdateCallback(ctx sdk.Context, k *Keeper, respons // SetEpochBlockCallback records the block height of the registered zone at the epoch boundary. func SetEpochBlockCallback(ctx sdk.Context, k *Keeper, args []byte, query icqtypes.Query) error { - data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeConnection, query.ChainId) - if !ok { - return fmt.Errorf("unable to find protocol data for connection/%s", query.ChainId) - } k.Logger(ctx).Debug("epoch callback called") - iConnectionData, err := types.UnmarshalProtocolData(types.ProtocolDataTypeConnection, data.Data) - connectionData, _ := iConnectionData.(*types.ConnectionProtocolData) - + data, connectionData, err := GetAndUnmarshalProtocolData[*types.ConnectionProtocolData](ctx, k, query.ChainId, types.ProtocolDataTypeConnection) if err != nil { return err } diff --git a/x/participationrewards/keeper/callbacks_test.go b/x/participationrewards/keeper/callbacks_test.go index 2ed038c82..1cb4b21f5 100644 --- a/x/participationrewards/keeper/callbacks_test.go +++ b/x/participationrewards/keeper/callbacks_test.go @@ -73,15 +73,9 @@ func (suite *KeeperTestSuite) TestOsmosisPoolUpdateCallback() { suite.NoError(err) - pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisPool, "944") - suite.True(found) - - data, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisPool, pd.GetData()) + _, pooldata, err := keeper.GetAndUnmarshalProtocolData[*types.OsmosisPoolProtocolData](ctx, prk, "944", types.ProtocolDataTypeOsmosisPool) suite.NoError(err) - pooldata, ok := data.(*types.OsmosisPoolProtocolData) - suite.True(ok) - pool, err := pooldata.GetPool() suite.NoError(err) @@ -137,15 +131,9 @@ func (suite *KeeperTestSuite) TestOsmosisClPoolUpdateCallback() { suite.NoError(err) - pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisCLPool, "1089") - suite.True(found) - - data, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisCLPool, pd.GetData()) + _, pooldata, err := keeper.GetAndUnmarshalProtocolData[*types.OsmosisClPoolProtocolData](ctx, prk, "1089", types.ProtocolDataTypeOsmosisCLPool) suite.NoError(err) - pooldata, ok := data.(*types.OsmosisClPoolProtocolData) - suite.True(ok) - pool, err := pooldata.GetPool() suite.NoError(err) @@ -198,12 +186,8 @@ func (suite *KeeperTestSuite) executeOsmosisPoolUpdateCallback() { }, } - pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisPool, "1") - suite.True(found) - - ioppd, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisPool, pd.Data) + _, oppd, err := keeper.GetAndUnmarshalProtocolData[*types.OsmosisPoolProtocolData](ctx, prk, "1", types.ProtocolDataTypeOsmosisPool) suite.NoError(err) - oppd := ioppd.(*types.OsmosisPoolProtocolData) suite.Equal(want, oppd) } @@ -321,12 +305,8 @@ func (suite *KeeperTestSuite) executeUmeeReservesUpdateCallback() { }, } - pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeUmeeReserves, umeeBaseDenom) - suite.True(found) - - value, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeReserves, pd.Data) + _, result, err := keeper.GetAndUnmarshalProtocolData[*types.UmeeReservesProtocolData](ctx, prk, umeeBaseDenom, types.ProtocolDataTypeUmeeReserves) suite.NoError(err) - result := value.(*types.UmeeReservesProtocolData) suite.Equal(want, result) } @@ -365,12 +345,8 @@ func (suite *KeeperTestSuite) executeUmeeLeverageModuleBalanceUpdateCallback() { }, } - pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeUmeeLeverageModuleBalance, umeeBaseDenom) - suite.True(found) - - value, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeLeverageModuleBalance, pd.Data) + _, result, err := keeper.GetAndUnmarshalProtocolData[*types.UmeeLeverageModuleBalanceProtocolData](ctx, prk, umeeBaseDenom, types.ProtocolDataTypeUmeeLeverageModuleBalance) suite.NoError(err) - result := value.(*types.UmeeLeverageModuleBalanceProtocolData) suite.Equal(want, result) } @@ -407,12 +383,8 @@ func (suite *KeeperTestSuite) executeUmeeUTokenSupplyUpdateCallback() { }, } - pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeUmeeUTokenSupply, leveragetypes.UTokenPrefix+umeeBaseDenom) - suite.True(found) - - value, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeUTokenSupply, pd.Data) + _, result, err := keeper.GetAndUnmarshalProtocolData[*types.UmeeUTokenSupplyProtocolData](ctx, prk, leveragetypes.UTokenPrefix+umeeBaseDenom, types.ProtocolDataTypeUmeeUTokenSupply) suite.NoError(err) - result := value.(*types.UmeeUTokenSupplyProtocolData) suite.Equal(want, result) } @@ -449,12 +421,8 @@ func (suite *KeeperTestSuite) executeUmeeTotalBorrowsUpdateCallback() { }, } - pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeUmeeTotalBorrows, umeeBaseDenom) - suite.True(found) - - value, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeTotalBorrows, pd.Data) + _, result, err := keeper.GetAndUnmarshalProtocolData[*types.UmeeTotalBorrowsProtocolData](ctx, prk, umeeBaseDenom, types.ProtocolDataTypeUmeeTotalBorrows) suite.NoError(err) - result := value.(*types.UmeeTotalBorrowsProtocolData) suite.Equal(want, result) } @@ -491,11 +459,7 @@ func (suite *KeeperTestSuite) executeUmeeInterestScalarUpdateCallback() { }, } - pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeUmeeInterestScalar, umeeBaseDenom) - suite.True(found) - - value, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeInterestScalar, pd.Data) + _, result, err := keeper.GetAndUnmarshalProtocolData[*types.UmeeInterestScalarProtocolData](ctx, prk, umeeBaseDenom, types.ProtocolDataTypeUmeeInterestScalar) suite.NoError(err) - result := value.(*types.UmeeInterestScalarProtocolData) suite.Equal(want, result) } diff --git a/x/participationrewards/keeper/distribution.go b/x/participationrewards/keeper/distribution.go index dbacf0fd1..d8c2a40fd 100644 --- a/x/participationrewards/keeper/distribution.go +++ b/x/participationrewards/keeper/distribution.go @@ -37,16 +37,12 @@ func DepthFirstSearch(graph AssetGraph, visited map[string]struct{}, asset strin func (k *Keeper) CalcTokenValues(ctx sdk.Context) (TokenValues, error) { k.Logger(ctx).Info("calcTokenValues") - data, found := k.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisParams, "osmosisparams") - if !found { - return TokenValues{}, errors.New("could not find osmosisparams protocol data") - } - osmoParams, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisParams, data.Data) + _, osmoParams, err := GetAndUnmarshalProtocolData[*types.OsmosisParamsProtocolData](ctx, k, "osmosisparams", types.ProtocolDataTypeOsmosisParams) if err != nil { return TokenValues{}, err } - baseDenom := osmoParams.(*types.OsmosisParamsProtocolData).BaseDenom + baseDenom := osmoParams.BaseDenom tvs := make(TokenValues) graph := make(AssetGraphSlice) diff --git a/x/participationrewards/keeper/protocol_data.go b/x/participationrewards/keeper/protocol_data.go index 3e5d8ef79..787f9dc18 100644 --- a/x/participationrewards/keeper/protocol_data.go +++ b/x/participationrewards/keeper/protocol_data.go @@ -1,6 +1,8 @@ package keeper import ( + "fmt" + "github.com/cosmos/cosmos-sdk/store/prefix" sdk "github.com/cosmos/cosmos-sdk/types" @@ -38,6 +40,22 @@ func (k Keeper) SetProtocolData(ctx sdk.Context, key []byte, data *types.Protoco store.Set(types.GetProtocolDataKey(types.ProtocolDataType(pdType), key), bz) } +func GetAndUnmarshalProtocolData[T any](ctx sdk.Context, k *Keeper, key string, pdType types.ProtocolDataType) (dt types.ProtocolData, tt T, err error) { + data, ok := k.GetProtocolData(ctx, pdType, key) + if !ok { + return dt, tt, fmt.Errorf("unable to find protocol data for %q", key) + } + pd, err := types.UnmarshalProtocolData(pdType, data.Data) + if err != nil { + return dt, tt, err + } + asType, ok := pd.(T) + if !ok { + return dt, tt, fmt.Errorf("could not retrieve type of %T, actual type: %T", (*T)(nil), pd) + } + return data, asType, nil +} + // DeleteProtocolData deletes protocol data info. func (k *Keeper) DeleteProtocolData(ctx sdk.Context, key []byte) { store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefixProtocolData)