Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: improve message validation #1460

Merged
merged 15 commits into from
Dec 1, 2023
3 changes: 3 additions & 0 deletions .changelog/unreleased/bug-fixes/1460-msg-validation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Improve validation of IBC packet data and provider messages. Also,
enable the provider to validate consumer packets before handling them.
([\#1460](https://github.com/cosmos/interchain-security/pull/1460))
3 changes: 3 additions & 0 deletions .changelog/unreleased/state-breaking/1460-msg-validation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Improve validation of IBC packet data and provider messages. Also,
enable the provider to validate consumer packets before handling them.
([\#1460](https://github.com/cosmos/interchain-security/pull/1460))
2 changes: 1 addition & 1 deletion tests/integration/expired_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (s *CCVTestSuite) TestConsumerPacketSendExpiredClient() {

// try to send slash packet for downtime infraction
addr := ed25519.GenPrivKey().PubKey().Address()
val := abci.Validator{Address: addr}
val := abci.Validator{Address: addr, Power: 1}
consumerKeeper.QueueSlashPacket(s.consumerCtx(), val, 2, stakingtypes.Infraction_INFRACTION_DOWNTIME)
// try to send slash packet for the same downtime infraction
consumerKeeper.QueueSlashPacket(s.consumerCtx(), val, 3, stakingtypes.Infraction_INFRACTION_DOWNTIME)
Expand Down
162 changes: 72 additions & 90 deletions tests/integration/slashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkaddress "github.com/cosmos/cosmos-sdk/types/address"
evidencetypes "github.com/cosmos/cosmos-sdk/x/evidence/types"
slashingtypes "github.com/cosmos/cosmos-sdk/x/slashing/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
Expand Down Expand Up @@ -270,12 +271,14 @@ func (s *CCVTestSuite) TestSlashPacketAcknowledgement() {
// Map infraction height on provider so validation passes and provider returns valid ack result
providerKeeper.SetValsetUpdateBlockHeight(s.providerCtx(), spd.ValsetUpdateId, 47923)

exportedAck := providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, spd)
s.Require().NotNil(exportedAck)
ackResult, err := providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, spd)
s.Require().NotNil(ackResult)
s.Require().NoError(err)
exportedAck := channeltypes.NewResultAcknowledgement(ackResult)

// Unmarshal ack to struct that's compatible with consumer. IBC does this automatically
ack := channeltypes.Acknowledgement{}
err := channeltypes.SubModuleCdc.UnmarshalJSON(exportedAck.Acknowledgement(), &ack)
err = channeltypes.SubModuleCdc.UnmarshalJSON(exportedAck.Acknowledgement(), &ack)
s.Require().NoError(err)

err = consumerKeeper.OnAcknowledgementPacket(s.consumerCtx(), packet, ack)
Expand Down Expand Up @@ -329,9 +332,7 @@ func (suite *CCVTestSuite) TestHandleSlashPacketDowntime() {
// TestOnRecvSlashPacketErrors tests errors for the OnRecvSlashPacket method in an integration testing setting
func (suite *CCVTestSuite) TestOnRecvSlashPacketErrors() {
providerKeeper := suite.providerApp.GetProviderKeeper()
providerSlashingKeeper := suite.providerApp.GetTestSlashingKeeper()
firstBundle := suite.getFirstBundle()
consumerChainID := firstBundle.Chain.ChainID

suite.SetupAllCCVChannels()

Expand All @@ -340,100 +341,80 @@ func (suite *CCVTestSuite) TestOnRecvSlashPacketErrors() {

// Expect panic if ccv channel is not established via dest channel of packet
suite.Panics(func() {
providerKeeper.OnRecvSlashPacket(ctx, channeltypes.Packet{}, ccv.SlashPacketData{})
_, _ = providerKeeper.OnRecvSlashPacket(ctx, channeltypes.Packet{}, ccv.SlashPacketData{})
})

// Add correct channelID to packet. Now we will not panic anymore.
packet := channeltypes.Packet{DestinationChannel: firstBundle.Path.EndpointB.ChannelID}
suite.NotPanics(func() {
_, _ = providerKeeper.OnRecvSlashPacket(ctx, packet, ccv.SlashPacketData{})
})

// Init chain height is set by established CCV channel
// Delete init chain height and confirm expected error
initChainHeight, found := providerKeeper.GetInitChainHeight(ctx, consumerChainID)
suite.Require().True(found)
providerKeeper.DeleteInitChainHeight(ctx, consumerChainID)

packetData := ccv.SlashPacketData{ValsetUpdateId: 0}
errAck := providerKeeper.OnRecvSlashPacket(ctx, packet, packetData)
suite.Require().False(errAck.Success())
errAckCast := errAck.(channeltypes.Acknowledgement)
// Error strings in err acks are now thrown out by IBC core to prevent app hash.
// Hence a generic error string is expected.
suite.Require().Equal("ABCI code: 1: error handling packet: see events for details", errAckCast.GetError())

// Restore init chain height
providerKeeper.SetInitChainHeight(ctx, consumerChainID, initChainHeight)

// now the method will fail at infraction height check.
packetData.Infraction = stakingtypes.Infraction_INFRACTION_UNSPECIFIED
errAck = providerKeeper.OnRecvSlashPacket(ctx, packet, packetData)
suite.Require().False(errAck.Success())
errAckCast = errAck.(channeltypes.Acknowledgement)
suite.Require().Equal("ABCI code: 1: error handling packet: see events for details", errAckCast.GetError())

// save current VSC ID
vscID := providerKeeper.GetValidatorSetUpdateId(ctx)

// remove block height value mapped to current VSC ID
providerKeeper.DeleteValsetUpdateBlockHeight(ctx, vscID)

// Instantiate packet data with current VSC ID
packetData = ccv.SlashPacketData{ValsetUpdateId: vscID}

// expect an error if mapped block height is not found
errAck = providerKeeper.OnRecvSlashPacket(ctx, packet, packetData)
suite.Require().False(errAck.Success())
errAckCast = errAck.(channeltypes.Acknowledgement)
suite.Require().Equal("ABCI code: 1: error handling packet: see events for details", errAckCast.GetError())

// construct slashing packet with non existing validator
slashingPkt := ccv.NewSlashPacketData(
// Check ValidateBasic for SlashPacket data
validAddress := ed25519.GenPrivKey().PubKey().Address()
slashPacketData := ccv.NewSlashPacketData(
abci.Validator{
Address: ed25519.GenPrivKey().PubKey().Address(),
Power: int64(0),
Address: validAddress,
Power: int64(1),
}, uint64(0), stakingtypes.Infraction_INFRACTION_DOWNTIME,
)

// Set initial block height for consumer chain
providerKeeper.SetInitChainHeight(ctx, consumerChainID, uint64(ctx.BlockHeight()))

// Expect no error ack if validator does not exist
// TODO: this behavior should be changed to return an error ack,
// see: https://github.com/cosmos/interchain-security/issues/546
ack := providerKeeper.OnRecvSlashPacket(ctx, packet, *slashingPkt)
suite.Require().True(ack.Success())

val := suite.providerChain.Vals.Validators[0]

// commit block to set VSC ID
suite.coordinator.CommitBlock(suite.providerChain)
// Update suite.ctx bc CommitBlock updates only providerChain's current header block height
ctx = suite.providerChain.GetContext()
suite.Require().NotZero(providerKeeper.GetValsetUpdateBlockHeight(ctx, vscID))

// create validator signing info
valInfo := slashingtypes.NewValidatorSigningInfo(sdk.ConsAddress(val.Address), ctx.BlockHeight(),
ctx.BlockHeight()-1, time.Time{}.UTC(), false, int64(0))
providerSlashingKeeper.SetValidatorSigningInfo(ctx, sdk.ConsAddress(val.Address), valInfo)

// update validator address and VSC ID
slashingPkt.Validator.Address = val.Address
slashingPkt.ValsetUpdateId = vscID

// expect error ack when infraction type in unspecified
tmAddr := suite.providerChain.Vals.Validators[1].Address
slashingPkt.Validator.Address = tmAddr
slashingPkt.Infraction = stakingtypes.Infraction_INFRACTION_UNSPECIFIED

valInfo.Address = sdk.ConsAddress(tmAddr).String()
providerSlashingKeeper.SetValidatorSigningInfo(ctx, sdk.ConsAddress(tmAddr), valInfo)

errAck = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashingPkt)
suite.Require().False(errAck.Success())

// expect to queue entries for the slash request
slashingPkt.Infraction = stakingtypes.Infraction_INFRACTION_DOWNTIME
ack = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashingPkt)
suite.Require().True(ack.Success())
// Expect an error if validator address is too long
slashPacketData.Validator.Address = make([]byte, sdkaddress.MaxAddrLen+1)
_, err := providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().Error(err, "validating SlashPacket data should fail - invalid validator address")

// Expect an error if validator power is zero
slashPacketData.Validator.Address = validAddress
slashPacketData.Validator.Power = 0
_, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().Error(err, "validating SlashPacket data should fail - invalid validator power")

// Expect an error if the infraction type is unspecified
slashPacketData.Validator.Power = 1
slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_UNSPECIFIED
_, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().Error(err, "validating SlashPacket data should fail - invalid infraction type")

// Restore slashPacketData to be valid
slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_DOWNTIME

// Check ValidateSlashPacket
// Expect an error if a mapping of the infraction height cannot be found;
// just set the vscID of the slash packet to the latest mapped vscID +1
valsetUpdateBlockHeights := providerKeeper.GetAllValsetUpdateBlockHeights(ctx)
latestMappedValsetUpdateId := valsetUpdateBlockHeights[len(valsetUpdateBlockHeights)-1].ValsetUpdateId
slashPacketData.ValsetUpdateId = latestMappedValsetUpdateId + 1
_, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().Error(err, "ValidateSlashPacket should fail - no infraction height mapping")

// Restore slashPacketData to be valid
slashPacketData.ValsetUpdateId = latestMappedValsetUpdateId

// Expect no error if validator does not exist
_, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().NoError(err, "no error expected")

// Check expected behavior for handling SlashPackets for double signing infractions
slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_DOUBLE_SIGN
ackResult, err := providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().NoError(err, "no error expected")
suite.Require().Equal(ccv.V1Result, ackResult, "expected successful ack")

// Check expected behavior for handling SlashPackets for downtime infractions
slashPacketData.Infraction = stakingtypes.Infraction_INFRACTION_DOWNTIME

// Expect the packet to bounce if the slash meter is negative
providerKeeper.SetSlashMeter(ctx, sdk.NewInt(-1))
ackResult, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().NoError(err, "no error expected")
suite.Require().Equal(ccv.SlashPacketBouncedResult, ackResult, "expected successful ack")

// Expect the packet to be handled if the slash meter is positive
providerKeeper.SetSlashMeter(ctx, sdk.NewInt(0))
ackResult, err = providerKeeper.OnRecvSlashPacket(ctx, packet, *slashPacketData)
suite.Require().NoError(err, "no error expected")
suite.Require().Equal(ccv.SlashPacketHandledResult, ackResult, "expected successful ack")
}

// TestValidatorDowntime tests if a slash packet is sent
Expand Down Expand Up @@ -649,6 +630,7 @@ func (suite *CCVTestSuite) TestQueueAndSendSlashPacket() {
addr := ed25519.GenPrivKey().PubKey().Address()
val := abci.Validator{
Address: addr,
Power: int64(1),
}
consumerKeeper.QueueSlashPacket(ctx, val, 0, infraction)
slashedVals = append(slashedVals, slashedVal{validator: val, infraction: infraction})
Expand Down
9 changes: 6 additions & 3 deletions tests/integration/throttle.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ func (s *CCVTestSuite) TestPacketSpam() {
consumerPacketData, err := provider.UnmarshalConsumerPacketData(data) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
packet := s.newPacketFromConsumer(data, uint64(sequence), firstBundle.Path, timeoutHeight, timeoutTimestamp)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
_, err = providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
s.Require().NoError(err)
}

// Execute block
Expand Down Expand Up @@ -411,7 +412,8 @@ func (s *CCVTestSuite) TestDoubleSignDoesNotAffectThrottling() {
consumerPacketData, err := provider.UnmarshalConsumerPacketData(data) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
packet := s.newPacketFromConsumer(data, uint64(sequence), firstBundle.Path, timeoutHeight, timeoutTimestamp)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
_, err = providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
s.Require().NoError(err)
}

// Execute block to handle packets in endblock
Expand Down Expand Up @@ -581,7 +583,8 @@ func (s CCVTestSuite) TestSlashAllValidators() { //nolint:govet // this is a tes
consumerPacketData, err := provider.UnmarshalConsumerPacketData(data) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
packet := s.newPacketFromConsumer(data, ibcSeqNum, s.getFirstBundle().Path, timeoutHeight, timeoutTimestamp)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
_, err = providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
s.Require().NoError(err)
}

// Check that all validators are jailed.
Expand Down
15 changes: 6 additions & 9 deletions tests/integration/valset_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ func (suite *CCVTestSuite) TestQueueAndSendVSCMaturedPackets() {

// send first packet
packet := suite.newPacketFromProvider(pd.GetBytes(), 1, suite.path, clienttypes.NewHeight(1, 0), 0)
ack := consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().NotNil(ack, "OnRecvVSCPacket did not return ack")
suite.Require().True(ack.Success(), "OnRecvVSCPacket did not return a Success Acknowledgment")
err = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().Nil(err, "OnRecvVSCPacket did return non-nil error")

// increase time
incrementTime(suite, time.Hour)
Expand All @@ -83,9 +82,8 @@ func (suite *CCVTestSuite) TestQueueAndSendVSCMaturedPackets() {
pd.ValsetUpdateId = 2
packet.Data = pd.GetBytes()
packet.Sequence = 2
ack = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().NotNil(ack, "OnRecvVSCPacket did not return ack")
suite.Require().True(ack.Success(), "OnRecvVSCPacket did not return a Success Acknowledgment")
err = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().Nil(err, "OnRecvVSCPacket did return non-nil error")

// increase time
incrementTime(suite, 24*time.Hour)
Expand All @@ -95,9 +93,8 @@ func (suite *CCVTestSuite) TestQueueAndSendVSCMaturedPackets() {
pd.ValsetUpdateId = 3
packet.Data = pd.GetBytes()
packet.Sequence = 3
ack = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().NotNil(ack, "OnRecvVSCPacket did not return ack")
suite.Require().True(ack.Success(), "OnRecvVSCPacket did not return a Success Acknowledgment")
err = consumerKeeper.OnRecvVSCPacket(suite.consumerChain.GetContext(), packet, pd)
suite.Require().Nil(err, "OnRecvVSCPacket did return non-nil error")

packetMaturities := consumerKeeper.GetAllPacketMaturityTimes(suite.consumerChain.GetContext())

Expand Down
44 changes: 34 additions & 10 deletions x/ccv/consumer/ibc_module.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package consumer

import (
"fmt"
"strconv"
"strings"

transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types"
Expand Down Expand Up @@ -225,25 +226,48 @@ func (am AppModule) OnRecvPacket(
packet channeltypes.Packet,
_ sdk.AccAddress,
) ibcexported.Acknowledgement {
var (
ack ibcexported.Acknowledgement
data types.ValidatorSetChangePacketData
)
logger := am.keeper.Logger(ctx)
ack := channeltypes.NewResultAcknowledgement([]byte{byte(1)})

var data types.ValidatorSetChangePacketData
var ackErr error
if err := types.ModuleCdc.UnmarshalJSON(packet.GetData(), &data); err != nil {
errAck := types.NewErrorAcknowledgementWithLog(ctx, fmt.Errorf("cannot unmarshal CCV packet data"))
ack = &errAck
} else {
ack = am.keeper.OnRecvVSCPacket(ctx, packet, data)
ackErr = errorsmod.Wrapf(sdkerrors.ErrInvalidType, "cannot unmarshal VSCPacket data")
logger.Error(fmt.Sprintf("%s sequence %d", ackErr.Error(), packet.Sequence))
ack = channeltypes.NewErrorAcknowledgement(ackErr)
}

// only attempt the application logic if the packet data
// was successfully decoded
if ack.Success() {
err := am.keeper.OnRecvVSCPacket(ctx, packet, data)
if err != nil {
ack = channeltypes.NewErrorAcknowledgement(err)
ackErr = err
logger.Error(fmt.Sprintf("%s sequence %d", ackErr.Error(), packet.Sequence))
} else {
logger.Info("successfully handled VSCPacket sequence: %d", packet.Sequence)
}
}

eventAttributes := []sdk.Attribute{
sdk.NewAttribute(sdk.AttributeKeyModule, types.ModuleName),
sdk.NewAttribute(types.AttributeValSetUpdateID, strconv.Itoa(int(data.ValsetUpdateId))),
sdk.NewAttribute(types.AttributeKeyAckSuccess, fmt.Sprintf("%t", ack.Success())),
}

if ackErr != nil {
eventAttributes = append(eventAttributes, sdk.NewAttribute(types.AttributeKeyAckError, ackErr.Error()))
}

ctx.EventManager().EmitEvent(
sdk.NewEvent(
types.EventTypePacket,
sdk.NewAttribute(sdk.AttributeKeyModule, consumertypes.ModuleName),
sdk.NewAttribute(types.AttributeKeyAckSuccess, fmt.Sprintf("%t", ack != nil)),
eventAttributes...,
),
)

// NOTE: acknowledgement will be written synchronously during IBC handler execution.
return ack
}

Expand Down
11 changes: 7 additions & 4 deletions x/ccv/consumer/keeper/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types"
channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types"
"github.com/cosmos/ibc-go/v7/modules/core/exported"

errorsmod "cosmossdk.io/errors"

Expand All @@ -25,7 +24,12 @@ import (
//
// Note: CCV uses an ordered IBC channel, meaning VSC packet changes will be accumulated (and later
// processed by ApplyCCValidatorChanges) s.t. more recent val power changes overwrite older ones.
func (k Keeper) OnRecvVSCPacket(ctx sdk.Context, packet channeltypes.Packet, newChanges ccv.ValidatorSetChangePacketData) exported.Acknowledgement {
MSalopek marked this conversation as resolved.
Show resolved Hide resolved
func (k Keeper) OnRecvVSCPacket(ctx sdk.Context, packet channeltypes.Packet, newChanges ccv.ValidatorSetChangePacketData) error {
// validate packet data upon receiving
if err := newChanges.ValidateBasic(); err != nil {
mpoke marked this conversation as resolved.
Show resolved Hide resolved
return errorsmod.Wrapf(err, "error validating VSCPacket data")
}

// get the provider channel
providerChannel, found := k.GetProviderChannel(ctx)
if found && providerChannel != packet.DestinationChannel {
Expand Down Expand Up @@ -87,8 +91,7 @@ func (k Keeper) OnRecvVSCPacket(ctx sdk.Context, packet channeltypes.Packet, new
"len updates", len(newChanges.ValidatorUpdates),
"len slash acks", len(newChanges.SlashAcks),
)
ack := channeltypes.NewResultAcknowledgement([]byte{byte(1)})
return ack
return nil
}

// QueueVSCMaturedPackets appends matured VSCs to an internal queue.
Expand Down
Loading
Loading