Skip to content

Commit

Permalink
Make sure from address != to address in tc claim
Browse files Browse the repository at this point in the history
  • Loading branch information
toshiSat committed Jan 30, 2025
1 parent 7338e85 commit 5cf6280
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
7 changes: 6 additions & 1 deletion x/claim/keeper/msg_server_claim_thorchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ func (k msgServer) ClaimThorchain(goCtx context.Context, msg *types.MsgClaimThor
ctx := sdk.UnwrapSDKContext(goCtx)
k.Logger(ctx).Info(msg.Creator)

// Add check for matching addresses
if msg.FromAddress == msg.ToAddress {
return nil, errors.Wrapf(types.ErrInvalidAddress, "from address and to address cannot be the same: %s", msg.FromAddress)
}

// only allow thorchain claim server address to call this function
if msg.Creator != "tarkeo1z02ke8639m47g9dfrheegr2u9zecegt5qvtj00" && msg.Creator != "arkeo1z02ke8639m47g9dfrheegr2u9zecegt50fjg7v" {
return nil, errors.Wrapf(types.ErrInvalidCreator, "Invalid Creator %s", msg.Creator)
Expand All @@ -23,7 +28,7 @@ func (k msgServer) ClaimThorchain(goCtx context.Context, msg *types.MsgClaimThor
if err != nil {
return nil, errors.Wrapf(err, "failed to get claim record for %s", msg.FromAddress)
}
if fromAddressClaimRecord.IsEmpty() || fromAddressClaimRecord.AmountClaim.IsZero() {
if fromAddressClaimRecord.IsEmpty() || (fromAddressClaimRecord.AmountClaim.IsZero() && fromAddressClaimRecord.AmountVote.IsZero() && fromAddressClaimRecord.AmountDelegate.IsZero()) {
return nil, errors.Wrapf(types.ErrNoClaimableAmount, "no claimable amount for %s", msg.FromAddress)
}

Expand Down
51 changes: 51 additions & 0 deletions x/claim/keeper/msg_server_claim_thorchain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,54 @@ func TestClaimThorchainMainnetAddress(t *testing.T) {
_, err = msgServer.ClaimThorchain(ctx, &claimMessage)
require.ErrorIs(t, err, types.ErrNoClaimableAmount)
}

func TestClaimThorchainFailureCases(t *testing.T) {
msgServer, keepers, ctx := setupMsgServer(t)
sdkCtx := sdk.UnwrapSDKContext(ctx)

config := sdk.GetConfig()
config.SetBech32PrefixForAccount("arkeo", "arkeopub")

arkeoServerAddress, err := sdk.AccAddressFromBech32("arkeo1z02ke8639m47g9dfrheegr2u9zecegt50fjg7v")
require.NoError(t, err)

fromAddr := utils.GetRandomArkeoAddress()
toAddr := utils.GetRandomArkeoAddress()

// Test case 1: Same from and to address
sameAddressMsg := types.MsgClaimThorchain{
Creator: arkeoServerAddress.String(),
FromAddress: fromAddr.String(),
ToAddress: fromAddr.String(),
}
_, err = msgServer.ClaimThorchain(ctx, &sameAddressMsg)
require.ErrorIs(t, types.ErrInvalidAddress, err)

// Test case 2: Empty claim record for from address
emptyFromMsg := types.MsgClaimThorchain{
Creator: arkeoServerAddress.String(),
FromAddress: fromAddr.String(),
ToAddress: toAddr.String(),
}
_, err = msgServer.ClaimThorchain(ctx, &emptyFromMsg)
require.ErrorIs(t, types.ErrNoClaimableAmount, err)

// Test case 3: Zero amount claim record
zeroClaimRecord := types.ClaimRecord{
Chain: types.ARKEO,
Address: fromAddr.String(),
AmountClaim: sdk.NewInt64Coin(types.DefaultClaimDenom, 0),
AmountVote: sdk.NewInt64Coin(types.DefaultClaimDenom, 0),
AmountDelegate: sdk.NewInt64Coin(types.DefaultClaimDenom, 0),
}
err = keepers.ClaimKeeper.SetClaimRecord(sdkCtx, zeroClaimRecord)
require.NoError(t, err)

zeroAmountMsg := types.MsgClaimThorchain{
Creator: arkeoServerAddress.String(),
FromAddress: fromAddr.String(),
ToAddress: toAddr.String(),
}
_, err = msgServer.ClaimThorchain(ctx, &zeroAmountMsg)
require.ErrorIs(t, types.ErrNoClaimableAmount, err)
}
1 change: 1 addition & 0 deletions x/claim/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ var (
ErrInvalidSignature = errors.Register(ModuleName, 3, "Invalid signature")
ErrClaimRecordNotTransferrable = errors.Register(ModuleName, 4, "Claim record can not be transferred")
ErrInvalidCreator = errors.Register(ModuleName, 5, "Invalid Creator")
ErrInvalidAddress = errors.Register(ModuleName, 6, "invalid address")
)

0 comments on commit 5cf6280

Please sign in to comment.