diff --git a/x/ccv/provider/types/msg.go b/x/ccv/provider/types/msg.go index c47c62e0a0..f9b84af0e4 100644 --- a/x/ccv/provider/types/msg.go +++ b/x/ccv/provider/types/msg.go @@ -103,10 +103,14 @@ func (msg MsgAssignConsumerKey) ValidateBasic() error { if 128 < len(msg.ChainId) { return errorsmod.Wrapf(ErrInvalidConsumerChainID, "chainId cannot exceed 128 length") } - _, err := sdk.ValAddressFromBech32(msg.ProviderAddr) + valAddr, err := sdk.ValAddressFromBech32(msg.ProviderAddr) if err != nil { return ErrInvalidProviderAddress } + // Check that the provider validator address and the signer address are the same + if sdk.AccAddress(valAddr.Bytes()).String() != msg.Signer { + return errorsmod.Wrapf(ErrInvalidProviderAddress, "provider validator address must be the same as the signer address") + } if msg.ConsumerKey == "" { return ErrInvalidConsumerConsensusPubKey } @@ -357,11 +361,14 @@ func (msg MsgOptIn) ValidateBasic() error { if 128 < len(msg.ChainId) { return errorsmod.Wrapf(ErrInvalidConsumerChainID, "chainId cannot exceed 128 length") } - _, err := sdk.ValAddressFromBech32(msg.ProviderAddr) + valAddr, err := sdk.ValAddressFromBech32(msg.ProviderAddr) if err != nil { return ErrInvalidProviderAddress } - + // Check that the provider validator address and the signer address are the same + if sdk.AccAddress(valAddr.Bytes()).String() != msg.Signer { + return errorsmod.Wrapf(ErrInvalidProviderAddress, "provider validator address must be the same as the signer address") + } if msg.ConsumerKey != "" { if _, _, err := ParseConsumerKeyFromJson(msg.ConsumerKey); err != nil { return ErrInvalidConsumerConsensusPubKey @@ -415,10 +422,15 @@ func (msg MsgOptOut) ValidateBasic() error { if 128 < len(msg.ChainId) { return errorsmod.Wrapf(ErrInvalidConsumerChainID, "chainId cannot exceed 128 length") } - _, err := sdk.ValAddressFromBech32(msg.ProviderAddr) + valAddr, err := sdk.ValAddressFromBech32(msg.ProviderAddr) if err != nil { return ErrInvalidProviderAddress } + // Check that the provider validator address and the signer address are the same + if sdk.AccAddress(valAddr.Bytes()).String() != msg.Signer { + return errorsmod.Wrapf(ErrInvalidProviderAddress, "provider validator address must be the same as the signer address") + } + return nil } @@ -444,15 +456,17 @@ func (msg MsgSetConsumerCommissionRate) ValidateBasic() error { if strings.TrimSpace(msg.ChainId) == "" { return errorsmod.Wrapf(ErrInvalidConsumerChainID, "chainId cannot be blank") } - if 128 < len(msg.ChainId) { return errorsmod.Wrapf(ErrInvalidConsumerChainID, "chainId cannot exceed 128 length") } - _, err := sdk.ValAddressFromBech32(msg.ProviderAddr) + valAddr, err := sdk.ValAddressFromBech32(msg.ProviderAddr) if err != nil { return ErrInvalidProviderAddress } - + // Check that the provider validator address and the signer address are the same + if sdk.AccAddress(valAddr.Bytes()).String() != msg.Signer { + return errorsmod.Wrapf(ErrInvalidProviderAddress, "provider validator address must be the same as the signer address") + } if msg.Rate.IsNegative() || msg.Rate.GT(math.LegacyOneDec()) { return errorsmod.Wrapf(ErrInvalidConsumerCommissionRate, "consumer commission rate should be in the range [0, 1]") } diff --git a/x/ccv/provider/types/msg_test.go b/x/ccv/provider/types/msg_test.go new file mode 100644 index 0000000000..4ef2a7efc8 --- /dev/null +++ b/x/ccv/provider/types/msg_test.go @@ -0,0 +1,89 @@ +package types_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + cryptoutil "github.com/cosmos/interchain-security/v5/testutil/crypto" + "github.com/cosmos/interchain-security/v5/x/ccv/provider/types" + "github.com/stretchr/testify/require" +) + +func TestMsgAssignConsumerKeyValidateBasic(t *testing.T) { + cId1 := cryptoutil.NewCryptoIdentityFromIntSeed(35443543534) + cId2 := cryptoutil.NewCryptoIdentityFromIntSeed(65465464564) + + valOpAddr1 := cId1.SDKValOpAddress() + acc1 := sdk.AccAddress(valOpAddr1.Bytes()).String() + acc2 := sdk.AccAddress(cId2.SDKValOpAddress().Bytes()).String() + + longChainId := "abcdefghijklmnopqrstuvwxyz" + for i := 0; i < 3; i++ { + longChainId += longChainId + } + + testCases := []struct { + name string + chainId string + providerAddr string + signer string + consumerKey string + expErr bool + }{ + { + name: "chain Id empty", + expErr: true, + }, + { + name: "chain Id too long", + chainId: longChainId, + expErr: true, + }, + { + name: "invalid provider address", + chainId: "chainId", + expErr: true, + }, + { + name: "invalid signer address: must be the same as the provider address", + chainId: "chainId", + providerAddr: valOpAddr1.String(), + signer: acc2, + expErr: true, + }, + { + name: "invalid consumer pubkey", + chainId: "chainId", + providerAddr: valOpAddr1.String(), + signer: acc1, + expErr: true, + }, + { + name: "valid assign consumer key msg", + chainId: "chainId", + providerAddr: valOpAddr1.String(), + consumerKey: "{\"@type\": \"/cosmos.crypto.ed25519.PubKey\", \"key\": \"e3BehnEIlGUAnJYn9V8gBXuMh4tXO8xxlxyXD1APGyk=\"}", + signer: acc1, + expErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + msg := types.MsgAssignConsumerKey{ + ChainId: tc.chainId, + ConsumerKey: tc.consumerKey, + ProviderAddr: tc.providerAddr, + Signer: tc.signer, + } + + err := msg.ValidateBasic() + if tc.expErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +}