diff --git a/testutil/keeper/unit_test_helpers.go b/testutil/keeper/unit_test_helpers.go index c5540f4601..c98275cffc 100644 --- a/testutil/keeper/unit_test_helpers.go +++ b/testutil/keeper/unit_test_helpers.go @@ -207,7 +207,10 @@ func GetNewVSCMaturedPacketData() types.VSCMaturedPacketData { } // SetupForStoppingConsumerChain registers expected mock calls and corresponding state setup -// which asserts that a consumer chain was properly stopped from StopConsumerChain(). +// which assert that a consumer chain was properly setup to be later stopped from `StopConsumerChain`. +// Note: This function only setups and tests that we correctly setup a consumer chain that we could later stop when +// calling `StopConsumerChain` -- this does NOT necessarily mean that the consumer chain is stopped. +// Also see `SetupForStoppingConsumerChainChannel` and `TestProviderStateIsCleanedAfterConsumerChainIsStopped`. func SetupForStoppingConsumerChain(t *testing.T, ctx sdk.Context, providerKeeper *providerkeeper.Keeper, mocks MockedKeepers, ) { @@ -215,7 +218,6 @@ func SetupForStoppingConsumerChain(t *testing.T, ctx sdk.Context, expectations := GetMocksForCreateConsumerClient(ctx, &mocks, "chainID", clienttypes.NewHeight(4, 5)) expectations = append(expectations, GetMocksForSetConsumerChain(ctx, &mocks, "chainID")...) - expectations = append(expectations, GetMocksForStopConsumerChain(ctx, &mocks)...) gomock.InOrder(expectations...) @@ -226,6 +228,50 @@ func SetupForStoppingConsumerChain(t *testing.T, ctx sdk.Context, require.NoError(t, err) } +// SetupForStoppingConsumerChainChannel registers expected mock calls which assert that the channel to the consumer +// chain is closed. To be used when we test that `StopConsumerChain` is called with `closeChan` set to `true`. +func SetupForStoppingConsumerChainChannel(t *testing.T, ctx sdk.Context, mocks MockedKeepers) { + t.Helper() + gomock.InOrder(GetMocksForStopConsumerChain(ctx, &mocks)...) +} + +// TestProviderStateIsCleanedAfterConsumerChainIsStopped executes test assertions for the provider's state being cleaned +// after a stopped consumer chain. +func TestProviderStateIsCleanedAfterConsumerChainIsStopped(t *testing.T, ctx sdk.Context, providerKeeper providerkeeper.Keeper, + expectedChainID, expectedChannelID string, +) { + t.Helper() + _, found := providerKeeper.GetConsumerClientId(ctx, expectedChainID) + require.False(t, found) + _, found = providerKeeper.GetChainToChannel(ctx, expectedChainID) + require.False(t, found) + _, found = providerKeeper.GetChannelToChain(ctx, expectedChannelID) + require.False(t, found) + _, found = providerKeeper.GetInitChainHeight(ctx, expectedChainID) + require.False(t, found) + acks := providerKeeper.GetSlashAcks(ctx, expectedChainID) + require.Empty(t, acks) + _, found = providerKeeper.GetInitTimeoutTimestamp(ctx, expectedChainID) + require.False(t, found) + + require.Empty(t, providerKeeper.GetAllVscSendTimestamps(ctx, expectedChainID)) + + // test key assignment state is cleaned + require.Empty(t, providerKeeper.GetAllValidatorConsumerPubKeys(ctx, &expectedChainID)) + require.Empty(t, providerKeeper.GetAllValidatorsByConsumerAddr(ctx, &expectedChainID)) + require.Empty(t, providerKeeper.GetAllKeyAssignmentReplacements(ctx, expectedChainID)) + require.Empty(t, providerKeeper.GetAllConsumerAddrsToPrune(ctx, expectedChainID)) + + allGlobalEntries := providerKeeper.GetAllGlobalSlashEntries(ctx) + for _, entry := range allGlobalEntries { + require.NotEqual(t, expectedChainID, entry.ConsumerChainID) + } + + slashPacketData, vscMaturedPacketData, _, _ := providerKeeper.GetAllThrottledPacketData(ctx, expectedChainID) + require.Empty(t, slashPacketData) + require.Empty(t, vscMaturedPacketData) +} + func GetTestConsumerAdditionProp() *providertypes.ConsumerAdditionProposal { prop := providertypes.NewConsumerAdditionProposal( "chainID", diff --git a/x/ccv/provider/keeper/proposal_test.go b/x/ccv/provider/keeper/proposal_test.go index c84369e815..342e22078c 100644 --- a/x/ccv/provider/keeper/proposal_test.go +++ b/x/ccv/provider/keeper/proposal_test.go @@ -473,6 +473,7 @@ func TestHandleConsumerRemovalProposal(t *testing.T) { // meaning no external keeper methods are allowed to be called. if tc.expAppendProp { testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks) + testkeeper.SetupForStoppingConsumerChainChannel(t, ctx, mocks) } tc.setupMocks(ctx, providerKeeper, tc.prop.ChainId) @@ -527,6 +528,7 @@ func TestStopConsumerChain(t *testing.T) { description: "valid stop of consumer chain, throttle related queues are cleaned", setup: func(ctx sdk.Context, providerKeeper *providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { testkeeper.SetupForStoppingConsumerChain(t, ctx, providerKeeper, mocks) + testkeeper.SetupForStoppingConsumerChainChannel(t, ctx, mocks) providerKeeper.QueueGlobalSlashEntry(ctx, providertypes.NewGlobalSlashEntry( ctx.BlockTime(), "chainID", 1, cryptoutil.NewCryptoIdentityFromIntSeed(90).ProviderConsAddress())) @@ -546,6 +548,7 @@ func TestStopConsumerChain(t *testing.T) { description: "valid stop of consumer chain, all mock calls hit", setup: func(ctx sdk.Context, providerKeeper *providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { testkeeper.SetupForStoppingConsumerChain(t, ctx, providerKeeper, mocks) + testkeeper.SetupForStoppingConsumerChainChannel(t, ctx, mocks) }, expErr: false, }, @@ -569,48 +572,12 @@ func TestStopConsumerChain(t *testing.T) { require.NoError(t, err) } - testProviderStateIsCleaned(t, ctx, providerKeeper, "chainID", "channelID") + testkeeper.TestProviderStateIsCleanedAfterConsumerChainIsStopped(t, ctx, providerKeeper, "chainID", "channelID") ctrl.Finish() } } -// testProviderStateIsCleaned executes test assertions for the proposer's state being cleaned after a stopped consumer chain. -func testProviderStateIsCleaned(t *testing.T, ctx sdk.Context, providerKeeper providerkeeper.Keeper, - expectedChainID, expectedChannelID string, -) { - t.Helper() - _, found := providerKeeper.GetConsumerClientId(ctx, expectedChainID) - require.False(t, found) - _, found = providerKeeper.GetChainToChannel(ctx, expectedChainID) - require.False(t, found) - _, found = providerKeeper.GetChannelToChain(ctx, expectedChannelID) - require.False(t, found) - _, found = providerKeeper.GetInitChainHeight(ctx, expectedChainID) - require.False(t, found) - acks := providerKeeper.GetSlashAcks(ctx, expectedChainID) - require.Empty(t, acks) - _, found = providerKeeper.GetInitTimeoutTimestamp(ctx, expectedChainID) - require.False(t, found) - - require.Empty(t, providerKeeper.GetAllVscSendTimestamps(ctx, expectedChainID)) - - // test key assignment state is cleaned - require.Empty(t, providerKeeper.GetAllValidatorConsumerPubKeys(ctx, &expectedChainID)) - require.Empty(t, providerKeeper.GetAllValidatorsByConsumerAddr(ctx, &expectedChainID)) - require.Empty(t, providerKeeper.GetAllKeyAssignmentReplacements(ctx, expectedChainID)) - require.Empty(t, providerKeeper.GetAllConsumerAddrsToPrune(ctx, expectedChainID)) - - allGlobalEntries := providerKeeper.GetAllGlobalSlashEntries(ctx) - for _, entry := range allGlobalEntries { - require.NotEqual(t, expectedChainID, entry.ConsumerChainID) - } - - slashPacketData, vscMaturedPacketData, _, _ := providerKeeper.GetAllThrottledPacketData(ctx, expectedChainID) - require.Empty(t, slashPacketData) - require.Empty(t, vscMaturedPacketData) -} - // TestPendingConsumerRemovalPropDeletion tests the getting/setting // and deletion methods for pending consumer removal props func TestPendingConsumerRemovalPropDeletion(t *testing.T) { diff --git a/x/ccv/provider/keeper/relay_test.go b/x/ccv/provider/keeper/relay_test.go index b266211a42..798238c8f9 100644 --- a/x/ccv/provider/keeper/relay_test.go +++ b/x/ccv/provider/keeper/relay_test.go @@ -1,6 +1,7 @@ package keeper_test import ( + "strings" "testing" "time" @@ -715,3 +716,72 @@ func TestSendVSCPacketsToChainFailure(t *testing.T) { // Pending VSC packets should be deleted in StopConsumerChain require.Empty(t, providerKeeper.GetPendingVSCPackets(ctx, "consumerChainID")) } + +// TestOnTimeoutPacketWithNoChainFound tests the `OnTimeoutPacket` method fails when no chain is found +func TestOnTimeoutPacketWithNoChainFound(t *testing.T) { + // Keeper setup + providerKeeper, ctx, ctrl, _ := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) + defer ctrl.Finish() + + // We do not `SetChannelToChain` for "channelID" and therefore `OnTimeoutPacket` fails + packet := channeltypes.Packet{ + SourceChannel: "channelID", + } + err := providerKeeper.OnTimeoutPacket(ctx, packet) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), channeltypes.ErrInvalidChannel.Error())) +} + +// TestOnTimeoutPacketStopsChain tests that the chain is stopped in case of a timeout +func TestOnTimeoutPacketStopsChain(t *testing.T) { + // Keeper setup + providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) + defer ctrl.Finish() + providerKeeper.SetParams(ctx, providertypes.DefaultParams()) + + testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks) + + packet := channeltypes.Packet{ + SourceChannel: "channelID", + } + err := providerKeeper.OnTimeoutPacket(ctx, packet) + + testkeeper.TestProviderStateIsCleanedAfterConsumerChainIsStopped(t, ctx, providerKeeper, "chainID", "channelID") + require.NoError(t, err) +} + +// TestOnAcknowledgementPacketWithNoAckError tests `OnAcknowledgementPacket` when the underlying ack contains no error +func TestOnAcknowledgementPacketWithNoAckError(t *testing.T) { + // Keeper setup + providerKeeper, ctx, ctrl, _ := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) + defer ctrl.Finish() + + ack := channeltypes.Acknowledgement{Response: &channeltypes.Acknowledgement_Result{Result: []byte{}}} + err := providerKeeper.OnAcknowledgementPacket(ctx, channeltypes.Packet{}, ack) + require.NoError(t, err) +} + +// TestOnAcknowledgementPacketWithAckError tests `OnAcknowledgementPacket` when the underlying ack contains an error +func TestOnAcknowledgementPacketWithAckError(t *testing.T) { + // Keeper setup + providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) + defer ctrl.Finish() + providerKeeper.SetParams(ctx, providertypes.DefaultParams()) + + // test that `OnAcknowledgementPacket` returns an error if the ack contains an error and the channel is unknown + ackError := channeltypes.Acknowledgement{Response: &channeltypes.Acknowledgement_Error{Error: "some error"}} + err := providerKeeper.OnAcknowledgementPacket(ctx, channeltypes.Packet{}, ackError) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), providertypes.ErrUnknownConsumerChannelId.Error())) + + // test that we stop the consumer chain when `OnAcknowledgementPacket` returns an error and the chain is found + testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks) + packet := channeltypes.Packet{ + SourceChannel: "channelID", + } + + err = providerKeeper.OnAcknowledgementPacket(ctx, packet, ackError) + + testkeeper.TestProviderStateIsCleanedAfterConsumerChainIsStopped(t, ctx, providerKeeper, "chainID", "channelID") + require.NoError(t, err) +} diff --git a/x/ccv/provider/proposal_handler_test.go b/x/ccv/provider/proposal_handler_test.go index d3707d8c28..6d31c61638 100644 --- a/x/ccv/provider/proposal_handler_test.go +++ b/x/ccv/provider/proposal_handler_test.go @@ -107,6 +107,7 @@ func TestProviderProposalHandler(t *testing.T) { case tc.expValidConsumerRemoval: testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks) + testkeeper.SetupForStoppingConsumerChainChannel(t, ctx, mocks) case tc.expValidEquivocation: providerKeeper.SetSlashLog(ctx, providertypes.NewProviderConsAddress(equivocation.GetConsensusAddress()))