diff --git a/testutil/keeper/expectations.go b/testutil/keeper/expectations.go index 07d4d320ba..7814fe0fcf 100644 --- a/testutil/keeper/expectations.go +++ b/testutil/keeper/expectations.go @@ -81,8 +81,9 @@ func GetMocksForSetConsumerChain(ctx sdk.Context, mocks *MockedKeepers, } } -// GetMocksForStopConsumerChain returns mock expectations needed to call StopConsumerChain(). -func GetMocksForStopConsumerChain(ctx sdk.Context, mocks *MockedKeepers) []*gomock.Call { +// GetMocksForStopConsumerChainWithCloseChannel returns mock expectations needed to call StopConsumerChain() when +// `closeChan` is true. +func GetMocksForStopConsumerChainWithCloseChannel(ctx sdk.Context, mocks *MockedKeepers) []*gomock.Call { dummyCap := &capabilitytypes.Capability{} return []*gomock.Call{ mocks.MockChannelKeeper.EXPECT().GetChannel(gomock.Any(), types.ProviderPortID, "channelID").Return( diff --git a/testutil/keeper/unit_test_helpers.go b/testutil/keeper/unit_test_helpers.go index c5540f4601..4ccb8a1861 100644 --- a/testutil/keeper/unit_test_helpers.go +++ b/testutil/keeper/unit_test_helpers.go @@ -207,7 +207,10 @@ func GetNewVSCMaturedPacketData() types.VSCMaturedPacketData { } // SetupForStoppingConsumerChain registers expected mock calls and corresponding state setup -// which asserts that a consumer chain was properly stopped from StopConsumerChain(). +// which assert that a consumer chain was properly setup to be later stopped from `StopConsumerChain`. +// Note: This function only setups and tests that we correctly setup a consumer chain that we could later stop when +// calling `StopConsumerChain` -- this does NOT necessarily mean that the consumer chain is stopped. +// Also see `TestProviderStateIsCleanedAfterConsumerChainIsStopped`. func SetupForStoppingConsumerChain(t *testing.T, ctx sdk.Context, providerKeeper *providerkeeper.Keeper, mocks MockedKeepers, ) { @@ -215,7 +218,6 @@ func SetupForStoppingConsumerChain(t *testing.T, ctx sdk.Context, expectations := GetMocksForCreateConsumerClient(ctx, &mocks, "chainID", clienttypes.NewHeight(4, 5)) expectations = append(expectations, GetMocksForSetConsumerChain(ctx, &mocks, "chainID")...) - expectations = append(expectations, GetMocksForStopConsumerChain(ctx, &mocks)...) gomock.InOrder(expectations...) @@ -226,6 +228,43 @@ func SetupForStoppingConsumerChain(t *testing.T, ctx sdk.Context, require.NoError(t, err) } +// TestProviderStateIsCleanedAfterConsumerChainIsStopped executes test assertions for the provider's state being cleaned +// after a stopped consumer chain. +func TestProviderStateIsCleanedAfterConsumerChainIsStopped(t *testing.T, ctx sdk.Context, providerKeeper providerkeeper.Keeper, + expectedChainID, expectedChannelID string, +) { + t.Helper() + _, found := providerKeeper.GetConsumerClientId(ctx, expectedChainID) + require.False(t, found) + _, found = providerKeeper.GetChainToChannel(ctx, expectedChainID) + require.False(t, found) + _, found = providerKeeper.GetChannelToChain(ctx, expectedChannelID) + require.False(t, found) + _, found = providerKeeper.GetInitChainHeight(ctx, expectedChainID) + require.False(t, found) + acks := providerKeeper.GetSlashAcks(ctx, expectedChainID) + require.Empty(t, acks) + _, found = providerKeeper.GetInitTimeoutTimestamp(ctx, expectedChainID) + require.False(t, found) + + require.Empty(t, providerKeeper.GetAllVscSendTimestamps(ctx, expectedChainID)) + + // test key assignment state is cleaned + require.Empty(t, providerKeeper.GetAllValidatorConsumerPubKeys(ctx, &expectedChainID)) + require.Empty(t, providerKeeper.GetAllValidatorsByConsumerAddr(ctx, &expectedChainID)) + require.Empty(t, providerKeeper.GetAllKeyAssignmentReplacements(ctx, expectedChainID)) + require.Empty(t, providerKeeper.GetAllConsumerAddrsToPrune(ctx, expectedChainID)) + + allGlobalEntries := providerKeeper.GetAllGlobalSlashEntries(ctx) + for _, entry := range allGlobalEntries { + require.NotEqual(t, expectedChainID, entry.ConsumerChainID) + } + + slashPacketData, vscMaturedPacketData, _, _ := providerKeeper.GetAllThrottledPacketData(ctx, expectedChainID) + require.Empty(t, slashPacketData) + require.Empty(t, vscMaturedPacketData) +} + func GetTestConsumerAdditionProp() *providertypes.ConsumerAdditionProposal { prop := providertypes.NewConsumerAdditionProposal( "chainID", diff --git a/x/ccv/provider/keeper/proposal_test.go b/x/ccv/provider/keeper/proposal_test.go index c84369e815..0d7085068a 100644 --- a/x/ccv/provider/keeper/proposal_test.go +++ b/x/ccv/provider/keeper/proposal_test.go @@ -473,6 +473,9 @@ func TestHandleConsumerRemovalProposal(t *testing.T) { // meaning no external keeper methods are allowed to be called. if tc.expAppendProp { testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks) + + // assert mocks for expected calls to `StopConsumerChain` when closing the underlying channel + gomock.InOrder(testkeeper.GetMocksForStopConsumerChainWithCloseChannel(ctx, &mocks)...) } tc.setupMocks(ctx, providerKeeper, tc.prop.ChainId) @@ -528,6 +531,9 @@ func TestStopConsumerChain(t *testing.T) { setup: func(ctx sdk.Context, providerKeeper *providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { testkeeper.SetupForStoppingConsumerChain(t, ctx, providerKeeper, mocks) + // assert mocks for expected calls to `StopConsumerChain` when closing the underlying channel + gomock.InOrder(testkeeper.GetMocksForStopConsumerChainWithCloseChannel(ctx, &mocks)...) + providerKeeper.QueueGlobalSlashEntry(ctx, providertypes.NewGlobalSlashEntry( ctx.BlockTime(), "chainID", 1, cryptoutil.NewCryptoIdentityFromIntSeed(90).ProviderConsAddress())) @@ -546,6 +552,9 @@ func TestStopConsumerChain(t *testing.T) { description: "valid stop of consumer chain, all mock calls hit", setup: func(ctx sdk.Context, providerKeeper *providerkeeper.Keeper, mocks testkeeper.MockedKeepers) { testkeeper.SetupForStoppingConsumerChain(t, ctx, providerKeeper, mocks) + + // assert mocks for expected calls to `StopConsumerChain` when closing the underlying channel + gomock.InOrder(testkeeper.GetMocksForStopConsumerChainWithCloseChannel(ctx, &mocks)...) }, expErr: false, }, @@ -569,48 +578,12 @@ func TestStopConsumerChain(t *testing.T) { require.NoError(t, err) } - testProviderStateIsCleaned(t, ctx, providerKeeper, "chainID", "channelID") + testkeeper.TestProviderStateIsCleanedAfterConsumerChainIsStopped(t, ctx, providerKeeper, "chainID", "channelID") ctrl.Finish() } } -// testProviderStateIsCleaned executes test assertions for the proposer's state being cleaned after a stopped consumer chain. -func testProviderStateIsCleaned(t *testing.T, ctx sdk.Context, providerKeeper providerkeeper.Keeper, - expectedChainID, expectedChannelID string, -) { - t.Helper() - _, found := providerKeeper.GetConsumerClientId(ctx, expectedChainID) - require.False(t, found) - _, found = providerKeeper.GetChainToChannel(ctx, expectedChainID) - require.False(t, found) - _, found = providerKeeper.GetChannelToChain(ctx, expectedChannelID) - require.False(t, found) - _, found = providerKeeper.GetInitChainHeight(ctx, expectedChainID) - require.False(t, found) - acks := providerKeeper.GetSlashAcks(ctx, expectedChainID) - require.Empty(t, acks) - _, found = providerKeeper.GetInitTimeoutTimestamp(ctx, expectedChainID) - require.False(t, found) - - require.Empty(t, providerKeeper.GetAllVscSendTimestamps(ctx, expectedChainID)) - - // test key assignment state is cleaned - require.Empty(t, providerKeeper.GetAllValidatorConsumerPubKeys(ctx, &expectedChainID)) - require.Empty(t, providerKeeper.GetAllValidatorsByConsumerAddr(ctx, &expectedChainID)) - require.Empty(t, providerKeeper.GetAllKeyAssignmentReplacements(ctx, expectedChainID)) - require.Empty(t, providerKeeper.GetAllConsumerAddrsToPrune(ctx, expectedChainID)) - - allGlobalEntries := providerKeeper.GetAllGlobalSlashEntries(ctx) - for _, entry := range allGlobalEntries { - require.NotEqual(t, expectedChainID, entry.ConsumerChainID) - } - - slashPacketData, vscMaturedPacketData, _, _ := providerKeeper.GetAllThrottledPacketData(ctx, expectedChainID) - require.Empty(t, slashPacketData) - require.Empty(t, vscMaturedPacketData) -} - // TestPendingConsumerRemovalPropDeletion tests the getting/setting // and deletion methods for pending consumer removal props func TestPendingConsumerRemovalPropDeletion(t *testing.T) { @@ -1075,8 +1048,8 @@ func TestBeginBlockCCR(t *testing.T) { expectations = append(expectations, testkeeper.GetMocksForSetConsumerChain(ctx, &mocks, prop.ChainId)...) } // Only first two consumer chains should be stopped - expectations = append(expectations, testkeeper.GetMocksForStopConsumerChain(ctx, &mocks)...) - expectations = append(expectations, testkeeper.GetMocksForStopConsumerChain(ctx, &mocks)...) + expectations = append(expectations, testkeeper.GetMocksForStopConsumerChainWithCloseChannel(ctx, &mocks)...) + expectations = append(expectations, testkeeper.GetMocksForStopConsumerChainWithCloseChannel(ctx, &mocks)...) gomock.InOrder(expectations...) diff --git a/x/ccv/provider/keeper/relay_test.go b/x/ccv/provider/keeper/relay_test.go index b266211a42..590286261c 100644 --- a/x/ccv/provider/keeper/relay_test.go +++ b/x/ccv/provider/keeper/relay_test.go @@ -1,6 +1,7 @@ package keeper_test import ( + "strings" "testing" "time" @@ -699,7 +700,7 @@ func TestSendVSCPacketsToChainFailure(t *testing.T) { ) // Append mocks for expected call to StopConsumerChain - mockCalls = append(mockCalls, testkeeper.GetMocksForStopConsumerChain(ctx, &mocks)...) + mockCalls = append(mockCalls, testkeeper.GetMocksForStopConsumerChainWithCloseChannel(ctx, &mocks)...) // Assert mock calls hit gomock.InOrder(mockCalls...) @@ -715,3 +716,72 @@ func TestSendVSCPacketsToChainFailure(t *testing.T) { // Pending VSC packets should be deleted in StopConsumerChain require.Empty(t, providerKeeper.GetPendingVSCPackets(ctx, "consumerChainID")) } + +// TestOnTimeoutPacketWithNoChainFound tests the `OnTimeoutPacket` method fails when no chain is found +func TestOnTimeoutPacketWithNoChainFound(t *testing.T) { + // Keeper setup + providerKeeper, ctx, ctrl, _ := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) + defer ctrl.Finish() + + // We do not `SetChannelToChain` for "channelID" and therefore `OnTimeoutPacket` fails + packet := channeltypes.Packet{ + SourceChannel: "channelID", + } + err := providerKeeper.OnTimeoutPacket(ctx, packet) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), channeltypes.ErrInvalidChannel.Error())) +} + +// TestOnTimeoutPacketStopsChain tests that the chain is stopped in case of a timeout +func TestOnTimeoutPacketStopsChain(t *testing.T) { + // Keeper setup + providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) + defer ctrl.Finish() + providerKeeper.SetParams(ctx, providertypes.DefaultParams()) + + testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks) + + packet := channeltypes.Packet{ + SourceChannel: "channelID", + } + err := providerKeeper.OnTimeoutPacket(ctx, packet) + + testkeeper.TestProviderStateIsCleanedAfterConsumerChainIsStopped(t, ctx, providerKeeper, "chainID", "channelID") + require.NoError(t, err) +} + +// TestOnAcknowledgementPacketWithNoAckError tests `OnAcknowledgementPacket` when the underlying ack contains no error +func TestOnAcknowledgementPacketWithNoAckError(t *testing.T) { + // Keeper setup + providerKeeper, ctx, ctrl, _ := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) + defer ctrl.Finish() + + ack := channeltypes.Acknowledgement{Response: &channeltypes.Acknowledgement_Result{Result: []byte{}}} + err := providerKeeper.OnAcknowledgementPacket(ctx, channeltypes.Packet{}, ack) + require.NoError(t, err) +} + +// TestOnAcknowledgementPacketWithAckError tests `OnAcknowledgementPacket` when the underlying ack contains an error +func TestOnAcknowledgementPacketWithAckError(t *testing.T) { + // Keeper setup + providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) + defer ctrl.Finish() + providerKeeper.SetParams(ctx, providertypes.DefaultParams()) + + // test that `OnAcknowledgementPacket` returns an error if the ack contains an error and the channel is unknown + ackError := channeltypes.Acknowledgement{Response: &channeltypes.Acknowledgement_Error{Error: "some error"}} + err := providerKeeper.OnAcknowledgementPacket(ctx, channeltypes.Packet{}, ackError) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), providertypes.ErrUnknownConsumerChannelId.Error())) + + // test that we stop the consumer chain when `OnAcknowledgementPacket` returns an error and the chain is found + testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks) + packet := channeltypes.Packet{ + SourceChannel: "channelID", + } + + err = providerKeeper.OnAcknowledgementPacket(ctx, packet, ackError) + + testkeeper.TestProviderStateIsCleanedAfterConsumerChainIsStopped(t, ctx, providerKeeper, "chainID", "channelID") + require.NoError(t, err) +} diff --git a/x/ccv/provider/proposal_handler_test.go b/x/ccv/provider/proposal_handler_test.go index d3707d8c28..e8963421d2 100644 --- a/x/ccv/provider/proposal_handler_test.go +++ b/x/ccv/provider/proposal_handler_test.go @@ -108,6 +108,9 @@ func TestProviderProposalHandler(t *testing.T) { case tc.expValidConsumerRemoval: testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks) + // assert mocks for expected calls to `StopConsumerChain` when closing the underlying channel + gomock.InOrder(testkeeper.GetMocksForStopConsumerChainWithCloseChannel(ctx, &mocks)...) + case tc.expValidEquivocation: providerKeeper.SetSlashLog(ctx, providertypes.NewProviderConsAddress(equivocation.GetConsensusAddress())) mocks.MockEvidenceKeeper.EXPECT().HandleEquivocationEvidence(ctx, equivocation)