diff --git a/tests/integration/throttle.go b/tests/integration/throttle.go index 92800cd1bb..997765a2c2 100644 --- a/tests/integration/throttle.go +++ b/tests/integration/throttle.go @@ -315,8 +315,14 @@ func (s *CCVTestSuite) TestPacketSpam() { // Recv 500 packets from consumer to provider in same block for _, packet := range packets { consumerPacketData := ccvtypes.ConsumerPacketData{} - ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketData) - providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + consumerPacketDataV1 := ccvtypes.ConsumerPacketDataV1{} + err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacketData) + if err == nil { + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + } else { + ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketDataV1) + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketDataV1.GetSlashPacketData().FromV1()) + } } // Execute block to handle packets in endblock @@ -369,8 +375,14 @@ func (s *CCVTestSuite) TestDoubleSignDoesNotAffectThrottling() { // Recv 500 packets from consumer to provider in same block for _, packet := range packets { consumerPacketData := ccvtypes.ConsumerPacketData{} - ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketData) - providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + consumerPacketDataV1 := ccvtypes.ConsumerPacketDataV1{} + err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacketData) + if err == nil { + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + } else { + ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketDataV1) + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketDataV1.GetSlashPacketData().FromV1()) + } } // Execute block to handle packets in endblock @@ -465,7 +477,17 @@ func (s *CCVTestSuite) TestQueueOrdering() { // Recv 500 packets from consumer to provider in same block for i, packet := range packets { consumerPacketData := ccvtypes.ConsumerPacketData{} - ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketData) + consumerPacketDataV1 := ccvtypes.ConsumerPacketDataV1{} + err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacketData) + if err != nil { + ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketDataV1) + consumerPacketData = ccvtypes.ConsumerPacketData{ + Type: consumerPacketDataV1.Type, + Data: &ccvtypes.ConsumerPacketData_SlashPacketData{ + SlashPacketData: consumerPacketDataV1.GetSlashPacketData().FromV1(), + }, + } + } // Type depends on index packets were appended from above if (i+5)%10 == 0 { vscMaturedPacketData := consumerPacketData.GetVscMaturedPacketData() @@ -679,8 +701,14 @@ func (s *CCVTestSuite) TestSlashSameValidator() { // Recv and queue all slash packets. for _, packet := range packets { consumerPacketData := ccvtypes.ConsumerPacketData{} - ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketData) - providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + consumerPacketDataV1 := ccvtypes.ConsumerPacketDataV1{} + err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacketData) + if err == nil { + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + } else { + ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketDataV1) + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketDataV1.GetSlashPacketData().FromV1()) + } } // We should have 6 pending slash packet entries queued. @@ -740,8 +768,14 @@ func (s CCVTestSuite) TestSlashAllValidators() { //nolint:govet // this is a tes // Recv and queue all slash packets. for _, packet := range packets { consumerPacketData := ccvtypes.ConsumerPacketData{} - ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketData) - providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + consumerPacketDataV1 := ccvtypes.ConsumerPacketDataV1{} + err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacketData) + if err == nil { + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData()) + } else { + ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketDataV1) + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketDataV1.GetSlashPacketData().FromV1()) + } } // We should have 24 pending slash packet entries queued. @@ -787,9 +821,14 @@ func (s *CCVTestSuite) TestLeadingVSCMaturedAreDequeued() { packet := s.constructSlashPacketFromConsumer(*bundle, *s.providerChain.Vals.Validators[0], stakingtypes.Infraction_INFRACTION_DOWNTIME, ibcSeqNum) packetData := ccvtypes.ConsumerPacketData{} - ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &packetData) - providerKeeper.OnRecvSlashPacket(s.providerCtx(), - packet, *packetData.GetSlashPacketData()) + packetDataV1 := ccvtypes.ConsumerPacketDataV1{} + err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &packetData) + if err == nil { + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *packetData.GetSlashPacketData()) + } else { + ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &packetDataV1) + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *packetDataV1.GetSlashPacketData().FromV1()) + } } } @@ -878,9 +917,14 @@ func (s *CCVTestSuite) TestVscMaturedHandledPerBlockLimit() { packet := s.constructSlashPacketFromConsumer(*bundle, *s.providerChain.Vals.Validators[0], stakingtypes.Infraction_INFRACTION_DOWNTIME, ibcSeqNum) packetData := ccvtypes.ConsumerPacketData{} - ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &packetData) - providerKeeper.OnRecvSlashPacket(s.providerCtx(), - packet, *packetData.GetSlashPacketData()) + packetDataV1 := ccvtypes.ConsumerPacketDataV1{} + err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &packetData) + if err == nil { + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *packetData.GetSlashPacketData()) + } else { + ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &packetDataV1) + providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *packetDataV1.GetSlashPacketData().FromV1()) + } } } diff --git a/x/ccv/provider/ibc_module.go b/x/ccv/provider/ibc_module.go index b543c8927e..b9b0cd817d 100644 --- a/x/ccv/provider/ibc_module.go +++ b/x/ccv/provider/ibc_module.go @@ -175,28 +175,49 @@ func (am AppModule) OnRecvPacket( _ sdk.AccAddress, ) ibcexported.Acknowledgement { var ( - ack ibcexported.Acknowledgement - consumerPacket ccv.ConsumerPacketData + ack ibcexported.Acknowledgement + consumerPacket ccv.ConsumerPacketData + consumerPacketV1 ccv.ConsumerPacketDataV1 + isV1Packet bool ) + // unmarshall consumer packet if err := ccv.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacket); err != nil { - errAck := channeltypes.NewErrorAcknowledgement(fmt.Errorf("cannot unmarshal CCV packet data")) - ack = &errAck - } else { - // TODO: call ValidateBasic method on consumer packet data - // See: https://github.com/cosmos/interchain-security/issues/634 - - switch consumerPacket.Type { - case ccv.VscMaturedPacket: - // handle VSCMaturedPacket - ack = am.keeper.OnRecvVSCMaturedPacket(ctx, packet, *consumerPacket.GetVscMaturedPacketData()) - case ccv.SlashPacket: - // handle SlashPacket - ack = am.keeper.OnRecvSlashPacket(ctx, packet, *consumerPacket.GetSlashPacketData()) - default: - errAck := channeltypes.NewErrorAcknowledgement(fmt.Errorf("invalid consumer packet type: %q", consumerPacket.Type)) + // retry for v1 packet type + errV1 := ccv.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacketV1) + if errV1 != nil { + errAck := channeltypes.NewErrorAcknowledgement(fmt.Errorf("cannot unmarshal CCV packet data")) ack = &errAck + + ctx.EventManager().EmitEvent( + sdk.NewEvent( + ccv.EventTypePacket, + sdk.NewAttribute(sdk.AttributeKeyModule, providertypes.ModuleName), + sdk.NewAttribute(ccv.AttributeKeyAckSuccess, fmt.Sprintf("%t", ack != nil)), + ), + ) + return ack + } + isV1Packet = true + } + + // TODO: call ValidateBasic method on consumer packet data + // See: https://github.com/cosmos/interchain-security/issues/634 + + switch consumerPacket.Type { + case ccv.VscMaturedPacket: + // handle VSCMaturedPacket + ack = am.keeper.OnRecvVSCMaturedPacket(ctx, packet, *consumerPacket.GetVscMaturedPacketData()) + case ccv.SlashPacket: + // handle SlashPacket + if isV1Packet { + ack = am.keeper.OnRecvSlashPacket(ctx, packet, *consumerPacketV1.GetSlashPacketData().FromV1()) + } else { + ack = am.keeper.OnRecvSlashPacket(ctx, packet, *consumerPacket.GetSlashPacketData()) } + default: + errAck := channeltypes.NewErrorAcknowledgement(fmt.Errorf("invalid consumer packet type: %q", consumerPacket.Type)) + ack = &errAck } ctx.EventManager().EmitEvent( diff --git a/x/ccv/provider/keeper/relay.go b/x/ccv/provider/keeper/relay.go index 1aa0858da6..b9671ce58b 100644 --- a/x/ccv/provider/keeper/relay.go +++ b/x/ccv/provider/keeper/relay.go @@ -400,6 +400,26 @@ func (k Keeper) ValidateSlashPacket(ctx sdk.Context, chainID string, return nil } +// ValidateV1SlashPacket validates a recv slash packet compatible with v1 before it is +// handled or persisted in store. An error is returned if the packet is invalid, +// and an error ack should be relayed to the sender. +func (k Keeper) ValidateV1SlashPacket(ctx sdk.Context, chainID string, + packet channeltypes.Packet, data ccv.SlashPacketDataV1, +) error { + _, found := k.getMappedInfractionHeight(ctx, chainID, data.ValsetUpdateId) + // return error if we cannot find infraction height matching the validator update id + if !found { + return fmt.Errorf("cannot find infraction height matching "+ + "the validator update id %d for chain %s", data.ValsetUpdateId, chainID) + } + + if data.Infraction != ccv.DoubleSign && data.Infraction != ccv.Downtime { + return fmt.Errorf("invalid infraction type: %s", data.Infraction) + } + + return nil +} + // HandleSlashPacket potentially jails a misbehaving validator for a downtime infraction. // This method should NEVER be called with a double-sign infraction. func (k Keeper) HandleSlashPacket(ctx sdk.Context, chainID string, data ccv.SlashPacketData) { diff --git a/x/ccv/types/ccv.go b/x/ccv/types/ccv.go index 4bcb4556c2..e7739b4c2f 100644 --- a/x/ccv/types/ccv.go +++ b/x/ccv/types/ccv.go @@ -93,6 +93,10 @@ func (vdt SlashPacketData) GetBytes() []byte { return valDowntimeBytes } +func (vdt SlashPacketData) ToV1() *SlashPacketDataV1 { + return NewSlashPacketDataV1(vdt.Validator, vdt.ValsetUpdateId, vdt.Infraction) +} + func (cp ConsumerPacketData) ValidateBasic() (err error) { switch cp.Type { case VscMaturedPacket: @@ -140,3 +144,21 @@ func (cp ConsumerPacketData) ToV1Bytes() []byte { bytes := ModuleCdc.MustMarshalJSON(&cpv1) return bytes } + +// FromV1 converts SlashPacketDataV1 to SlashPacketData. +// Provider must handle both V1 and later versions of the SlashPacketData. +func (vdt1 SlashPacketDataV1) FromV1() *SlashPacketData { + newType := stakingtypes.Infraction_INFRACTION_UNSPECIFIED + switch vdt1.Infraction { + case Downtime: + newType = stakingtypes.Infraction_INFRACTION_DOWNTIME + case DoubleSign: + newType = stakingtypes.Infraction_INFRACTION_DOUBLE_SIGN + } + + return &SlashPacketData{ + Validator: vdt1.Validator, + ValsetUpdateId: vdt1.ValsetUpdateId, + Infraction: newType, + } +}