Skip to content

Commit

Permalink
Merge branch 'feat/partial-set-security' into sainoe/pss-query
Browse files Browse the repository at this point in the history
  • Loading branch information
sainoe committed Mar 26, 2024
2 parents b311f35 + 8ed60ee commit 22ca561
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/gosec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
- name: Run Gosec Security Scanner
uses: securego/gosec@master
with:
args: -exclude-dir=legacy_ibc_testing ./... -exclude-generated ./...
args: -exclude-dir=tests ./... -exclude-generated ./...
46 changes: 46 additions & 0 deletions x/ccv/provider/client/cli/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func GetTxCmd() *cobra.Command {
cmd.AddCommand(NewSubmitConsumerDoubleVotingCmd())
cmd.AddCommand(NewOptInCmd())
cmd.AddCommand(NewOptOutCmd())
cmd.AddCommand(NewSetConsumerCommissionRateCmd())

return cmd
}
Expand Down Expand Up @@ -288,3 +289,48 @@ func NewOptOutCmd() *cobra.Command {

return cmd
}

func NewSetConsumerCommissionRateCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "set-consumer-commission-rate [consumer-chain-id] [commission-rate]",
Short: "set a per-consumer chain commission",
Long: strings.TrimSpace(
fmt.Sprintf(`Note that the "commission-rate" argument is a fraction and should be in the range [0,1].
Example:
%s set-consumer-commission-rate consumer-1 0.5 --from node0 --home ../node0`,
version.AppName),
),
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
clientCtx, err := client.GetClientTxContext(cmd)
if err != nil {
return err
}

txf, err := tx.NewFactoryCLI(clientCtx, cmd.Flags())
if err != nil {
return err
}
txf = txf.WithTxConfig(clientCtx.TxConfig).WithAccountRetriever(clientCtx.AccountRetriever)

providerValAddr := clientCtx.GetFromAddress()

commission, err := sdk.NewDecFromStr(args[1])
if err != nil {
return err
}
msg := types.NewMsgSetConsumerCommissionRate(args[0], commission, sdk.ValAddress(providerValAddr))
if err := msg.ValidateBasic(); err != nil {
return err
}

return tx.GenerateOrBroadcastTxWithFactory(clientCtx, txf, msg)
},
}

flags.AddTxFlagsToCmd(cmd)

_ = cmd.MarkFlagRequired(flags.FlagFrom)

return cmd
}
10 changes: 8 additions & 2 deletions x/ccv/provider/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ func NewHandler(k *keeper.Keeper) sdk.Handler {
case *types.MsgSubmitConsumerMisbehaviour:
res, err := msgServer.SubmitConsumerMisbehaviour(sdk.WrapSDKContext(ctx), msg)
return sdk.WrapServiceResult(ctx, res, err)
case *types.MsgSubmitConsumerDoubleVoting:
res, err := msgServer.SubmitConsumerDoubleVoting(sdk.WrapSDKContext(ctx), msg)
case *types.MsgOptIn:
res, err := msgServer.OptIn(sdk.WrapSDKContext(ctx), msg)
return sdk.WrapServiceResult(ctx, res, err)
case *types.MsgOptOut:
res, err := msgServer.OptOut(sdk.WrapSDKContext(ctx), msg)
return sdk.WrapServiceResult(ctx, res, err)
case *types.MsgSetConsumerCommissionRate:
res, err := msgServer.SetConsumerCommissionRate(sdk.WrapSDKContext(ctx), msg)
return sdk.WrapServiceResult(ctx, res, err)
default:
return nil, errorsmod.Wrapf(sdkerrors.ErrUnknownRequest, "unrecognized %s message type: %T", types.ModuleName, msg)
Expand Down
1 change: 0 additions & 1 deletion x/ccv/provider/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

ibctesting "github.com/cosmos/ibc-go/v7/testing"
"github.com/stretchr/testify/require"

cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec"
sdk "github.com/cosmos/cosmos-sdk/types"

Expand Down
21 changes: 15 additions & 6 deletions x/ccv/provider/keeper/partial_set_security.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package keeper

import (
"math"
"sort"

errorsmod "cosmossdk.io/errors"
sdk "github.com/cosmos/cosmos-sdk/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"

"github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
"sort"
)

// HandleOptIn prepares validator `providerAddr` to opt in to `chainID` with an optional `consumerKey` consumer public key.
Expand Down Expand Up @@ -51,7 +52,7 @@ func (k Keeper) HandleOptOut(ctx sdk.Context, chainID string, providerAddr types
"opting out of an unknown or not running consumer chain, with id: %s", chainID)
}

if topN, found := k.GetTopN(ctx, chainID); found {
if topN, found := k.GetTopN(ctx, chainID); found && topN > 0 {
// a validator cannot opt out from a Top N chain if the validator is in the Top N validators
validator, validatorFound := k.stakingKeeper.GetValidatorByConsAddr(ctx, providerAddr.ToSdkConsAddr())
if !validatorFound {
Expand All @@ -60,7 +61,7 @@ func (k Keeper) HandleOptOut(ctx sdk.Context, chainID string, providerAddr types
"validator with consensus address %s could not be found", providerAddr.ToSdkConsAddr())
}
power := k.stakingKeeper.GetLastValidatorPower(ctx, validator.GetOperator())
minPowerToOptIn := k.ComputeMinPowerToOptIn(ctx, k.stakingKeeper.GetLastValidators(ctx), topN)
minPowerToOptIn := k.ComputeMinPowerToOptIn(ctx, chainID, k.stakingKeeper.GetLastValidators(ctx), topN)

if power >= minPowerToOptIn {
return errorsmod.Wrapf(
Expand Down Expand Up @@ -94,8 +95,16 @@ func (k Keeper) OptInTopNValidators(ctx sdk.Context, chainID string, bondedValid
}

// ComputeMinPowerToOptIn returns the minimum power needed for a validator (from the bonded validators)
// to belong to the `topN` validators
func (k Keeper) ComputeMinPowerToOptIn(ctx sdk.Context, bondedValidators []stakingtypes.Validator, topN uint32) int64 {
// to belong to the `topN` validators. `chainID` is only used for logging purposes.
func (k Keeper) ComputeMinPowerToOptIn(ctx sdk.Context, chainID string, bondedValidators []stakingtypes.Validator, topN uint32) int64 {
if topN == 0 {
// This should never happen but because `ComputeMinPowerToOptIn` is called during an `EndBlock` we do want
// to `panic` here. Instead, we log an error and return the maximum possible `int64`.
k.Logger(ctx).Error("trying to compute minimum power to opt in for a non-Top-N chain",
"chainID", chainID)
return math.MaxInt64
}

totalPower := sdk.ZeroDec()
var powers []int64

Expand Down
24 changes: 14 additions & 10 deletions x/ccv/provider/keeper/partial_set_security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package keeper_test

import (
"bytes"
"math"
"sort"
"testing"

Expand Down Expand Up @@ -271,16 +272,19 @@ func TestComputeMinPowerToOptIn(t *testing.T) {
createStakingValidator(ctx, mocks, 5, 6),
}

require.Equal(t, int64(1), providerKeeper.ComputeMinPowerToOptIn(ctx, bondedValidators, 100))
require.Equal(t, int64(1), providerKeeper.ComputeMinPowerToOptIn(ctx, bondedValidators, 97))
require.Equal(t, int64(3), providerKeeper.ComputeMinPowerToOptIn(ctx, bondedValidators, 96))
require.Equal(t, int64(3), providerKeeper.ComputeMinPowerToOptIn(ctx, bondedValidators, 85))
require.Equal(t, int64(5), providerKeeper.ComputeMinPowerToOptIn(ctx, bondedValidators, 84))
require.Equal(t, int64(5), providerKeeper.ComputeMinPowerToOptIn(ctx, bondedValidators, 65))
require.Equal(t, int64(6), providerKeeper.ComputeMinPowerToOptIn(ctx, bondedValidators, 64))
require.Equal(t, int64(6), providerKeeper.ComputeMinPowerToOptIn(ctx, bondedValidators, 41))
require.Equal(t, int64(10), providerKeeper.ComputeMinPowerToOptIn(ctx, bondedValidators, 40))
require.Equal(t, int64(10), providerKeeper.ComputeMinPowerToOptIn(ctx, bondedValidators, 1))
require.Equal(t, int64(1), providerKeeper.ComputeMinPowerToOptIn(ctx, "chainID", bondedValidators, 100))
require.Equal(t, int64(1), providerKeeper.ComputeMinPowerToOptIn(ctx, "chainID", bondedValidators, 97))
require.Equal(t, int64(3), providerKeeper.ComputeMinPowerToOptIn(ctx, "chainID", bondedValidators, 96))
require.Equal(t, int64(3), providerKeeper.ComputeMinPowerToOptIn(ctx, "chainID", bondedValidators, 85))
require.Equal(t, int64(5), providerKeeper.ComputeMinPowerToOptIn(ctx, "chainID", bondedValidators, 84))
require.Equal(t, int64(5), providerKeeper.ComputeMinPowerToOptIn(ctx, "chainID", bondedValidators, 65))
require.Equal(t, int64(6), providerKeeper.ComputeMinPowerToOptIn(ctx, "chainID", bondedValidators, 64))
require.Equal(t, int64(6), providerKeeper.ComputeMinPowerToOptIn(ctx, "chainID", bondedValidators, 41))
require.Equal(t, int64(10), providerKeeper.ComputeMinPowerToOptIn(ctx, "chainID", bondedValidators, 40))
require.Equal(t, int64(10), providerKeeper.ComputeMinPowerToOptIn(ctx, "chainID", bondedValidators, 1))

// exceptional case when we erroneously call with `topN == 0`
require.Equal(t, int64(math.MaxInt64), providerKeeper.ComputeMinPowerToOptIn(ctx, "chainID", bondedValidators, 0))
}

// TestShouldConsiderOnlyOptIn returns true if `validator` is opted in, in `chainID.
Expand Down
2 changes: 1 addition & 1 deletion x/ccv/provider/keeper/proposal.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func (k Keeper) MakeConsumerGenesis(

if prop.Top_N > 0 {
// in a Top-N chain, we automatically opt in all validators that belong to the top N
minPower := k.ComputeMinPowerToOptIn(ctx, bondedValidators, prop.Top_N)
minPower := k.ComputeMinPowerToOptIn(ctx, chainID, bondedValidators, prop.Top_N)
k.OptInTopNValidators(ctx, chainID, bondedValidators, minPower)
}

Expand Down
12 changes: 10 additions & 2 deletions x/ccv/provider/keeper/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ func (k Keeper) QueueVSCPackets(ctx sdk.Context) {
for _, chain := range k.GetAllConsumerChains(ctx) {
currentValidators := k.GetConsumerValSet(ctx, chain.ChainId)

if topN, found := k.GetTopN(ctx, chain.ChainId); found {
if topN, found := k.GetTopN(ctx, chain.ChainId); found && topN > 0 {
// in a Top-N chain, we automatically opt in all validators that belong to the top N
minPower := k.ComputeMinPowerToOptIn(ctx, bondedValidators, topN)
minPower := k.ComputeMinPowerToOptIn(ctx, chain.ChainId, bondedValidators, topN)
k.OptInTopNValidators(ctx, chain.ChainId, bondedValidators, minPower)
}

Expand Down Expand Up @@ -401,6 +401,14 @@ func (k Keeper) HandleSlashPacket(ctx sdk.Context, chainID string, data ccv.Slas
"infractionType", data.Infraction,
)

// Check that the validator belongs to the consumer chain valset
if !k.IsConsumerValidator(ctx, chainID, providerConsAddr) {
k.Logger(ctx).Error("cannot jail validator %s that does not belong to consumer %s valset",
providerConsAddr.String(), chainID)
// drop packet
return
}

// Obtain validator from staking keeper
validator, found := k.stakingKeeper.GetValidatorByConsAddr(ctx, providerConsAddr.ToSdkConsAddr())

Expand Down
65 changes: 44 additions & 21 deletions x/ccv/provider/keeper/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
cryptotestutil "github.com/cosmos/interchain-security/v4/testutil/crypto"
testkeeper "github.com/cosmos/interchain-security/v4/testutil/keeper"
"github.com/cosmos/interchain-security/v4/x/ccv/provider/keeper"
"github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
providertypes "github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
ccv "github.com/cosmos/interchain-security/v4/x/ccv/types"
)
Expand Down Expand Up @@ -136,6 +137,9 @@ func TestOnRecvDowntimeSlashPacket(t *testing.T) {
// Now set slash meter to positive value and assert slash packet handled result is returned
providerKeeper.SetSlashMeter(ctx, math.NewInt(5))

// Set the consumer validator
providerKeeper.SetConsumerValidator(ctx, "chain-1", types.ConsumerValidator{ProviderConsAddr: packetData.Validator.Address})

// Mock call to GetEffectiveValPower, so that it returns 2.
providerAddr := providertypes.NewProviderConsAddress(packetData.Validator.Address)
calls := []*gomock.Call{
Expand Down Expand Up @@ -289,8 +293,11 @@ func TestValidateSlashPacket(t *testing.T) {
func TestHandleSlashPacket(t *testing.T) {
chainId := "consumer-id"
validVscID := uint64(234)

providerConsAddr := cryptotestutil.NewCryptoIdentityFromIntSeed(7842334).ProviderConsAddress()
consumerConsAddr := cryptotestutil.NewCryptoIdentityFromIntSeed(784987634).ConsumerConsAddress()
// this "dummy" consensus address won't be stored on the provider states
dummyConsAddr := cryptotestutil.NewCryptoIdentityFromIntSeed(784987639).ConsumerConsAddress()

testCases := []struct {
name string
Expand All @@ -299,6 +306,20 @@ func TestHandleSlashPacket(t *testing.T) {
expectedCalls func(sdk.Context, testkeeper.MockedKeepers, ccv.SlashPacketData) []*gomock.Call
expectedSlashAcksLen int
}{
{
"validator isn't a consumer validator",
ccv.SlashPacketData{
Validator: abci.Validator{Address: dummyConsAddr.ToSdkConsAddr()},
ValsetUpdateId: validVscID,
Infraction: stakingtypes.Infraction_INFRACTION_DOWNTIME,
},
func(ctx sdk.Context, mocks testkeeper.MockedKeepers,
expectedPacketData ccv.SlashPacketData,
) []*gomock.Call {
return []*gomock.Call{}
},
0,
},
{
"unfound validator",
ccv.SlashPacketData{
Expand Down Expand Up @@ -403,34 +424,36 @@ func TestHandleSlashPacket(t *testing.T) {
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(
t, testkeeper.NewInMemKeeperParams(t))

providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(
t, testkeeper.NewInMemKeeperParams(t))

// Setup expected mock calls
gomock.InOrder(tc.expectedCalls(ctx, mocks, tc.packetData)...)
// Setup expected mock calls
gomock.InOrder(tc.expectedCalls(ctx, mocks, tc.packetData)...)

// Setup init chain height and a single valid valset update ID to block height mapping.
providerKeeper.SetInitChainHeight(ctx, chainId, 5)
providerKeeper.SetValsetUpdateBlockHeight(ctx, validVscID, 99)
// Setup init chain height and a single valid valset update ID to block height mapping.
providerKeeper.SetInitChainHeight(ctx, chainId, 5)
providerKeeper.SetValsetUpdateBlockHeight(ctx, validVscID, 99)

// Setup consumer address to provider address mapping.
require.NotEmpty(t, tc.packetData.Validator.Address)
providerKeeper.SetValidatorByConsumerAddr(ctx, chainId, consumerConsAddr, providerConsAddr)
// Setup consumer address to provider address mapping.
require.NotEmpty(t, tc.packetData.Validator.Address)
providerKeeper.SetValidatorByConsumerAddr(ctx, chainId, consumerConsAddr, providerConsAddr)
providerKeeper.SetConsumerValidator(ctx, chainId, types.ConsumerValidator{ProviderConsAddr: providerConsAddr.Address.Bytes()})

// Execute method and assert expected mock calls.
providerKeeper.HandleSlashPacket(ctx, chainId, tc.packetData)
// Execute method and assert expected mock calls.
providerKeeper.HandleSlashPacket(ctx, chainId, tc.packetData)

require.Equal(t, tc.expectedSlashAcksLen, len(providerKeeper.GetSlashAcks(ctx, chainId)))
require.Equal(t, tc.expectedSlashAcksLen, len(providerKeeper.GetSlashAcks(ctx, chainId)))

if tc.expectedSlashAcksLen == 1 {
// must match the consumer address
require.Equal(t, consumerConsAddr.String(), providerKeeper.GetSlashAcks(ctx, chainId)[0])
require.NotEqual(t, providerConsAddr.String(), providerKeeper.GetSlashAcks(ctx, chainId)[0])
require.NotEqual(t, providerConsAddr.String(), consumerConsAddr.String())
}
if tc.expectedSlashAcksLen == 1 {
// must match the consumer address
require.Equal(t, consumerConsAddr.String(), providerKeeper.GetSlashAcks(ctx, chainId)[0])
require.NotEqual(t, providerConsAddr.String(), providerKeeper.GetSlashAcks(ctx, chainId)[0])
require.NotEqual(t, providerConsAddr.String(), consumerConsAddr.String())
}

ctrl.Finish()
ctrl.Finish()
})
}
}

Expand Down
9 changes: 9 additions & 0 deletions x/ccv/provider/types/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,15 @@ func (msg MsgOptOut) ValidateBasic() error {
return nil
}

// NewMsgSetConsumerCommissionRate creates a new MsgSetConsumerCommissionRate msg instance.
func NewMsgSetConsumerCommissionRate(chainID string, commission sdk.Dec, providerValidatorAddress sdk.ValAddress) *MsgSetConsumerCommissionRate {
return &MsgSetConsumerCommissionRate{
ChainId: chainID,
Rate: commission,
ProviderAddr: providerValidatorAddress.String(),
}
}

// Type implements the sdk.Msg interface.
func (msg MsgOptOut) Type() string {
return TypeMsgOptOut
Expand Down

0 comments on commit 22ca561

Please sign in to comment.