From 441bcee2f8161e6c8b0d07a30f5f89713da57bbb Mon Sep 17 00:00:00 2001 From: Simon Noetzlin Date: Wed, 22 Nov 2023 18:32:39 +0100 Subject: [PATCH] fix #1282 --- x/ccv/provider/keeper/hooks.go | 44 +---------------- x/ccv/provider/keeper/hooks_test.go | 2 +- x/ccv/provider/keeper/keeper.go | 16 ++++++ x/ccv/provider/keeper/key_assignment.go | 52 ++++++++++++++++++++ x/ccv/provider/keeper/key_assignment_test.go | 51 +++++++++++++------ 5 files changed, 105 insertions(+), 60 deletions(-) diff --git a/x/ccv/provider/keeper/hooks.go b/x/ccv/provider/keeper/hooks.go index 3be2be04eb..cfb9090981 100644 --- a/x/ccv/provider/keeper/hooks.go +++ b/x/ccv/provider/keeper/hooks.go @@ -91,50 +91,8 @@ func (h Hooks) AfterUnbondingInitiated(ctx sdk.Context, id uint64) error { return nil } -// ValidatorConsensusKeyInUse is called when a new validator is created -// in the x/staking module of cosmos-sdk. In case it panics, the TX aborts -// and thus, the validator is not created. See AfterValidatorCreated hook. -func ValidatorConsensusKeyInUse(k *Keeper, ctx sdk.Context, valAddr sdk.ValAddress) bool { - // Get the validator being added in the staking module. - val, found := k.stakingKeeper.GetValidator(ctx, valAddr) - if !found { - // Abort TX, do NOT allow validator to be created - panic("did not find newly created validator in staking module") - } - - // Get the consensus address of the validator being added - consensusAddr, err := val.GetConsAddr() - if err != nil { - // Abort TX, do NOT allow validator to be created - panic("could not get validator cons addr ") - } - - allConsumerChains := []string{} - consumerChains := k.GetAllConsumerChains(ctx) - for _, consumerChain := range consumerChains { - allConsumerChains = append(allConsumerChains, consumerChain.ChainId) - } - proposedChains := k.GetAllProposedConsumerChainIDs(ctx) - for _, proposedChain := range proposedChains { - allConsumerChains = append(allConsumerChains, proposedChain.ChainID) - } - pendingChainIDs := k.GetAllPendingConsumerChainIDs(ctx) - allConsumerChains = append(allConsumerChains, pendingChainIDs...) - - for _, c := range allConsumerChains { - if _, exist := k.GetValidatorByConsumerAddr( - ctx, - c, - providertypes.NewConsumerConsAddress(consensusAddr)); exist { - return true - } - } - - return false -} - func (h Hooks) AfterValidatorCreated(ctx sdk.Context, valAddr sdk.ValAddress) error { - if ValidatorConsensusKeyInUse(h.k, ctx, valAddr) { + if h.k.ValidatorConsensusKeyInUse(ctx, valAddr) { // Abort TX, do NOT allow validator to be created panic("cannot create a validator with a consensus key that is already in use or was recently in use as an assigned consumer chain key") } diff --git a/x/ccv/provider/keeper/hooks_test.go b/x/ccv/provider/keeper/hooks_test.go index 0bdf9b26fa..47175eb907 100644 --- a/x/ccv/provider/keeper/hooks_test.go +++ b/x/ccv/provider/keeper/hooks_test.go @@ -83,7 +83,7 @@ func TestValidatorConsensusKeyInUse(t *testing.T) { tt.setup(ctx, k) t.Run(tt.name, func(t *testing.T) { - if actual := providerkeeper.ValidatorConsensusKeyInUse(&k, ctx, newValidator.SDKStakingValidator().GetOperator()); actual != tt.expect { + if actual := k.ValidatorConsensusKeyInUse(ctx, newValidator.SDKStakingValidator().GetOperator()); actual != tt.expect { t.Errorf("validatorConsensusKeyInUse() = %v, want %v", actual, tt.expect) } }) diff --git a/x/ccv/provider/keeper/keeper.go b/x/ccv/provider/keeper/keeper.go index 0b9fa0420d..f0c6195cfe 100644 --- a/x/ccv/provider/keeper/keeper.go +++ b/x/ccv/provider/keeper/keeper.go @@ -1123,3 +1123,19 @@ func (k Keeper) GetSlashLog( func (k Keeper) BondDenom(ctx sdk.Context) string { return k.stakingKeeper.BondDenom(ctx) } + +func (k Keeper) GetAllRegisteredAndProposedChainIDs(ctx sdk.Context) []string { + allConsumerChains := []string{} + consumerChains := k.GetAllConsumerChains(ctx) + for _, consumerChain := range consumerChains { + allConsumerChains = append(allConsumerChains, consumerChain.ChainId) + } + proposedChains := k.GetAllProposedConsumerChainIDs(ctx) + for _, proposedChain := range proposedChains { + allConsumerChains = append(allConsumerChains, proposedChain.ChainID) + } + pendingChainIDs := k.GetAllPendingConsumerChainIDs(ctx) + allConsumerChains = append(allConsumerChains, pendingChainIDs...) + + return allConsumerChains +} diff --git a/x/ccv/provider/keeper/key_assignment.go b/x/ccv/provider/keeper/key_assignment.go index d440848bbf..49d940a947 100644 --- a/x/ccv/provider/keeper/key_assignment.go +++ b/x/ccv/provider/keeper/key_assignment.go @@ -377,6 +377,14 @@ func (k Keeper) AssignConsumerKey( validator stakingtypes.Validator, consumerKey tmprotocrypto.PublicKey, ) error { + // check that the consumer chain is either registered or that + // ConsumerAdditionProposal was voted on. + if !k.CheckIfConsumerIsProposedOrRegistered(ctx, chainID) { + return errorsmod.Wrapf( + types.ErrUnknownConsumerChainId, chainID, + ) + } + consAddrTmp, err := ccvtypes.TMCryptoPublicKeyToConsAddr(consumerKey) if err != nil { return err @@ -629,3 +637,47 @@ func (k Keeper) DeleteKeyAssignments(ctx sdk.Context, chainID string) { k.DeleteConsumerAddrsToPrune(ctx, chainID, consumerAddrsToPrune.VscId) } } + +// CheckIfConsumerIsProposedOrRegistered checks if a consumer chain is either registered, meaning either already running +// or will run soon, or proposed its ConsumerAdditionProposal was submitted but the chain was not yet added to ICS yet. +func (k Keeper) CheckIfConsumerIsProposedOrRegistered(ctx sdk.Context, chainID string) bool { + allConsumerChains := k.GetAllRegisteredAndProposedChainIDs(ctx) + for _, c := range allConsumerChains { + if c == chainID { + return true + } + } + + return false +} + +// ValidatorConsensusKeyInUse checks if the given consensus key is already +// used by validator in a consumer chain. +// Note that this method is called when a new validator is created in the x/staking module of cosmos-sdk. +// In case it panics, the TX aborts and thus, the validator is not created. See AfterValidatorCreated hook. +func (k Keeper) ValidatorConsensusKeyInUse(ctx sdk.Context, valAddr sdk.ValAddress) bool { + // Get the validator being added in the staking module. + val, found := k.stakingKeeper.GetValidator(ctx, valAddr) + if !found { + // Abort TX, do NOT allow validator to be created + panic("did not find newly created validator in staking module") + } + + // Get the consensus address of the validator being added + consensusAddr, err := val.GetConsAddr() + if err != nil { + // Abort TX, do NOT allow validator to be created + panic("could not get validator cons addr ") + } + + allConsumerChains := k.GetAllRegisteredAndProposedChainIDs(ctx) + for _, c := range allConsumerChains { + if _, exist := k.GetValidatorByConsumerAddr(ctx, + c, + types.NewConsumerConsAddress(consensusAddr), + ); exist { + return true + } + } + return false +} diff --git a/x/ccv/provider/keeper/key_assignment_test.go b/x/ccv/provider/keeper/key_assignment_test.go index e9cb1dd646..b9a20e6b1b 100644 --- a/x/ccv/provider/keeper/key_assignment_test.go +++ b/x/ccv/provider/keeper/key_assignment_test.go @@ -389,17 +389,32 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { doActions func(sdk.Context, providerkeeper.Keeper) }{ /* - 0. Consumer registered: Assign PK0->CK0 and retrieve PK0->CK0 - 1. Consumer registered: Assign PK0->CK0, PK0->CK1 and retrieve PK0->CK1 - 2. Consumer registered: Assign PK0->CK0, PK1->CK0 and error - 3. Consumer registered: Assign PK1->PK0 and error - 4. Consumer not registered: Assign PK0->CK0 and retrieve PK0->CK0 - 5. Consumer not registered: Assign PK0->CK0, PK0->CK1 and retrieve PK0->CK1 - 6. Consumer not registered: Assign PK0->CK0, PK1->CK0 and error - 7. Consumer not registered: Assign PK1->PK0 and error + 0. Consumer not registered: Assign PK0->CK0 and error + 1. Consumer registered: Assign PK0->CK0 and retrieve PK0->CK0 + 2. Consumer registered: Assign PK0->CK0, PK0->CK1 and retrieve PK0->CK1 + 3. Consumer registered: Assign PK0->CK0, PK1->CK0 and error + 4. Consumer registered: Assign PK1->PK0 and error + 5. Consumer proposed: Assign Assign PK0->CK0 and retrieve PK0->CK0 + 6. Consumer proposed: Assign PK0->CK0, PK0->CK1 and retrieve PK0->CK1 + 7. Consumer proposed: Assign PK0->CK0, PK1->CK0 and error + 8. Consumer proposed: Assign PK1->PK0 and error */ { - name: "0", + name: "0", + mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {}, + doActions: func(ctx sdk.Context, k providerkeeper.Keeper) { + err := k.AssignConsumerKey(ctx, chainID, + providerIdentities[0].SDKStakingValidator(), + consumerIdentities[0].TMProtoCryptoPublicKey(), + ) + require.Error(t, err) + _, found := k.GetValidatorByConsumerAddr(ctx, chainID, + consumerIdentities[0].ConsumerConsAddress()) + require.False(t, found) + }, + }, + { + name: "1", mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { gomock.InOrder( mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, @@ -424,7 +439,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { }, }, { - name: "1", + name: "2", mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { gomock.InOrder( mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, @@ -460,7 +475,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { }, }, { - name: "2", + name: "3", mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { gomock.InOrder( mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, @@ -493,7 +508,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { }, }, { - name: "3", + name: "4", mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { gomock.InOrder( mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, @@ -511,7 +526,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { }, }, { - name: "4", + name: "5", mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { gomock.InOrder( mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, @@ -520,6 +535,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { ) }, doActions: func(ctx sdk.Context, k providerkeeper.Keeper) { + k.SetProposedConsumerChain(ctx, chainID, 0) err := k.AssignConsumerKey(ctx, chainID, providerIdentities[0].SDKStakingValidator(), consumerIdentities[0].TMProtoCryptoPublicKey(), @@ -532,7 +548,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { }, }, { - name: "5", + name: "6", mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { gomock.InOrder( mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, @@ -544,6 +560,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { ) }, doActions: func(ctx sdk.Context, k providerkeeper.Keeper) { + k.SetProposedConsumerChain(ctx, chainID, 0) err := k.AssignConsumerKey(ctx, chainID, providerIdentities[0].SDKStakingValidator(), consumerIdentities[0].TMProtoCryptoPublicKey(), @@ -561,7 +578,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { }, }, { - name: "6", + name: "7", mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { gomock.InOrder( mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, @@ -573,6 +590,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { ) }, doActions: func(ctx sdk.Context, k providerkeeper.Keeper) { + k.SetProposedConsumerChain(ctx, chainID, 0) err := k.AssignConsumerKey(ctx, chainID, providerIdentities[0].SDKStakingValidator(), consumerIdentities[0].TMProtoCryptoPublicKey(), @@ -590,7 +608,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { }, }, { - name: "7", + name: "8", mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { gomock.InOrder( mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx, @@ -599,6 +617,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) { ) }, doActions: func(ctx sdk.Context, k providerkeeper.Keeper) { + k.SetProposedConsumerChain(ctx, chainID, 0) err := k.AssignConsumerKey(ctx, chainID, providerIdentities[1].SDKStakingValidator(), providerIdentities[0].TMProtoCryptoPublicKey(),