From 61f36d43e0ef580d67c51fd947eeff601f0d8080 Mon Sep 17 00:00:00 2001 From: 0xTopaz <60733299+onlyhyde@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:58:34 +0900 Subject: [PATCH] GSW-1839 Refactor/position contract utils (#433) * GSW-1839 refactor: integrated helper and test code - integrated helper with nft helper - add test helper code - add test code for helper - change file filename * GSW-1839 refactor: utils - add assert functions - refactor original util functions * Update position/utils_test.gno * test: Update to use the correct test values --------- Co-authored-by: Blake <104744707+r3v4s@users.noreply.github.com> --- position/position.gno | 20 +-- position/utils.gno | 95 +++++++++-- position/utils_test.gno | 344 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 434 insertions(+), 25 deletions(-) create mode 100644 position/utils_test.gno diff --git a/position/position.gno b/position/position.gno index 60cb0f62..f33ffd0b 100644 --- a/position/position.gno +++ b/position/position.gno @@ -105,12 +105,12 @@ func Mint( poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(poolPath) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "Mint", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "tickLower", ufmt.Sprintf("%d", tickLower), "tickUpper", ufmt.Sprintf("%d", tickUpper), "poolPath", poolPath, @@ -265,12 +265,12 @@ func IncreaseLiquidity( poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(poolPath) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "IncreaseLiquidity", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "lpTokenId", ufmt.Sprintf("%d", tokenId), "internal_poolPath", poolPath, "internal_liquidity", liquidity.ToString(), @@ -386,12 +386,12 @@ func DecreaseLiquidity( poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(poolPath) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "DecreaseLiquidity", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "lpTokenId", ufmt.Sprintf("%d", tokenId), "liquidityRatio", ufmt.Sprintf("%d", liquidityRatio), "internal_poolPath", poolPath, @@ -615,12 +615,12 @@ func Reposition( poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(position.poolKey) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "Reposition", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "lpTokenId", ufmt.Sprintf("%d", tokenId), "tickLower", ufmt.Sprintf("%d", tickLower), "tickUpper", ufmt.Sprintf("%d", tickUpper), @@ -736,12 +736,12 @@ func CollectFee(tokenId uint64, unwrapResult bool) (uint64, string, string, stri } } - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "CollectSwapFee", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "lpTokenId", ufmt.Sprintf("%d", tokenId), "internal_fee0", withoutFee0, "internal_fee1", withoutFee1, diff --git a/position/utils.gno b/position/utils.gno index dedba866..2121fb4f 100644 --- a/position/utils.gno +++ b/position/utils.gno @@ -6,35 +6,86 @@ import ( "gno.land/p/demo/ufmt" pusers "gno.land/p/demo/users" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" ) -func checkDeadline(deadline int64) { - now := time.Now().Unix() - if now > deadline { - panic(addDetailToError( - errExpired, - ufmt.Sprintf("utils.gno__checkDeadline() || transaction too old, now(%d) > deadline(%d)", now, deadline), - )) - } -} - +// a2u converts std.Address to pusers.AddressOrName. +// pusers is a package that contains the user-related functions. +// +// Input: +// - addr: the address to convert +// +// Output: +// - pusers.AddressOrName: the converted address func a2u(addr std.Address) pusers.AddressOrName { return pusers.AddressOrName(addr) } -func prevRealm() string { - return std.PrevRealm().PkgPath() +// derivePkgAddr derives the Realm address from it's pkgpath parameter +func derivePkgAddr(pkgPath string) std.Address { + return std.DerivePkgAddr(pkgPath) } -func isUserCall() bool { - return std.PrevRealm().IsUser() +// getOrigPkgAddr returns the original package address. +// In position contract, original package address is the position address. +func getOrigPkgAddr() std.Address { + return consts.POSITION_ADDR } -func getPrev() (string, string) { +// getPrevRealm returns object of the previous realm. +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrev returns the address and package path of the previous realm. +func getPrevAsString() (string, string) { prev := std.PrevRealm() return prev.Addr().String(), prev.PkgPath() } +// isUserCall returns true if the caller is a user. +func isUserCall() bool { + return std.PrevRealm().IsUser() +} + +// checkDeadline checks if the deadline is expired. +// If the deadline is expired, it panics. +// The deadline is expired if the current time is greater than the deadline. +// Input: +// - deadline: the deadline to check +func checkDeadline(deadline int64) { + now := time.Now().Unix() + if now > deadline { + panic(newErrorWithDetail( + errExpired, + ufmt.Sprintf("transaction too old, now(%d) > deadline(%d)", now, deadline), + )) + } +} + +// assertOnlyUserOrStaker panics if the caller is not a user or staker. +func assertOnlyUserOrStaker(caller std.Realm) { + if !caller.IsUser() { + if err := common.StakerOnly(caller.Addr()); err != nil { + panic(newErrorWithDetail( + errNoPermission, + ufmt.Sprintf("from (%s)", caller.Addr()), + )) + } + } +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + // assertOnlyValidAddress panics if the address is invalid. func assertOnlyValidAddress(addr std.Address) { if !addr.IsValid() { @@ -44,3 +95,17 @@ func assertOnlyValidAddress(addr std.Address) { )) } } + +// assertOnlyValidAddress panics if the address is invalid or previous address is not +// different from the other address. +func assertOnlyValidAddressWith(prevAddr, otherAddr std.Address) { + assertOnlyValidAddress(prevAddr) + assertOnlyValidAddress(otherAddr) + + if prevAddr != otherAddr { + panic(newErrorWithDetail( + errInvalidAddress, + ufmt.Sprintf("(%s, %s)", prevAddr, otherAddr), + )) + } +} diff --git a/position/utils_test.gno b/position/utils_test.gno new file mode 100644 index 00000000..26051a27 --- /dev/null +++ b/position/utils_test.gno @@ -0,0 +1,344 @@ +package position + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + pusers "gno.land/p/demo/users" + "gno.land/r/demo/users" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" +) + +func TestA2u(t *testing.T) { + var ( + addr = std.Address("g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8c") + ) + + tests := []struct { + name string + input std.Address + expected pusers.AddressOrName + }{ + { + name: "Success - a2u", + input: addr, + expected: pusers.AddressOrName(addr), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := a2u(tc.input) + uassert.Equal(t, users.Resolve(got).String(), users.Resolve(tc.expected).String()) + }) + } +} + +func TestDerivePkgAddr(t *testing.T) { + var ( + pkgPath = "gno.land/r/gnoswap/v1/position" + ) + tests := []struct { + name string + input string + expected string + }{ + { + name: "Success - derivePkgAddr", + input: pkgPath, + expected: "g1q646ctzhvn60v492x8ucvyqnrj2w30cwh6efk5", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := derivePkgAddr(tc.input) + uassert.Equal(t, got.String(), tc.expected) + }) + } +} + +func TestGetOrigPkgAddr(t *testing.T) { + tests := []struct { + name string + expected std.Address + }{ + { + name: "Success - getOrigPkgAddr", + expected: consts.POSITION_ADDR, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := getOrigPkgAddr() + uassert.Equal(t, got, tc.expected) + }) + } +} + +func TestGetPrevRealm(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prevRealm is User", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevRealm() + uassert.Equal(t, got.Addr().String(), tc.expected[0]) + uassert.Equal(t, got.PkgPath(), tc.expected[1]) + }) + } +} + +func TestGetPrevAddr(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected std.Address + }{ + { + name: "Success - prev Address is User", + originCaller: consts.ADMIN, + expected: "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevAddr() + uassert.Equal(t, got.String(), tc.expected.String()) + }) + } +} + +func TestGetPrevAsString(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prev Realm of user info as string", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got1, got2 := getPrevAsString() + uassert.Equal(t, got1, tc.expected[0]) + uassert.Equal(t, got2, tc.expected[1]) + }) + } +} + +func TestIsUserCall(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + originPkgPath string + expected bool + }{ + { + name: "Success - User Call", + originCaller: consts.ADMIN, + expected: true, + }, + { + name: "Failure - Not User Call", + originCaller: consts.ROUTER_ADDR, + originPkgPath: consts.ROUTER_PATH, + expected: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(tc.originCaller) + if !tc.expected { + std.TestSetRealm(std.NewCodeRealm(tc.originPkgPath)) + } + got := isUserCall() + uassert.Equal(t, got, tc.expected) + }) + } +} + +func TestCheckDeadline(t *testing.T) { + tests := []struct { + name string + deadline int64 + now int64 + expected string + }{ + { + name: "Success - checkDeadline", + deadline: 1234567890 + 100, + now: 1234567890, + expected: "", + }, + { + name: "Failure - checkDeadline", + deadline: 1234567890 - 100, + now: 1234567890, + expected: "[GNOSWAP-POSITION-007] transaction expired || transaction too old, now(1234567890) > deadline(1234567790)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected != "" { + uassert.PanicsWithMessage(t, tc.expected, func() { + checkDeadline(tc.deadline) + }) + } else { + uassert.NotPanics(t, func() { + checkDeadline(tc.deadline) + }) + } + }) + } +} + +func TestAssertOnlyUserOrStaker(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected bool + }{ + { + name: "Failure - Not User or Staker", + originCaller: consts.ROUTER_ADDR, + expected: false, + }, + { + name: "Success - User Call", + originCaller: consts.ADMIN, + expected: true, + }, + { + name: "Success - Staker Call", + originCaller: consts.STAKER_ADDR, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(tc.originCaller) + assertOnlyUserOrStaker(std.PrevRealm()) + }) + } +} + +func TestAssertOnlyNotHalted(t *testing.T) { + tests := []struct { + name string + expected bool + panicMsg string + }{ + { + name: "Failure - Halted", + expected: false, + panicMsg: "[GNOSWAP-COMMON-002] halted || gnoswap halted", + }, + { + name: "Success - Not Halted", + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyNotHalted() + }) + } else { + std.TestSetRealm(std.NewUserRealm(users.Resolve(admin))) + common.SetHaltByAdmin(true) + uassert.PanicsWithMessage(t, tc.panicMsg, func() { + assertOnlyNotHalted() + }) + common.SetHaltByAdmin(false) + } + }) + } +} + +func TestAssertOnlyValidAddress(t *testing.T) { + tests := []struct { + name string + addr std.Address + expected bool + errorMsg string + }{ + { + name: "Success - valid address", + addr: consts.ADMIN, + expected: true, + }, + { + name: "Failure - invalid address", + addr: "g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8", // invalid length + expected: false, + errorMsg: "[GNOSWAP-POSITION-011] invalid address || (g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyValidAddress(tc.addr) + }) + } else { + uassert.PanicsWithMessage(t, tc.errorMsg, func() { + assertOnlyValidAddress(tc.addr) + }) + } + }) + } +} + +func TestAssertOnlyValidAddressWith(t *testing.T) { + tests := []struct { + name string + addr std.Address + other std.Address + expected bool + errorMsg string + }{ + { + name: "Success - validation address check to compare with other address", + addr: consts.ADMIN, + other: std.Address("g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d"), + expected: true, + }, + { + name: "Failure - two address is different", + addr: "g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8", + other: "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + expected: false, + errorMsg: "[GNOSWAP-POSITION-011] invalid address || (g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyValidAddressWith(tc.addr, tc.other) + }) + } else { + uassert.PanicsWithMessage(t, tc.errorMsg, func() { + assertOnlyValidAddressWith(tc.addr, tc.other) + }) + } + }) + } +}