Skip to content

Commit

Permalink
fix #1282
Browse files Browse the repository at this point in the history
  • Loading branch information
sainoe committed Nov 22, 2023
1 parent 50cb42a commit 441bcee
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 60 deletions.
44 changes: 1 addition & 43 deletions x/ccv/provider/keeper/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,50 +91,8 @@ func (h Hooks) AfterUnbondingInitiated(ctx sdk.Context, id uint64) error {
return nil
}

// ValidatorConsensusKeyInUse is called when a new validator is created
// in the x/staking module of cosmos-sdk. In case it panics, the TX aborts
// and thus, the validator is not created. See AfterValidatorCreated hook.
func ValidatorConsensusKeyInUse(k *Keeper, ctx sdk.Context, valAddr sdk.ValAddress) bool {
// Get the validator being added in the staking module.
val, found := k.stakingKeeper.GetValidator(ctx, valAddr)
if !found {
// Abort TX, do NOT allow validator to be created
panic("did not find newly created validator in staking module")
}

// Get the consensus address of the validator being added
consensusAddr, err := val.GetConsAddr()
if err != nil {
// Abort TX, do NOT allow validator to be created
panic("could not get validator cons addr ")
}

allConsumerChains := []string{}
consumerChains := k.GetAllConsumerChains(ctx)
for _, consumerChain := range consumerChains {
allConsumerChains = append(allConsumerChains, consumerChain.ChainId)
}
proposedChains := k.GetAllProposedConsumerChainIDs(ctx)
for _, proposedChain := range proposedChains {
allConsumerChains = append(allConsumerChains, proposedChain.ChainID)
}
pendingChainIDs := k.GetAllPendingConsumerChainIDs(ctx)
allConsumerChains = append(allConsumerChains, pendingChainIDs...)

for _, c := range allConsumerChains {
if _, exist := k.GetValidatorByConsumerAddr(
ctx,
c,
providertypes.NewConsumerConsAddress(consensusAddr)); exist {
return true
}
}

return false
}

func (h Hooks) AfterValidatorCreated(ctx sdk.Context, valAddr sdk.ValAddress) error {
if ValidatorConsensusKeyInUse(h.k, ctx, valAddr) {
if h.k.ValidatorConsensusKeyInUse(ctx, valAddr) {
// Abort TX, do NOT allow validator to be created
panic("cannot create a validator with a consensus key that is already in use or was recently in use as an assigned consumer chain key")
}
Expand Down
2 changes: 1 addition & 1 deletion x/ccv/provider/keeper/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestValidatorConsensusKeyInUse(t *testing.T) {
tt.setup(ctx, k)

t.Run(tt.name, func(t *testing.T) {
if actual := providerkeeper.ValidatorConsensusKeyInUse(&k, ctx, newValidator.SDKStakingValidator().GetOperator()); actual != tt.expect {
if actual := k.ValidatorConsensusKeyInUse(ctx, newValidator.SDKStakingValidator().GetOperator()); actual != tt.expect {
t.Errorf("validatorConsensusKeyInUse() = %v, want %v", actual, tt.expect)
}
})
Expand Down
16 changes: 16 additions & 0 deletions x/ccv/provider/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -1123,3 +1123,19 @@ func (k Keeper) GetSlashLog(
func (k Keeper) BondDenom(ctx sdk.Context) string {
return k.stakingKeeper.BondDenom(ctx)
}

func (k Keeper) GetAllRegisteredAndProposedChainIDs(ctx sdk.Context) []string {
allConsumerChains := []string{}
consumerChains := k.GetAllConsumerChains(ctx)
for _, consumerChain := range consumerChains {
allConsumerChains = append(allConsumerChains, consumerChain.ChainId)
}
proposedChains := k.GetAllProposedConsumerChainIDs(ctx)
for _, proposedChain := range proposedChains {
allConsumerChains = append(allConsumerChains, proposedChain.ChainID)
}
pendingChainIDs := k.GetAllPendingConsumerChainIDs(ctx)
allConsumerChains = append(allConsumerChains, pendingChainIDs...)

return allConsumerChains
}
52 changes: 52 additions & 0 deletions x/ccv/provider/keeper/key_assignment.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,14 @@ func (k Keeper) AssignConsumerKey(
validator stakingtypes.Validator,
consumerKey tmprotocrypto.PublicKey,
) error {
// check that the consumer chain is either registered or that
// ConsumerAdditionProposal was voted on.
if !k.CheckIfConsumerIsProposedOrRegistered(ctx, chainID) {
return errorsmod.Wrapf(
types.ErrUnknownConsumerChainId, chainID,
)
}

consAddrTmp, err := ccvtypes.TMCryptoPublicKeyToConsAddr(consumerKey)
if err != nil {
return err
Expand Down Expand Up @@ -629,3 +637,47 @@ func (k Keeper) DeleteKeyAssignments(ctx sdk.Context, chainID string) {
k.DeleteConsumerAddrsToPrune(ctx, chainID, consumerAddrsToPrune.VscId)
}
}

// CheckIfConsumerIsProposedOrRegistered checks if a consumer chain is either registered, meaning either already running
// or will run soon, or proposed its ConsumerAdditionProposal was submitted but the chain was not yet added to ICS yet.
func (k Keeper) CheckIfConsumerIsProposedOrRegistered(ctx sdk.Context, chainID string) bool {
allConsumerChains := k.GetAllRegisteredAndProposedChainIDs(ctx)
for _, c := range allConsumerChains {
if c == chainID {
return true
}
}

return false
}

// ValidatorConsensusKeyInUse checks if the given consensus key is already
// used by validator in a consumer chain.
// Note that this method is called when a new validator is created in the x/staking module of cosmos-sdk.
// In case it panics, the TX aborts and thus, the validator is not created. See AfterValidatorCreated hook.
func (k Keeper) ValidatorConsensusKeyInUse(ctx sdk.Context, valAddr sdk.ValAddress) bool {
// Get the validator being added in the staking module.
val, found := k.stakingKeeper.GetValidator(ctx, valAddr)
if !found {
// Abort TX, do NOT allow validator to be created
panic("did not find newly created validator in staking module")
}

// Get the consensus address of the validator being added
consensusAddr, err := val.GetConsAddr()
if err != nil {
// Abort TX, do NOT allow validator to be created
panic("could not get validator cons addr ")
}

allConsumerChains := k.GetAllRegisteredAndProposedChainIDs(ctx)
for _, c := range allConsumerChains {
if _, exist := k.GetValidatorByConsumerAddr(ctx,
c,
types.NewConsumerConsAddress(consensusAddr),
); exist {
return true
}
}
return false
}
51 changes: 35 additions & 16 deletions x/ccv/provider/keeper/key_assignment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,17 +389,32 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
doActions func(sdk.Context, providerkeeper.Keeper)
}{
/*
0. Consumer registered: Assign PK0->CK0 and retrieve PK0->CK0
1. Consumer registered: Assign PK0->CK0, PK0->CK1 and retrieve PK0->CK1
2. Consumer registered: Assign PK0->CK0, PK1->CK0 and error
3. Consumer registered: Assign PK1->PK0 and error
4. Consumer not registered: Assign PK0->CK0 and retrieve PK0->CK0
5. Consumer not registered: Assign PK0->CK0, PK0->CK1 and retrieve PK0->CK1
6. Consumer not registered: Assign PK0->CK0, PK1->CK0 and error
7. Consumer not registered: Assign PK1->PK0 and error
0. Consumer not registered: Assign PK0->CK0 and error
1. Consumer registered: Assign PK0->CK0 and retrieve PK0->CK0
2. Consumer registered: Assign PK0->CK0, PK0->CK1 and retrieve PK0->CK1
3. Consumer registered: Assign PK0->CK0, PK1->CK0 and error
4. Consumer registered: Assign PK1->PK0 and error
5. Consumer proposed: Assign Assign PK0->CK0 and retrieve PK0->CK0
6. Consumer proposed: Assign PK0->CK0, PK0->CK1 and retrieve PK0->CK1
7. Consumer proposed: Assign PK0->CK0, PK1->CK0 and error
8. Consumer proposed: Assign PK1->PK0 and error
*/
{
name: "0",
name: "0",
mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {},
doActions: func(ctx sdk.Context, k providerkeeper.Keeper) {
err := k.AssignConsumerKey(ctx, chainID,
providerIdentities[0].SDKStakingValidator(),
consumerIdentities[0].TMProtoCryptoPublicKey(),
)
require.Error(t, err)
_, found := k.GetValidatorByConsumerAddr(ctx, chainID,
consumerIdentities[0].ConsumerConsAddress())
require.False(t, found)
},
},
{
name: "1",
mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {
gomock.InOrder(
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx,
Expand All @@ -424,7 +439,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
},
},
{
name: "1",
name: "2",
mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {
gomock.InOrder(
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx,
Expand Down Expand Up @@ -460,7 +475,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
},
},
{
name: "2",
name: "3",
mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {
gomock.InOrder(
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx,
Expand Down Expand Up @@ -493,7 +508,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
},
},
{
name: "3",
name: "4",
mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {
gomock.InOrder(
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx,
Expand All @@ -511,7 +526,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
},
},
{
name: "4",
name: "5",
mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {
gomock.InOrder(
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx,
Expand All @@ -520,6 +535,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
)
},
doActions: func(ctx sdk.Context, k providerkeeper.Keeper) {
k.SetProposedConsumerChain(ctx, chainID, 0)
err := k.AssignConsumerKey(ctx, chainID,
providerIdentities[0].SDKStakingValidator(),
consumerIdentities[0].TMProtoCryptoPublicKey(),
Expand All @@ -532,7 +548,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
},
},
{
name: "5",
name: "6",
mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {
gomock.InOrder(
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx,
Expand All @@ -544,6 +560,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
)
},
doActions: func(ctx sdk.Context, k providerkeeper.Keeper) {
k.SetProposedConsumerChain(ctx, chainID, 0)
err := k.AssignConsumerKey(ctx, chainID,
providerIdentities[0].SDKStakingValidator(),
consumerIdentities[0].TMProtoCryptoPublicKey(),
Expand All @@ -561,7 +578,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
},
},
{
name: "6",
name: "7",
mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {
gomock.InOrder(
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx,
Expand All @@ -573,6 +590,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
)
},
doActions: func(ctx sdk.Context, k providerkeeper.Keeper) {
k.SetProposedConsumerChain(ctx, chainID, 0)
err := k.AssignConsumerKey(ctx, chainID,
providerIdentities[0].SDKStakingValidator(),
consumerIdentities[0].TMProtoCryptoPublicKey(),
Expand All @@ -590,7 +608,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
},
},
{
name: "7",
name: "8",
mockSetup: func(ctx sdk.Context, k providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {
gomock.InOrder(
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(ctx,
Expand All @@ -599,6 +617,7 @@ func TestAssignConsensusKeyForConsumerChain(t *testing.T) {
)
},
doActions: func(ctx sdk.Context, k providerkeeper.Keeper) {
k.SetProposedConsumerChain(ctx, chainID, 0)
err := k.AssignConsumerKey(ctx, chainID,
providerIdentities[1].SDKStakingValidator(),
providerIdentities[0].TMProtoCryptoPublicKey(),
Expand Down

0 comments on commit 441bcee

Please sign in to comment.