Skip to content

Commit

Permalink
Fix test by veryfing provider's chain state is clean.
Browse files Browse the repository at this point in the history
  • Loading branch information
insumity committed Aug 18, 2023
1 parent cca008d commit 92e37d8
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 39 deletions.
50 changes: 48 additions & 2 deletions testutil/keeper/unit_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,17 @@ func GetNewVSCMaturedPacketData() types.VSCMaturedPacketData {
}

// SetupForStoppingConsumerChain registers expected mock calls and corresponding state setup
// which asserts that a consumer chain was properly stopped from StopConsumerChain().
// which assert that a consumer chain was properly setup to be later stopped from `StopConsumerChain`.
// Note: This function only setups and tests that we correctly setup a consumer chain that we could later stop when
// calling `StopConsumerChain` -- this does NOT necessarily mean that the consumer chain is stopped.
// Also see `SetupForStoppingConsumerChainChannel` and `TestProviderStateIsCleanedAfterConsumerChainIsStopped`.
func SetupForStoppingConsumerChain(t *testing.T, ctx sdk.Context,
providerKeeper *providerkeeper.Keeper, mocks MockedKeepers,
) {
t.Helper()
expectations := GetMocksForCreateConsumerClient(ctx, &mocks,
"chainID", clienttypes.NewHeight(4, 5))
expectations = append(expectations, GetMocksForSetConsumerChain(ctx, &mocks, "chainID")...)
expectations = append(expectations, GetMocksForStopConsumerChain(ctx, &mocks)...)

gomock.InOrder(expectations...)

Expand All @@ -226,6 +228,50 @@ func SetupForStoppingConsumerChain(t *testing.T, ctx sdk.Context,
require.NoError(t, err)
}

// SetupForStoppingConsumerChainChannel registers expected mock calls which assert that the channel to the consumer
// chain is closed. To be used when we test that `StopConsumerChain` is called with `closeChan` set to `true`.
func SetupForStoppingConsumerChainChannel(t *testing.T, ctx sdk.Context, mocks MockedKeepers) {
t.Helper()
gomock.InOrder(GetMocksForStopConsumerChain(ctx, &mocks)...)
}

// TestProviderStateIsCleanedAfterConsumerChainIsStopped executes test assertions for the provider's state being cleaned
// after a stopped consumer chain.
func TestProviderStateIsCleanedAfterConsumerChainIsStopped(t *testing.T, ctx sdk.Context, providerKeeper providerkeeper.Keeper,
expectedChainID, expectedChannelID string,
) {
t.Helper()
_, found := providerKeeper.GetConsumerClientId(ctx, expectedChainID)
require.False(t, found)
_, found = providerKeeper.GetChainToChannel(ctx, expectedChainID)
require.False(t, found)
_, found = providerKeeper.GetChannelToChain(ctx, expectedChannelID)
require.False(t, found)
_, found = providerKeeper.GetInitChainHeight(ctx, expectedChainID)
require.False(t, found)
acks := providerKeeper.GetSlashAcks(ctx, expectedChainID)
require.Empty(t, acks)
_, found = providerKeeper.GetInitTimeoutTimestamp(ctx, expectedChainID)
require.False(t, found)

require.Empty(t, providerKeeper.GetAllVscSendTimestamps(ctx, expectedChainID))

// test key assignment state is cleaned
require.Empty(t, providerKeeper.GetAllValidatorConsumerPubKeys(ctx, &expectedChainID))
require.Empty(t, providerKeeper.GetAllValidatorsByConsumerAddr(ctx, &expectedChainID))
require.Empty(t, providerKeeper.GetAllKeyAssignmentReplacements(ctx, expectedChainID))
require.Empty(t, providerKeeper.GetAllConsumerAddrsToPrune(ctx, expectedChainID))

allGlobalEntries := providerKeeper.GetAllGlobalSlashEntries(ctx)
for _, entry := range allGlobalEntries {
require.NotEqual(t, expectedChainID, entry.ConsumerChainID)
}

slashPacketData, vscMaturedPacketData, _, _ := providerKeeper.GetAllThrottledPacketData(ctx, expectedChainID)
require.Empty(t, slashPacketData)
require.Empty(t, vscMaturedPacketData)
}

func GetTestConsumerAdditionProp() *providertypes.ConsumerAdditionProposal {
prop := providertypes.NewConsumerAdditionProposal(
"chainID",
Expand Down
41 changes: 4 additions & 37 deletions x/ccv/provider/keeper/proposal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ func TestHandleConsumerRemovalProposal(t *testing.T) {
// meaning no external keeper methods are allowed to be called.
if tc.expAppendProp {
testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks)
testkeeper.SetupForStoppingConsumerChainChannel(t, ctx, mocks)
}

tc.setupMocks(ctx, providerKeeper, tc.prop.ChainId)
Expand Down Expand Up @@ -527,6 +528,7 @@ func TestStopConsumerChain(t *testing.T) {
description: "valid stop of consumer chain, throttle related queues are cleaned",
setup: func(ctx sdk.Context, providerKeeper *providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {
testkeeper.SetupForStoppingConsumerChain(t, ctx, providerKeeper, mocks)
testkeeper.SetupForStoppingConsumerChainChannel(t, ctx, mocks)

providerKeeper.QueueGlobalSlashEntry(ctx, providertypes.NewGlobalSlashEntry(
ctx.BlockTime(), "chainID", 1, cryptoutil.NewCryptoIdentityFromIntSeed(90).ProviderConsAddress()))
Expand All @@ -546,6 +548,7 @@ func TestStopConsumerChain(t *testing.T) {
description: "valid stop of consumer chain, all mock calls hit",
setup: func(ctx sdk.Context, providerKeeper *providerkeeper.Keeper, mocks testkeeper.MockedKeepers) {
testkeeper.SetupForStoppingConsumerChain(t, ctx, providerKeeper, mocks)
testkeeper.SetupForStoppingConsumerChainChannel(t, ctx, mocks)
},
expErr: false,
},
Expand All @@ -569,48 +572,12 @@ func TestStopConsumerChain(t *testing.T) {
require.NoError(t, err)
}

testProviderStateIsCleaned(t, ctx, providerKeeper, "chainID", "channelID")
testkeeper.TestProviderStateIsCleanedAfterConsumerChainIsStopped(t, ctx, providerKeeper, "chainID", "channelID")

ctrl.Finish()
}
}

// testProviderStateIsCleaned executes test assertions for the proposer's state being cleaned after a stopped consumer chain.
func testProviderStateIsCleaned(t *testing.T, ctx sdk.Context, providerKeeper providerkeeper.Keeper,
expectedChainID, expectedChannelID string,
) {
t.Helper()
_, found := providerKeeper.GetConsumerClientId(ctx, expectedChainID)
require.False(t, found)
_, found = providerKeeper.GetChainToChannel(ctx, expectedChainID)
require.False(t, found)
_, found = providerKeeper.GetChannelToChain(ctx, expectedChannelID)
require.False(t, found)
_, found = providerKeeper.GetInitChainHeight(ctx, expectedChainID)
require.False(t, found)
acks := providerKeeper.GetSlashAcks(ctx, expectedChainID)
require.Empty(t, acks)
_, found = providerKeeper.GetInitTimeoutTimestamp(ctx, expectedChainID)
require.False(t, found)

require.Empty(t, providerKeeper.GetAllVscSendTimestamps(ctx, expectedChainID))

// test key assignment state is cleaned
require.Empty(t, providerKeeper.GetAllValidatorConsumerPubKeys(ctx, &expectedChainID))
require.Empty(t, providerKeeper.GetAllValidatorsByConsumerAddr(ctx, &expectedChainID))
require.Empty(t, providerKeeper.GetAllKeyAssignmentReplacements(ctx, expectedChainID))
require.Empty(t, providerKeeper.GetAllConsumerAddrsToPrune(ctx, expectedChainID))

allGlobalEntries := providerKeeper.GetAllGlobalSlashEntries(ctx)
for _, entry := range allGlobalEntries {
require.NotEqual(t, expectedChainID, entry.ConsumerChainID)
}

slashPacketData, vscMaturedPacketData, _, _ := providerKeeper.GetAllThrottledPacketData(ctx, expectedChainID)
require.Empty(t, slashPacketData)
require.Empty(t, vscMaturedPacketData)
}

// TestPendingConsumerRemovalPropDeletion tests the getting/setting
// and deletion methods for pending consumer removal props
func TestPendingConsumerRemovalPropDeletion(t *testing.T) {
Expand Down
70 changes: 70 additions & 0 deletions x/ccv/provider/keeper/relay_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package keeper_test

import (
"strings"
"testing"
"time"

Expand Down Expand Up @@ -715,3 +716,72 @@ func TestSendVSCPacketsToChainFailure(t *testing.T) {
// Pending VSC packets should be deleted in StopConsumerChain
require.Empty(t, providerKeeper.GetPendingVSCPackets(ctx, "consumerChainID"))
}

// TestOnTimeoutPacketWithNoChainFound tests the `OnTimeoutPacket` method fails when no chain is found
func TestOnTimeoutPacketWithNoChainFound(t *testing.T) {
// Keeper setup
providerKeeper, ctx, ctrl, _ := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t))
defer ctrl.Finish()

// We do not `SetChannelToChain` for "channelID" and therefore `OnTimeoutPacket` fails
packet := channeltypes.Packet{
SourceChannel: "channelID",
}
err := providerKeeper.OnTimeoutPacket(ctx, packet)
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), channeltypes.ErrInvalidChannel.Error()))
}

// TestOnTimeoutPacketStopsChain tests that the chain is stopped in case of a timeout
func TestOnTimeoutPacketStopsChain(t *testing.T) {
// Keeper setup
providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t))
defer ctrl.Finish()
providerKeeper.SetParams(ctx, providertypes.DefaultParams())

testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks)

packet := channeltypes.Packet{
SourceChannel: "channelID",
}
err := providerKeeper.OnTimeoutPacket(ctx, packet)

testkeeper.TestProviderStateIsCleanedAfterConsumerChainIsStopped(t, ctx, providerKeeper, "chainID", "channelID")
require.NoError(t, err)
}

// TestOnAcknowledgementPacketWithNoAckError tests `OnAcknowledgementPacket` when the underlying ack contains no error
func TestOnAcknowledgementPacketWithNoAckError(t *testing.T) {
// Keeper setup
providerKeeper, ctx, ctrl, _ := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t))
defer ctrl.Finish()

ack := channeltypes.Acknowledgement{Response: &channeltypes.Acknowledgement_Result{Result: []byte{}}}
err := providerKeeper.OnAcknowledgementPacket(ctx, channeltypes.Packet{}, ack)
require.NoError(t, err)
}

// TestOnAcknowledgementPacketWithAckError tests `OnAcknowledgementPacket` when the underlying ack contains an error
func TestOnAcknowledgementPacketWithAckError(t *testing.T) {
// Keeper setup
providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t))
defer ctrl.Finish()
providerKeeper.SetParams(ctx, providertypes.DefaultParams())

// test that `OnAcknowledgementPacket` returns an error if the ack contains an error and the channel is unknown
ackError := channeltypes.Acknowledgement{Response: &channeltypes.Acknowledgement_Error{Error: "some error"}}
err := providerKeeper.OnAcknowledgementPacket(ctx, channeltypes.Packet{}, ackError)
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), providertypes.ErrUnknownConsumerChannelId.Error()))

// test that we stop the consumer chain when `OnAcknowledgementPacket` returns an error and the chain is found
testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks)
packet := channeltypes.Packet{
SourceChannel: "channelID",
}

err = providerKeeper.OnAcknowledgementPacket(ctx, packet, ackError)

testkeeper.TestProviderStateIsCleanedAfterConsumerChainIsStopped(t, ctx, providerKeeper, "chainID", "channelID")
require.NoError(t, err)
}
1 change: 1 addition & 0 deletions x/ccv/provider/proposal_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ func TestProviderProposalHandler(t *testing.T) {

case tc.expValidConsumerRemoval:
testkeeper.SetupForStoppingConsumerChain(t, ctx, &providerKeeper, mocks)
testkeeper.SetupForStoppingConsumerChainChannel(t, ctx, mocks)

case tc.expValidEquivocation:
providerKeeper.SetSlashLog(ctx, providertypes.NewProviderConsAddress(equivocation.GetConsensusAddress()))
Expand Down

0 comments on commit 92e37d8

Please sign in to comment.