Skip to content

Commit

Permalink
fix: ABCI Removal Fix of Currency Pairs (#300)
Browse files Browse the repository at this point in the history
* init

* keeper fixes

* testing
  • Loading branch information
davidterpay authored Apr 11, 2024
1 parent 043ce53 commit d13181b
Show file tree
Hide file tree
Showing 11 changed files with 201 additions and 61 deletions.
20 changes: 18 additions & 2 deletions abci/strategies/currencypair/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,25 @@ func (s *DefaultCurrencyPairStrategy) GetDecodedPrice(
}

// GetMaxNumCP returns the number of pairs that the VEs should include. This method returns an error if the size cannot
// be queried from the x/oracle state.
// be queried from the x/oracle state. Specifically, this method should return the maximum number of currency pairs that
// could have existed at the time at which the votes were created. As such, if the execution mode is PrepareProposal or
// ProcessProposal, the number of removed currency pairs in the previous block should be included in the total.
func (s *DefaultCurrencyPairStrategy) GetMaxNumCP(
ctx sdk.Context,
) (uint64, error) {
return s.oracleKeeper.GetPrevBlockCPCounter(ctx)
current, err := s.oracleKeeper.GetNumCurrencyPairs(ctx)
if err != nil {
return 0, err
}

if mode := ctx.ExecMode(); mode == sdk.ExecModePrepareProposal || mode == sdk.ExecModeProcessProposal {
removed, err := s.oracleKeeper.GetNumRemovedCurrencyPairs(ctx)
if err != nil {
return 0, err
}

return current + removed, nil
}

return current, nil
}
83 changes: 83 additions & 0 deletions abci/strategies/currencypair/default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,86 @@ func TestDefaultCurrencyPairStrategyGetEncodedPrice(t *testing.T) {
require.Error(t, err)
})
}

func TestGetMaxNumCP(t *testing.T) {
ok := mocks.NewOracleKeeper(t)
strategy := strategies.NewDefaultCurrencyPairStrategy(ok)

t.Run("can get max number of currency pairs with no removals, PrepareProposal", func(t *testing.T) {
ctx := sdk.Context{}.WithExecMode(sdk.ExecModePrepareProposal)

maxNumCP := uint64(100)
ok.On("GetNumCurrencyPairs", ctx).Return(maxNumCP, nil).Once()

numRemovedInPrevBlock := uint64(0)
ok.On("GetNumRemovedCurrencyPairs", ctx).Return(numRemovedInPrevBlock, nil).Once()

numCP, err := strategy.GetMaxNumCP(ctx)
require.NoError(t, err)
require.Equal(t, maxNumCP, numCP)
})

t.Run("can get max number of currency pairs with removals, PrepareProposal", func(t *testing.T) {
ctx := sdk.Context{}.WithExecMode(sdk.ExecModePrepareProposal)

maxNumCP := uint64(100)
ok.On("GetNumCurrencyPairs", ctx).Return(maxNumCP, nil).Once()

numRemovedInPrevBlock := uint64(10)
ok.On("GetNumRemovedCurrencyPairs", ctx).Return(numRemovedInPrevBlock, nil).Once()

numCP, err := strategy.GetMaxNumCP(ctx)
require.NoError(t, err)
require.Equal(t, maxNumCP+numRemovedInPrevBlock, numCP)
})

t.Run("can get max number of currency pairs with no removals, ProcessProposal", func(t *testing.T) {
ctx := sdk.Context{}.WithExecMode(sdk.ExecModeProcessProposal)

maxNumCP := uint64(100)
ok.On("GetNumCurrencyPairs", ctx).Return(maxNumCP, nil).Once()

numRemovedInPrevBlock := uint64(0)
ok.On("GetNumRemovedCurrencyPairs", ctx).Return(numRemovedInPrevBlock, nil).Once()

numCP, err := strategy.GetMaxNumCP(ctx)
require.NoError(t, err)
require.Equal(t, maxNumCP, numCP)
})

t.Run("can get max number of currency pairs with removals, ProcessProposal", func(t *testing.T) {
ctx := sdk.Context{}.WithExecMode(sdk.ExecModeProcessProposal)

maxNumCP := uint64(100)
ok.On("GetNumCurrencyPairs", ctx).Return(maxNumCP, nil).Once()

numRemovedInPrevBlock := uint64(10)
ok.On("GetNumRemovedCurrencyPairs", ctx).Return(numRemovedInPrevBlock, nil).Once()

numCP, err := strategy.GetMaxNumCP(ctx)
require.NoError(t, err)
require.Equal(t, maxNumCP+numRemovedInPrevBlock, numCP)
})

t.Run("can get max number of currency pairs for extend vote", func(t *testing.T) {
ctx := sdk.Context{}.WithExecMode(sdk.ExecModeVoteExtension)

maxNumCP := uint64(100)
ok.On("GetNumCurrencyPairs", ctx).Return(maxNumCP, nil).Once()

numCP, err := strategy.GetMaxNumCP(ctx)
require.NoError(t, err)
require.Equal(t, maxNumCP, numCP)
})

t.Run("can get max number of currency pairs for verify vote", func(t *testing.T) {
ctx := sdk.Context{}.WithExecMode(sdk.ExecModeVerifyVoteExtension)

maxNumCP := uint64(100)
ok.On("GetNumCurrencyPairs", ctx).Return(maxNumCP, nil).Once()

numCP, err := strategy.GetMaxNumCP(ctx)
require.NoError(t, err)
require.Equal(t, maxNumCP, numCP)
})
}
8 changes: 0 additions & 8 deletions abci/strategies/currencypair/delta.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,3 @@ func (s *DeltaCurrencyPairStrategy) getOnChainPrice(ctx sdk.Context, cp slinkyty
s.cache[cp] = currentPrice
return currentPrice, nil
}

// GetMaxNumCP returns the number of pairs that the VEs should include. This method returns an error if the size cannot
// be queried from the x/oracle state.
func (s *DeltaCurrencyPairStrategy) GetMaxNumCP(
ctx sdk.Context,
) (uint64, error) {
return s.oracleKeeper.GetPrevBlockCPCounter(ctx)
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 33 additions & 6 deletions abci/strategies/currencypair/mocks/mock_oracle_keeper.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion abci/strategies/currencypair/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ type OracleKeeper interface {
GetCurrencyPairFromID(ctx sdk.Context, id uint64) (cp slinkytypes.CurrencyPair, found bool)
GetIDForCurrencyPair(ctx sdk.Context, cp slinkytypes.CurrencyPair) (uint64, bool)
GetPriceForCurrencyPair(ctx sdk.Context, cp slinkytypes.CurrencyPair) (oracletypes.QuotePrice, error)
GetPrevBlockCPCounter(ctx sdk.Context) (uint64, error)
GetNumCurrencyPairs(ctx sdk.Context) (uint64, error)
GetNumRemovedCurrencyPairs(ctx sdk.Context) (uint64, error)
}

// CurrencyPairStrategy is a strategy for generating a unique ID and price representation for a given currency pair.
Expand Down
21 changes: 0 additions & 21 deletions x/oracle/keeper/abci.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package keeper

import (
"context"
"fmt"

sdk "github.com/cosmos/cosmos-sdk/types"
)
Expand All @@ -12,25 +11,5 @@ import (
func (k *Keeper) BeginBlocker(goCtx context.Context) error {
// unwrap the context
ctx := sdk.UnwrapSDKContext(goCtx)

removes, err := k.numRemoves.Get(ctx)
if err != nil {
return err
}

numCPs, err := k.numCPs.Get(ctx)
if err != nil {
return err
}

if numCPs < removes {
return fmt.Errorf("invalid decrement amount - result will be negative")
}

err = k.numCPs.Set(ctx, numCPs-removes)
if err != nil {
return err
}

return k.numRemoves.Set(ctx, 0)
}
28 changes: 20 additions & 8 deletions x/oracle/keeper/abci_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,35 @@ import (
func (s *KeeperTestSuite) TestBeginBlocker() {
s.Run("run with no state", func() {
s.Require().NoError(s.oracleKeeper.BeginBlocker(s.ctx))
removes, err := s.oracleKeeper.GetRemovedCPCounter(s.ctx)
removes, err := s.oracleKeeper.GetNumRemovedCurrencyPairs(s.ctx)
s.Require().NoError(err)
s.Require().Equal(removes, uint64(0))
})

s.Run("run with 1 in state - 1 removed", func() {
// Create the currency pair.
s.Require().NoError(s.oracleKeeper.CreateCurrencyPair(s.ctx, slinkytypes.CurrencyPair{Base: "test", Quote: "coin1"}))
cps, err := s.oracleKeeper.GetNumCurrencyPairs(s.ctx)
s.Require().NoError(err)
s.Require().Equal(cps, uint64(1))
removed, err := s.oracleKeeper.GetNumRemovedCurrencyPairs(s.ctx)
s.Require().NoError(err)
s.Require().Equal(removed, uint64(0))

// Remove the currency pair.
s.Require().NoError(s.oracleKeeper.RemoveCurrencyPair(s.ctx, slinkytypes.CurrencyPair{Base: "test", Quote: "coin1"}))
cps, err = s.oracleKeeper.GetNumCurrencyPairs(s.ctx)
s.Require().NoError(err)
s.Require().Equal(cps, uint64(0))
removed, err = s.oracleKeeper.GetNumRemovedCurrencyPairs(s.ctx)
s.Require().NoError(err)
s.Require().Equal(removed, uint64(1))

// Begin blocker should reset the removed count.
s.Require().NoError(s.oracleKeeper.BeginBlocker(s.ctx))
removes, err := s.oracleKeeper.GetRemovedCPCounter(s.ctx)
removes, err := s.oracleKeeper.GetNumRemovedCurrencyPairs(s.ctx)
s.Require().NoError(err)
s.Require().Equal(removes, uint64(0))

cps, err := s.oracleKeeper.GetPrevBlockCPCounter(s.ctx)
s.Require().NoError(err)
s.Require().Equal(cps, uint64(0))
})

s.Run("run with 2 in state - 1 removed", func() {
Expand All @@ -32,11 +44,11 @@ func (s *KeeperTestSuite) TestBeginBlocker() {
s.Require().NoError(s.oracleKeeper.CreateCurrencyPair(s.ctx, slinkytypes.CurrencyPair{Base: "test", Quote: "coin2"}))

s.Require().NoError(s.oracleKeeper.BeginBlocker(s.ctx))
removes, err := s.oracleKeeper.GetRemovedCPCounter(s.ctx)
removes, err := s.oracleKeeper.GetNumRemovedCurrencyPairs(s.ctx)
s.Require().NoError(err)
s.Require().Equal(removes, uint64(0))

cps, err := s.oracleKeeper.GetPrevBlockCPCounter(s.ctx)
cps, err := s.oracleKeeper.GetNumCurrencyPairs(s.ctx)
s.Require().NoError(err)
s.Require().Equal(cps, uint64(1))
})
Expand Down
30 changes: 25 additions & 5 deletions x/oracle/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,19 @@ func NewKeeper(

// RemoveCurrencyPair removes a given CurrencyPair from state, i.e. removes its nonce + QuotePrice from the module's store.
func (k *Keeper) RemoveCurrencyPair(ctx sdk.Context, cp slinkytypes.CurrencyPair) error {
// check if the currency pair exists.
if !k.HasCurrencyPair(ctx, cp) {
return types.NewCurrencyPairNotExistError(cp)
}

if err := k.currencyPairs.Remove(ctx, cp.String()); err != nil {
return err
}
return k.incrementRemovedCPCounter(ctx)
if err := k.incrementRemovedCPCounter(ctx); err != nil {
return err
}

return k.decrementCPCounter(ctx)
}

// HasCurrencyPair returns true if a given CurrencyPair is stored in state, false otherwise.
Expand Down Expand Up @@ -334,8 +343,8 @@ func (k *Keeper) incrementRemovedCPCounter(ctx sdk.Context) error {
return k.numRemoves.Set(ctx, val)
}

// GetRemovedCPCounter gets the counter of removed currency pairs.
func (k *Keeper) GetRemovedCPCounter(ctx sdk.Context) (uint64, error) {
// GetNumRemovedCurrencyPairs gets the counter of removed currency pairs in the previous block.
func (k *Keeper) GetNumRemovedCurrencyPairs(ctx sdk.Context) (uint64, error) {
return k.numRemoves.Get(ctx)
}

Expand All @@ -350,7 +359,18 @@ func (k *Keeper) incrementCPCounter(ctx sdk.Context) error {
return k.numCPs.Set(ctx, val)
}

// GetPrevBlockCPCounter gets the counter of currency pairs in the previous block.
func (k *Keeper) GetPrevBlockCPCounter(ctx sdk.Context) (uint64, error) {
// DecrementCPCounter decrements the counter of currency pairs.
func (k *Keeper) decrementCPCounter(ctx sdk.Context) error {
val, err := k.numCPs.Get(ctx)
if err != nil {
return err
}

val--
return k.numCPs.Set(ctx, val)
}

// GetNumCurrencyPairs returns the number of currency pairs currently in state.
func (k *Keeper) GetNumCurrencyPairs(ctx sdk.Context) (uint64, error) {
return k.numCPs.Get(ctx)
}
Loading

0 comments on commit d13181b

Please sign in to comment.