From 035a7961a8ddc2795d24df6159b9e4469ac40ac2 Mon Sep 17 00:00:00 2001 From: Blake <104744707+r3v4s@users.noreply.github.com> Date: Wed, 18 Dec 2024 13:52:14 +0900 Subject: [PATCH 1/4] GSW-1968 feat: check emission is caller or not (#435) * feat: add check logic for caller is emission or not * fix: remove deprecated require from gno.mod --- _deploy/r/gnoswap/common/access.gno | 39 ++++--- _deploy/r/gnoswap/common/access_test.gno | 136 +++++++++++++++++++++++ _deploy/r/gnoswap/common/gno.mod | 8 -- 3 files changed, 162 insertions(+), 21 deletions(-) create mode 100644 _deploy/r/gnoswap/common/access_test.gno diff --git a/_deploy/r/gnoswap/common/access.gno b/_deploy/r/gnoswap/common/access.gno index 459a66a2d..c96bc2efd 100644 --- a/_deploy/r/gnoswap/common/access.gno +++ b/_deploy/r/gnoswap/common/access.gno @@ -7,9 +7,13 @@ import ( "gno.land/r/gnoswap/v1/consts" ) +const ( + ErrNoPermission = "caller(%s) has no permission" +) + func AssertCaller(caller, addr std.Address) error { if caller != addr { - return ufmt.Errorf("caller(%s) has no permission", caller.String()) + return ufmt.Errorf(ErrNoPermission, caller.String()) } return nil } @@ -23,56 +27,65 @@ func SatisfyCond(cond bool) error { func AdminOnly(caller std.Address) error { if caller != consts.ADMIN { - return ufmt.Errorf("caller(%s) has no permission", caller.String()) + return ufmt.Errorf(ErrNoPermission, caller.String()) } return nil } func GovernanceOnly(caller std.Address) error { if caller != consts.GOV_GOVERNANCE_ADDR { - return ufmt.Errorf("caller(%s) has no permission", caller.String()) + return ufmt.Errorf(ErrNoPermission, caller.String()) } return nil } func GovStakerOnly(caller std.Address) error { if caller != consts.GOV_STAKER_ADDR { - return ufmt.Errorf("caller(%s) has no permission", caller.String()) + return ufmt.Errorf(ErrNoPermission, caller.String()) } return nil } func RouterOnly(caller std.Address) error { if caller != consts.ROUTER_ADDR { - return ufmt.Errorf("caller(%s) has no permission", caller.String()) + return ufmt.Errorf(ErrNoPermission, caller.String()) } return nil } func PositionOnly(caller std.Address) error { if caller != consts.POSITION_ADDR { - return ufmt.Errorf("caller(%s) has no permission", caller.String()) + return ufmt.Errorf(ErrNoPermission, caller.String()) } return nil } func StakerOnly(caller std.Address) error { if caller != consts.STAKER_ADDR { - return ufmt.Errorf("caller(%s) has no permission", caller.String()) + return ufmt.Errorf(ErrNoPermission, caller.String()) } return nil } -func TokenRegisterOnly(caller std.Address) error { - if caller != consts.TOKEN_REGISTER { - return ufmt.Errorf("caller(%s) has no permission", caller.String()) +func LaunchpadOnly(caller std.Address) error { + if caller != consts.LAUNCHPAD_ADDR { + return ufmt.Errorf(ErrNoPermission, caller.String()) } return nil } -func LaunchpadOnly(caller std.Address) error { - if caller != consts.LAUNCHPAD_ADDR { - return ufmt.Errorf("caller(%s) has no permission", caller.String()) +func EmissionOnly(caller std.Address) error { + if caller != consts.EMISSION_ADDR { + return ufmt.Errorf(ErrNoPermission, caller.String()) + } + return nil +} + +// DEPRECATED +// TODO: remove after r/grc20reg is applied for all contracts +func TokenRegisterOnly(caller std.Address) error { + if caller != consts.TOKEN_REGISTER { + return ufmt.Errorf(ErrNoPermission, caller.String()) } return nil } diff --git a/_deploy/r/gnoswap/common/access_test.gno b/_deploy/r/gnoswap/common/access_test.gno new file mode 100644 index 000000000..0c1dd91b5 --- /dev/null +++ b/_deploy/r/gnoswap/common/access_test.gno @@ -0,0 +1,136 @@ +package common + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + "gno.land/r/gnoswap/v1/consts" +) + +var ( + addr01 = testutils.TestAddress("addr01") + addr02 = testutils.TestAddress("addr02") +) + +func TestAssertCaller(t *testing.T) { + t.Run("same caller", func(t *testing.T) { + uassert.NoError(t, AssertCaller(addr01, addr01)) + }) + + t.Run("different caller", func(t *testing.T) { + uassert.Error(t, AssertCaller(addr01, addr02)) + }) +} + +func TestSatisfyCond(t *testing.T) { + t.Run("true", func(t *testing.T) { + uassert.NoError(t, SatisfyCond(true)) + }) + + t.Run("false", func(t *testing.T) { + uassert.Error(t, SatisfyCond(false)) + }) +} + +func TestAdminOnly(t *testing.T) { + t.Run("caller is admin", func(t *testing.T) { + uassert.NoError(t, AdminOnly(consts.ADMIN)) + }) + + t.Run("caller is not admin", func(t *testing.T) { + uassert.Error(t, AdminOnly(addr01)) + }) +} + +func TestGovernanceOnly(t *testing.T) { + t.Run("caller is governance", func(t *testing.T) { + uassert.NoError(t, GovernanceOnly(consts.GOV_GOVERNANCE_ADDR)) + }) + + t.Run("caller is not governance", func(t *testing.T) { + uassert.Error(t, GovernanceOnly(addr01)) + }) +} + +func TestGovStakerOnly(t *testing.T) { + t.Run("caller is gov staker", func(t *testing.T) { + uassert.NoError(t, GovStakerOnly(consts.GOV_STAKER_ADDR)) + }) + + t.Run("caller is not gov staker", func(t *testing.T) { + uassert.Error(t, GovStakerOnly(addr01)) + }) +} + +func TestRouterOnly(t *testing.T) { + t.Run("caller is router", func(t *testing.T) { + uassert.NoError(t, RouterOnly(consts.ROUTER_ADDR)) + }) + + t.Run("caller is not router", func(t *testing.T) { + uassert.Error(t, RouterOnly(addr01)) + }) +} + +func TestPositionOnly(t *testing.T) { + t.Run("caller is position", func(t *testing.T) { + uassert.NoError(t, PositionOnly(consts.POSITION_ADDR)) + }) + + t.Run("caller is not position", func(t *testing.T) { + uassert.Error(t, PositionOnly(addr01)) + }) +} + +func TestStakerOnly(t *testing.T) { + t.Run("caller is staker", func(t *testing.T) { + uassert.NoError(t, StakerOnly(consts.STAKER_ADDR)) + }) + + t.Run("caller is not staker", func(t *testing.T) { + uassert.Error(t, StakerOnly(addr01)) + }) +} + +func TestLaunchpadOnly(t *testing.T) { + t.Run("caller is launchpad", func(t *testing.T) { + uassert.NoError(t, LaunchpadOnly(consts.LAUNCHPAD_ADDR)) + }) + + t.Run("caller is not launchpad", func(t *testing.T) { + uassert.Error(t, LaunchpadOnly(addr01)) + }) +} + +func TestEmissionOnly(t *testing.T) { + t.Run("caller is emission", func(t *testing.T) { + uassert.NoError(t, EmissionOnly(consts.EMISSION_ADDR)) + }) + + t.Run("caller is not emission", func(t *testing.T) { + uassert.Error(t, EmissionOnly(addr01)) + }) +} + +func TestTokenRegisterOnly(t *testing.T) { + t.Run("caller is token register", func(t *testing.T) { + uassert.NoError(t, TokenRegisterOnly(consts.TOKEN_REGISTER)) + }) + + t.Run("caller is not token register", func(t *testing.T) { + uassert.Error(t, TokenRegisterOnly(addr01)) + }) +} + +func TestUserOnly(t *testing.T) { + t.Run("caller is user", func(t *testing.T) { + uassert.NoError(t, UserOnly(std.NewUserRealm(addr01))) + }) + + t.Run("caller is not user", func(t *testing.T) { + uassert.Error(t, UserOnly(std.NewCodeRealm("gno.land/r/realm"))) + }) +} diff --git a/_deploy/r/gnoswap/common/gno.mod b/_deploy/r/gnoswap/common/gno.mod index 944098d5b..ad0e3a333 100644 --- a/_deploy/r/gnoswap/common/gno.mod +++ b/_deploy/r/gnoswap/common/gno.mod @@ -1,9 +1 @@ module gno.land/r/gnoswap/v1/common - -require ( - gno.land/p/demo/ufmt v0.0.0-latest - gno.land/p/gnoswap/int256 v0.0.0-latest - gno.land/p/gnoswap/pool v0.0.0-latest - gno.land/p/gnoswap/uint256 v0.0.0-latest - gno.land/r/gnoswap/v1/consts v0.0.0-latest -) From c2a133b3e9078df124d2f193ae73f60e6c09221b Mon Sep 17 00:00:00 2001 From: 0xTopaz <60733299+onlyhyde@users.noreply.github.com> Date: Wed, 18 Dec 2024 15:47:55 +0900 Subject: [PATCH 2/4] Refactor/pool total (#442) * refactor: pool mint * refactor: name change transferFromAndVerify to safeTransferFrom * refactor: support checkTick when position modify * refactor: tick and position update * refactor: pool mint * refactor: collect and burn * refactor: collectProtocol * refactor: setFeeProtocol * refactor: createPool * refactor: Modify code based on code review comments --- _deploy/r/gnoswap/consts/consts.gno | 14 +- pool/errors.gno | 9 +- pool/getter.gno | 45 ++- pool/liquidity_math.gno | 4 +- pool/liquidity_math_test.gno | 4 +- pool/pool.gno | 376 ++++++++++++++--------- pool/pool_manager.gno | 106 +++++-- pool/pool_manager_test.gno | 56 ++++ pool/pool_test.gno | 30 +- pool/pool_transfer.gno | 100 ++++-- pool/pool_transfer_test.gno | 10 +- pool/position.gno | 123 +++++--- pool/position_modify.gno | 55 +++- pool/position_modify_test.gno | 156 ++++++++++ pool/position_test.gno | 21 +- pool/position_update.gno | 79 +++-- pool/position_update_test.gno | 23 +- pool/protocol_fee_pool_creation.gno | 4 +- pool/protocol_fee_withdrawal.gno | 6 +- pool/swap.gno | 10 +- pool/tests/__TEST_pool_burn_test.gnoA | 2 +- pool/tests/__TEST_pool_spec_#6_test.gnoA | 2 +- pool/tick.gno | 253 +++++++++++++-- pool/tick_bitmap.gno | 44 ++- pool/tick_test.gno | 161 ++++++++-- pool/type.gno | 84 +++-- pool/utils.gno | 242 ++++++++++++++- pool/utils_test.gno | 297 +++++++++++++++++- position/_RPC_api.gno | 19 +- position/liquidity_management.gno | 2 +- position/position.gno | 15 +- 31 files changed, 1868 insertions(+), 484 deletions(-) diff --git a/_deploy/r/gnoswap/consts/consts.gno b/_deploy/r/gnoswap/consts/consts.gno index 9f9228e9f..d2ebff665 100644 --- a/_deploy/r/gnoswap/consts/consts.gno +++ b/_deploy/r/gnoswap/consts/consts.gno @@ -6,9 +6,8 @@ import ( // GNOSWAP SERVICE const ( - ADMIN std.Address = "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d" // Admin - DEV_OPS std.Address = "g1mjvd83nnjee3z2g7683er55me9f09688pd4mj9" // DevOps - + ADMIN std.Address = "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d" + DEV_OPS std.Address = "g1mjvd83nnjee3z2g7683er55me9f09688pd4mj9" TOKEN_REGISTER std.Address = "g1er355fkjksqpdtwmhf5penwa82p0rhqxkkyhk5" TOKEN_REGISTER_NAMESPACE string = "gno.land/r/g1er355fkjksqpdtwmhf5penwa82p0rhqxkkyhk5" @@ -21,7 +20,8 @@ const ( GNOT string = "gnot" WRAPPED_WUGNOT string = "gno.land/r/demo/wugnot" - UGNOT_MIN_DEPOSIT_TO_WRAP uint64 = 1000 // defined in https://github.com/gnolang/gno/blob/81a88a2976ba9f2f9127ebbe7fb7d1e1f7fa4bd4/examples/gno.land/r/demo/wugnot/wugnot.gno#L19 + // defined in https://github.com/gnolang/gno/blob/81a88a2976ba9f2f9127ebbe7fb7d1e1f7fa4bd4/examples/gno.land/r/demo/wugnot/wugnot.gno#L19 + UGNOT_MIN_DEPOSIT_TO_WRAP uint64 = 1000 ) // CONTRACT PATH & ADDRESS @@ -91,9 +91,11 @@ const ( MAX_UINT128 string = "340282366920938463463374607431768211455" MAX_UINT160 string = "1461501637330902918203684832716283019655932542975" - MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819968" MAX_UINT256 string = "115792089237316195423570985008687907853269984665640564039457584007913129639935" + MAX_INT128 string = "170141183460469231731687303715884105727" + MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819968" + // Tick Related MIN_TICK int32 = -887272 MAX_TICK int32 = 887272 @@ -108,6 +110,8 @@ const ( Q64 string = "18446744073709551616" // 2 ** 64 Q96 string = "79228162514264337593543950336" // 2 ** 96 Q128 string = "340282366920938463463374607431768211456" // 2 ** 128 + + Q128_RESOLUTION uint = 128 ) // TIMESTAMP & DAY diff --git a/pool/errors.gno b/pool/errors.gno index 11c4efa57..c98286a4d 100644 --- a/pool/errors.gno +++ b/pool/errors.gno @@ -31,8 +31,13 @@ var ( errTransferFailed = errors.New("[GNOSWAP-POOL-021] token transfer failed") errInvalidTickAndTickSpacing = errors.New("[GNOSWAP-POOL-022] invalid tick and tick spacing requested") errInvalidAddress = errors.New("[GNOSWAP-POOL-023] invalid address") - errInvalidTickRange = errors.New("[GNOSWAP-POOL-024] tickLower is greater than tickUpper") - errUnderflow = errors.New("[GNOSWAP-POOL-025] underflow") // TODO: make as common error code + errInvalidTickRange = errors.New("[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper") + errUnderflow = errors.New("[GNOSWAP-POOL-025] underflow") + errOverFlow = errors.New("[GNOSWAP-POOL-026] overflow") + errBalanceUpdateFailed = errors.New("[GNOSWAP-POOL-027] balance update failed") + errTickLowerInvalid = errors.New("[GNOSWAP-POOL-028] tickLower is invalid") + errTickUpperInvalid = errors.New("[GNOSWAP-POOL-029] tickUpper is invalid") + errTickLowerGtTickUpper = errors.New("[GNOSWAP-POOL-030] tickLower is greater than tickUpper") ) // addDetailToError adds detail to an error message diff --git a/pool/getter.gno b/pool/getter.gno index f962846fa..443516141 100644 --- a/pool/getter.gno +++ b/pool/getter.gno @@ -1,6 +1,5 @@ package pool -// pool func PoolGetPoolList() []string { poolPaths := []string{} for poolPath, _ := range pools { @@ -11,91 +10,89 @@ func PoolGetPoolList() []string { } func PoolGetToken0Path(poolPath string) string { - return mustGetPool(poolPath).GetToken0Path() + return mustGetPool(poolPath).Token0Path() } func PoolGetToken1Path(poolPath string) string { - return mustGetPool(poolPath).GetToken1Path() + return mustGetPool(poolPath).Token1Path() } func PoolGetFee(poolPath string) uint32 { - return mustGetPool(poolPath).GetFee() + return mustGetPool(poolPath).Fee() } func PoolGetBalanceToken0(poolPath string) string { - return mustGetPool(poolPath).GetBalanceToken0().ToString() + return mustGetPool(poolPath).BalanceToken0().ToString() } func PoolGetBalanceToken1(poolPath string) string { - return mustGetPool(poolPath).GetBalanceToken1().ToString() + return mustGetPool(poolPath).BalanceToken1().ToString() } func PoolGetTickSpacing(poolPath string) int32 { - return mustGetPool(poolPath).GetTickSpacing() + return mustGetPool(poolPath).TickSpacing() } func PoolGetMaxLiquidityPerTick(poolPath string) string { - return mustGetPool(poolPath).GetMaxLiquidityPerTick().ToString() + return mustGetPool(poolPath).MaxLiquidityPerTick().ToString() } func PoolGetSlot0SqrtPriceX96(poolPath string) string { - return mustGetPool(poolPath).GetSlot0SqrtPriceX96().ToString() + return mustGetPool(poolPath).Slot0SqrtPriceX96().ToString() } func PoolGetSlot0Tick(poolPath string) int32 { - return mustGetPool(poolPath).GetSlot0Tick() + return mustGetPool(poolPath).Slot0Tick() } func PoolGetSlot0FeeProtocol(poolPath string) uint8 { - return mustGetPool(poolPath).GetSlot0FeeProtocol() + return mustGetPool(poolPath).Slot0FeeProtocol() } func PoolGetSlot0Unlocked(poolPath string) bool { - return mustGetPool(poolPath).GetSlot0Unlocked() + return mustGetPool(poolPath).Slot0Unlocked() } func PoolGetFeeGrowthGlobal0X128(poolPath string) string { - return mustGetPool(poolPath).GetFeeGrowthGlobal0X128().ToString() + return mustGetPool(poolPath).FeeGrowthGlobal0X128().ToString() } func PoolGetFeeGrowthGlobal1X128(poolPath string) string { - return mustGetPool(poolPath).GetFeeGrowthGlobal1X128().ToString() + return mustGetPool(poolPath).FeeGrowthGlobal1X128().ToString() } func PoolGetProtocolFeesToken0(poolPath string) string { - return mustGetPool(poolPath).GetProtocolFeesToken0().ToString() + return mustGetPool(poolPath).ProtocolFeesToken0().ToString() } func PoolGetProtocolFeesToken1(poolPath string) string { - return mustGetPool(poolPath).GetProtocolFeesToken1().ToString() + return mustGetPool(poolPath).ProtocolFeesToken1().ToString() } func PoolGetLiquidity(poolPath string) string { - return mustGetPool(poolPath).GetLiquidity().ToString() + return mustGetPool(poolPath).Liquidity().ToString() } -// position func PoolGetPositionLiquidity(poolPath, key string) string { - return mustGetPool(poolPath).GetPositionLiquidity(key).ToString() + return mustGetPool(poolPath).PositionLiquidity(key).ToString() } func PoolGetPositionFeeGrowthInside0LastX128(poolPath, key string) string { - return mustGetPool(poolPath).GetPositionFeeGrowthInside0LastX128(key).ToString() + return mustGetPool(poolPath).PositionFeeGrowthInside0LastX128(key).ToString() } func PoolGetPositionFeeGrowthInside1LastX128(poolPath, key string) string { - return mustGetPool(poolPath).GetPositionFeeGrowthInside1LastX128(key).ToString() + return mustGetPool(poolPath).PositionFeeGrowthInside1LastX128(key).ToString() } func PoolGetPositionTokensOwed0(poolPath, key string) string { - return mustGetPool(poolPath).GetPositionTokensOwed0(key).ToString() + return mustGetPool(poolPath).PositionTokensOwed0(key).ToString() } func PoolGetPositionTokensOwed1(poolPath, key string) string { - return mustGetPool(poolPath).GetPositionTokensOwed1(key).ToString() + return mustGetPool(poolPath).PositionTokensOwed1(key).ToString() } -// tick func PoolGetTickLiquidityGross(poolPath string, tick int32) string { return mustGetPool(poolPath).GetTickLiquidityGross(tick).ToString() } diff --git a/pool/liquidity_math.gno b/pool/liquidity_math.gno index 47aaea2cc..82a8b4317 100644 --- a/pool/liquidity_math.gno +++ b/pool/liquidity_math.gno @@ -44,7 +44,7 @@ func liquidityMathAddDelta(x *u256.Uint, y *i256.Int) *u256.Uint { if z.Gte(x) { panic(addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("Less than Condition(z must be < x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + ufmt.Sprintf("Condition failed: (z must be < x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), )) } } else { @@ -52,7 +52,7 @@ func liquidityMathAddDelta(x *u256.Uint, y *i256.Int) *u256.Uint { if z.Lt(x) { panic(addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("Less than or Equal Condition(z must be >= x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + ufmt.Sprintf("Condition failed: (z must be >= x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), )) } } diff --git a/pool/liquidity_math_test.gno b/pool/liquidity_math_test.gno index fc2f379b2..6c38f3df7 100644 --- a/pool/liquidity_math_test.gno +++ b/pool/liquidity_math_test.gno @@ -42,7 +42,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { }, wantPanic: addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("Less than Condition(z must be < x) (x: 0, y: -100, z:115792089237316195423570985008687907853269984665640564039457584007913129639836)")), + ufmt.Sprintf("Condition failed: (z must be < x) (x: 0, y: -100, z:115792089237316195423570985008687907853269984665640564039457584007913129639836)")), }, { name: "overflow panic with add delta", @@ -53,7 +53,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { }, wantPanic: addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("Less than or Equal Condition(z must be >= x) (x: 115792089237316195423570985008687907853269984665640564039457584007913129639935, y: 100, z:99)")), + ufmt.Sprintf("Condition failed: (z must be >= x) (x: 115792089237316195423570985008687907853269984665640564039457584007913129639935, y: 100, z:99)")), }, } diff --git a/pool/pool.gno b/pool/pool.gno index db4df9b8e..df2e51ee8 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -12,8 +12,32 @@ import ( u256 "gno.land/p/gnoswap/uint256" ) -// Mint creates a new position and mints liquidity tokens. -// Returns minted amount0, amount1 in string +// Mint adds liquidity to a pool by minting a new position. +// +// This function mints a liquidity position within the specified tick range in a pool. +// It verifies caller permissions, validates inputs, and updates the pool's state. Additionally, +// it transfers the required amounts of token0 and token1 from the caller to the pool. +// +// Parameters: +// - token0Path: string, the path or identifier for token 0. +// - token1Path: string, the path or identifier for token 1. +// - fee: uint32, the fee tier of the pool. +// - recipient: std.Address, the address to receive the newly created position. +// - tickLower: int32, the lower tick boundary of the position. +// - tickUpper: int32, the upper tick boundary of the position. +// - liquidityAmount: string, the amount of liquidity to add, provided as a decimal string. +// - positionCaller: std.Address, the address of the entity calling the function (e.g., the position owner). +// +// Returns: +// - string: The amount of token 0 transferred to the pool as a string. +// - string: The amount of token 1 transferred to the pool as a string. +// +// Panic Conditions: +// - The system is halted (`common.IsHalted()`). +// - Caller lacks permission to mint a position when `common.GetLimitCaller()` is enforced. +// - The provided `liquidityAmount` is zero. +// - Any failure during token transfers or position modifications. +// // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#mint func Mint( token0Path string, @@ -22,43 +46,54 @@ func Mint( recipient std.Address, tickLower int32, tickUpper int32, - _liquidityAmount string, + liquidityAmount string, positionCaller std.Address, ) (string, string) { - common.IsHalted() + assertOnlyNotHalted() if common.GetLimitCaller() { - caller := std.PrevRealm().Addr() - if err := common.PositionOnly(caller); err != nil { - panic(addDetailToError( - errNoPermission, - ufmt.Sprintf("only position(%s) can call pool mint(), called from %s", consts.POSITION_ADDR, caller.String()), - )) - } + assertOnlyPositionContract() } - liquidityAmount := u256.MustFromDecimal(_liquidityAmount) - if liquidityAmount.IsZero() { + liquidity := u256.MustFromDecimal(liquidityAmount) + if liquidity.IsZero() { panic(errZeroLiquidity) } pool := GetPool(token0Path, token1Path, fee) - position := newModifyPositionParams(recipient, tickLower, tickUpper, i256.FromUint256(liquidityAmount)) + liquidityDelta := safeConvertToInt128(liquidity) + position := newModifyPositionParams(recipient, tickLower, tickUpper, liquidityDelta) _, amount0, amount1 := pool.modifyPosition(position) if amount0.Gt(u256.Zero()) { - pool.transferFromAndVerify(positionCaller, consts.POOL_ADDR, pool.token0Path, amount0, true) + pool.safeTransferFrom(positionCaller, consts.POOL_ADDR, pool.token0Path, amount0, true) } if amount1.Gt(u256.Zero()) { - pool.transferFromAndVerify(positionCaller, consts.POOL_ADDR, pool.token1Path, amount1, false) + pool.safeTransferFrom(positionCaller, consts.POOL_ADDR, pool.token1Path, amount1, false) } return amount0.ToString(), amount1.ToString() } -// Burn removes liquidity from the caller and account tokens owed for the liquidity to the position -// If liquidity of 0 is burned, it recalculates fees owed to a position -// Returns burned amount0, amount1 in string +// Burn removes liquidity from a position in the pool. +// +// This function allows the caller to burn (remove) a specified amount of liquidity from a position. +// It calculates the amounts of token0 and token1 released when liquidity is removed and updates +// the position's owed token amounts. The actual transfer of tokens back to the caller happens +// during a separate `Collect()` operation. +// +// Parameters: +// - token0Path: string, the path or identifier for token 0. +// - token1Path: string, the path or identifier for token 1. +// - fee: uint32, the fee tier of the pool. +// - tickLower: int32, the lower tick boundary of the position. +// - tickUpper: int32, the upper tick boundary of the position. +// - liquidityAmount: string, the amount of liquidity to remove, provided as a decimal string (uint128). +// +// Returns: +// - string: The amount of token0 owed after removing liquidity, as a string. +// - string: The amount of token1 owed after removing liquidity, as a string. +// // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#burn func Burn( token0Path string, @@ -68,40 +103,54 @@ func Burn( tickUpper int32, liquidityAmount string, // uint128 ) (string, string) { // uint256 x2 - common.IsHalted() - caller := std.PrevRealm().Addr() + assertOnlyNotHalted() if common.GetLimitCaller() { - if err := common.PositionOnly(caller); err != nil { - panic(addDetailToError( - errNoPermission, - ufmt.Sprintf("only position(%s) can call pool burn(), called from %s", consts.POSITION_ADDR, caller.String()), - )) - } + assertOnlyPositionContract() } - - liqAmount := u256.MustFromDecimal(liquidityAmount) - pool := GetPool(token0Path, token1Path, fee) - liqDelta := i256.Zero().Neg(i256.FromUint256(liqAmount)) + caller := getPrevAddr() + liqAmount := u256.MustFromDecimal(liquidityAmount) + liqAmountInt256 := safeConvertToInt128(liqAmount) + liqDelta := i256.Zero().Neg(liqAmountInt256) posParams := newModifyPositionParams(caller, tickLower, tickUpper, liqDelta) position, amount0, amount1 := pool.modifyPosition(posParams) if amount0.Gt(u256.Zero()) || amount1.Gt(u256.Zero()) { + amount0 = toUint128(amount0) + amount1 = toUint128(amount1) position.tokensOwed0 = new(u256.Uint).Add(position.tokensOwed0, amount0) position.tokensOwed1 = new(u256.Uint).Add(position.tokensOwed1, amount1) } - positionKey := positionGetKey(caller, tickLower, tickUpper) + positionKey := getPositionKey(caller, tickLower, tickUpper) pool.positions[positionKey] = position // actual token transfer happens in Collect() return amount0.ToString(), amount1.ToString() } -// Collect collects tokens owed to a position -// Burned amounts, and swap fees will be transferred to the caller -// Returns collected amount0, amount1 in string +// Collect handles the collection of tokens (token0 and token1) from a liquidity position. +// +// This function allows the caller to collect a specified amount of tokens owed to a position +// in a liquidity pool. It calculates the collectible amount based on three constraints: +// the requested amount, the tokens owed to the position, and the pool's available balance. +// The collected tokens are transferred to the specified recipient. +// +// Parameters: +// - token0Path: string, the path or identifier for token 0. +// - token1Path: string, the path or identifier for token 1. +// - fee: uint32, the fee tier of the pool. +// - recipient: std.Address, the address to receive the collected tokens. +// - tickLower: int32, the lower tick boundary of the position. +// - tickUpper: int32, the upper tick boundary of the position. +// - amount0Requested: string, the requested amount of token 0 to collect (decimal string). +// - amount1Requested: string, the requested amount of token 1 to collect (decimal string). +// +// Returns: +// - string: The actual amount of token 0 collected, as a string. +// - string: The actual amount of token 1 collected, as a string. +// // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#collect func Collect( token0Path string, @@ -113,76 +162,61 @@ func Collect( amount0Requested string, amount1Requested string, ) (string, string) { - common.IsHalted() + assertOnlyNotHalted() if common.GetLimitCaller() { - caller := std.PrevRealm().Addr() - if err := common.PositionOnly(caller); err != nil { - panic(addDetailToError( - errNoPermission, - ufmt.Sprintf("only position(%s) can call pool collect(), called from %s", consts.POSITION_ADDR, caller.String()), - )) - } + assertOnlyPositionContract() } pool := GetPool(token0Path, token1Path, fee) - - positionKey := positionGetKey(std.PrevRealm().Addr(), tickLower, tickUpper) - position, exist := pool.positions[positionKey] - if !exist { - panic(addDetailToError( - errDataNotFound, - ufmt.Sprintf("positionKey(%s) does not exist", positionKey), - )) - } + positionKey := getPositionKey(getPrevAddr(), tickLower, tickUpper) + position := pool.mustGetPosition(positionKey) var amount0, amount1 *u256.Uint // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 amount0Req := u256.MustFromDecimal(amount0Requested) - amount0, position.tokensOwed0, pool.balances.token0 = collectToken(amount0Req, position.tokensOwed0, pool.balances.token0) - token0 := common.GetTokenTeller(pool.token0Path) - checkTransferError(token0.Transfer(recipient, amount0.Uint64())) - - // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 + amount0 = collectToken(amount0Req, position.tokensOwed0, pool.BalanceToken0()) amount1Req := u256.MustFromDecimal(amount1Requested) - amount1, position.tokensOwed1, pool.balances.token1 = collectToken(amount1Req, position.tokensOwed1, pool.balances.token1) + amount1 = collectToken(amount1Req, position.tokensOwed1, pool.BalanceToken1()) - // Update state first then transfer - position.tokensOwed1 = new(u256.Uint).Sub(position.tokensOwed1, amount1) - pool.balances.token1 = new(u256.Uint).Sub(pool.balances.token1, amount1) - token1 := common.GetTokenTeller(pool.token1Path) - checkTransferError(token1.Transfer(recipient, amount1.Uint64())) + if amount0.Gt(u256.Zero()) { + position.tokensOwed0 = new(u256.Uint).Sub(position.tokensOwed0, amount0) + pool.balances.token0 = new(u256.Uint).Sub(pool.balances.token0, amount0) + token0 := common.GetTokenTeller(pool.token0Path) + checkTransferError(token0.Transfer(recipient, amount0.Uint64())) + } + if amount1.Gt(u256.Zero()) { + position.tokensOwed1 = new(u256.Uint).Sub(position.tokensOwed1, amount1) + pool.balances.token1 = new(u256.Uint).Sub(pool.balances.token1, amount1) + token1 := common.GetTokenTeller(pool.token1Path) + checkTransferError(token1.Transfer(recipient, amount1.Uint64())) + } pool.positions[positionKey] = position return amount0.ToString(), amount1.ToString() } -// collectToken handles the collection of tokens (either token0 or token1) from a position. -// It calculates the actual amount that can be collected based on three constraints: -// the requested amount, tokens owed, and available pool balance. +// collectToken calculates the actual amount of tokens that can be collected. +// +// This function determines the smallest possible value among the requested amount (`amountReq`), +// the tokens owed (`tokensOwed`), and the pool's available balance (`poolBalance`). It ensures +// the collected amount does not exceed any of these constraints. // // Parameters: -// - amountReq: amount requested to collect -// - tokensOwed: amount of tokens owed to the position -// - poolBalance: current balance of tokens in the pool +// - amountReq: *u256.Uint, the amount of tokens requested for collection. +// - tokensOwed: *u256.Uint, the total amount of tokens owed to the position. +// - poolBalance: *u256.Uint, the current balance of tokens available in the pool. // // Returns: -// - amount: actual amount that will be collected (minimum of the three inputs) -// - newTokensOwed: remaining tokens owed after collection -// - newPoolBalance: remaining pool balance after collection +// - amount: *u256.Uint, the actual amount that can be collected (minimum of the three inputs). func collectToken( amountReq, tokensOwed, poolBalance *u256.Uint, -) (amount, newTokensOwed, newPoolBalance *u256.Uint) { +) (amount *u256.Uint) { // find smallest of three amounts amount = u256Min(amountReq, tokensOwed) amount = u256Min(amount, poolBalance) - - // value for update state - newTokensOwed = new(u256.Uint).Sub(tokensOwed, amount) - newPoolBalance = new(u256.Uint).Sub(poolBalance, amount) - - return amount, newTokensOwed, newPoolBalance + return amount.Clone() } // SetFeeProtocolByAdmin sets the fee protocol for all pools @@ -191,22 +225,8 @@ func SetFeeProtocolByAdmin( feeProtocol0 uint8, feeProtocol1 uint8, ) { - caller := std.PrevRealm().Addr() - if err := common.AdminOnly(caller); err != nil { - panic(err) - } - - newFee := setFeeProtocol(feeProtocol0, feeProtocol1) - - prevAddr, prevRealm := getPrev() - std.Emit( - "SetFeeProtocolByAdmin", - "prevAddr", prevAddr, - "prevRealm", prevRealm, - "feeProtocol0", ufmt.Sprintf("%d", feeProtocol0), - "feeProtocol1", ufmt.Sprintf("%d", feeProtocol1), - "internal_newFee", ufmt.Sprintf("%d", newFee), - ) + assertOnlyAdmin() + setFeeProtocolInternal(feeProtocol0, feeProtocol1, "SetFeeProtocolByAdmin") } // SetFeeProtocol sets the fee protocol for all pools @@ -214,40 +234,62 @@ func SetFeeProtocolByAdmin( // Also it will be applied to new created pools // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#setfeeprotocol func SetFeeProtocol(feeProtocol0, feeProtocol1 uint8) { - caller := std.PrevRealm().Addr() - if err := common.GovernanceOnly(caller); err != nil { - panic(err) - } + assertOnlyGovernance() + setFeeProtocolInternal(feeProtocol0, feeProtocol1, "SetFeeProtocol") +} +// setFeeProtocolInternal updates the protocol fee for all pools and emits an event. +// +// This function is an internal utility used to set the protocol fee for token0 and token1 in a compact +// format. The fee values are stored as a single `uint8` byte where: +// - Lower 4 bits represent the fee for token0 (feeProtocol0). +// - Upper 4 bits represent the fee for token1 (feeProtocol1). +// +// It also emits an event to log the changes, including the previous and new fee protocol values. +// +// Parameters: +// - feeProtocol0: uint8, protocol fee for token0 (must be 0 or between 4 and 10 inclusive). +// - feeProtocol1: uint8, protocol fee for token1 (must be 0 or between 4 and 10 inclusive). +// - eventName: string, the name of the event to emit (e.g., "SetFeeProtocolByAdmin"). +// +// Notes: +// - This function is called by higher-level functions like `SetFeeProtocolByAdmin` or `SetFeeProtocol`. +// - It does not validate caller permissions; validation must be performed by the calling function. +func setFeeProtocolInternal(feeProtocol0, feeProtocol1 uint8, eventName string) { + oldFee := slot0FeeProtocol newFee := setFeeProtocol(feeProtocol0, feeProtocol1) - prevAddr, prevRealm := getPrev() + feeProtocol0Old := oldFee % 16 + feeProtocol1Old := oldFee >> 4 + + prevAddr, prevPkgPath := getPrevAsString() std.Emit( - "SetFeeProtocol", + eventName, "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, + "feeProtocol0Old", ufmt.Sprintf("%d", feeProtocol0Old), + "feeProtocol1Old", ufmt.Sprintf("%d", feeProtocol1Old), "feeProtocol0", ufmt.Sprintf("%d", feeProtocol0), "feeProtocol1", ufmt.Sprintf("%d", feeProtocol1), "internal_newFee", ufmt.Sprintf("%d", newFee), ) } -// setFeeProtocol updates the protocol fee configuration for all existing pools and sets -// the default for new pools. This is an internal function called by both `admin` and `governance` -// protocol fee management functions. +// setFeeProtocol updates the protocol fee configuration for all managed pools. // -// The protocol fee is stored as a single `uint8` value where: -// - Lower 4 bits store feeProtocol0 (for token0) -// - Upper 4 bits store feeProtocol1 (for token1) +// This function combines the protocol fee values for token0 and token1 into a single `uint8` value, +// where: +// - Lower 4 bits store feeProtocol0 (for token0). +// - Upper 4 bits store feeProtocol1 (for token1). // -// This compact representation allows storing both fee values in a single byte. +// The updated fee protocol is applied uniformly to all pools managed by the system. // -// Parameters (must be 0 or between 4 and 10 inclusive): -// - feeProtocol0: protocol fee for token0 -// - feeProtocol1: protocol fee for token1 +// Parameters: +// - feeProtocol0: protocol fee for token0 (must be 0 or between 4 and 10 inclusive). +// - feeProtocol1: protocol fee for token1 (must be 0 or between 4 and 10 inclusive). // // Returns: -// - newFee (uint8): the combined fee protocol value +// - newFee (uint8): the combined fee protocol value. // // Example: // If feeProtocol0 = 4 and feeProtocol1 = 5: @@ -257,32 +299,47 @@ func SetFeeProtocol(feeProtocol0, feeProtocol1 uint8) { // // Binary: 0101 0100 // // ^^^^ ^^^^ // // fee1=5 fee0=4 +// +// Notes: +// - This function ensures that all pools under management are updated to use the same fee protocol. +// - Caller restrictions (e.g., admin or governance) are not enforced in this function. +// - Ensure the system is not halted before updating fees. func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 { - common.IsHalted() - + assertOnlyNotHalted() if err := validateFeeProtocol(feeProtocol0, feeProtocol1); err != nil { panic(addDetailToError( err, ufmt.Sprintf("expected (feeProtocol0(%d) == 0 || (feeProtocol0(%d) >= 4 && feeProtocol0(%d) <= 10)) && (feeProtocol1(%d) == 0 || (feeProtocol1(%d) >= 4 && feeProtocol1(%d) <= 10))", feeProtocol0, feeProtocol0, feeProtocol0, feeProtocol1, feeProtocol1, feeProtocol1), )) } - // combine both protocol fee into a single byte: // - feePrtocol0 occupies the lower 4 bits // - feeProtocol1 is shifted the lower 4 positions to occupy the upper 4 bits newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 ) - - // iterate all pool + // Update slot0 for each pool for _, pool := range pools { - pool.slot0.feeProtocol = newFee + if pool != nil { + pool.slot0.feeProtocol = newFee + } } - // update slot0 slot0FeeProtocol = newFee - return newFee } +// validateFeeProtocol validates the fee protocol values for token0 and token1. +// +// This function checks whether the provided fee protocol values (`feeProtocol0` and `feeProtocol1`) +// are valid using the `isValidFeeProtocolValue` function. If either value is invalid, it returns +// an error indicating that the protocol fee percentage is invalid. +// +// Parameters: +// - feeProtocol0: uint8, the fee protocol value for token0. +// - feeProtocol1: uint8, the fee protocol value for token1. +// +// Returns: +// - error: Returns `errInvalidProtocolFeePct` if either `feeProtocol0` or `feeProtocol1` is invalid. +// Returns `nil` if both values are valid. func validateFeeProtocol(feeProtocol0, feeProtocol1 uint8) error { if !isValidFeeProtocolValue(feeProtocol0) || !isValidFeeProtocolValue(feeProtocol1) { return errInvalidProtocolFeePct @@ -303,17 +360,10 @@ func CollectProtocolByAdmin( token1Path string, fee uint32, recipient std.Address, - amount0Requested string, // uint128 - amount1Requested string, // uint128 -) (string, string) { // uint128 x2 - common.MustRegistered(token0Path) - common.MustRegistered(token1Path) - - caller := std.PrevRealm().Addr() - if err := common.AdminOnly(caller); err != nil { - panic(err) - } - + amount0Requested string, + amount1Requested string, +) (string, string) { + assertOnlyAdmin() amount0, amount1 := collectProtocol( token0Path, token1Path, @@ -323,11 +373,11 @@ func CollectProtocolByAdmin( amount1Requested, ) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "CollectProtocolByAdmin", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "token0Path", token0Path, "token1Path", token1Path, "fee", ufmt.Sprintf("%d", fee), @@ -351,14 +401,7 @@ func CollectProtocol( amount0Requested string, // uint128 amount1Requested string, // uint128 ) (string, string) { // uint128 x2 - common.MustRegistered(token0Path) - common.MustRegistered(token1Path) - - caller := std.PrevRealm().Addr() - if err := common.GovernanceOnly(caller); err != nil { - panic(err) - } - + assertOnlyGovernance() amount0, amount1 := collectProtocol( token0Path, token1Path, @@ -368,11 +411,11 @@ func CollectProtocol( amount1Requested, ) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "CollectProtocol", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "token0Path", token0Path, "token1Path", token1Path, "fee", ufmt.Sprintf("%d", fee), @@ -384,6 +427,23 @@ func CollectProtocol( return amount0, amount1 } +// collectProtocol collects protocol fees for token0 and token1 from the specified pool. +// +// This function allows the collection of accumulated protocol fees for token0 and token1. It ensures +// the requested amounts do not exceed the available protocol fees in the pool and transfers the +// collected amounts to the specified recipient. +// +// Parameters: +// - token0Path: string, the path or identifier for token0. +// - token1Path: string, the path or identifier for token1. +// - fee: uint32, the fee tier of the pool. +// - recipient: std.Address, the address to receive the collected protocol fees. +// - amount0Requested: string, the requested amount of token0 to collect (decimal string). +// - amount1Requested: string, the requested amount of token1 to collect (decimal string). +// +// Returns: +// - string: The actual amount of token0 collected, as a string. +// - string: The actual amount of token1 collected, as a string. func collectProtocol( token0Path string, token1Path string, @@ -392,25 +452,37 @@ func collectProtocol( amount0Requested string, amount1Requested string, ) (string, string) { - common.IsHalted() + assertOnlyRegistered(token0Path) + assertOnlyRegistered(token1Path) + assertOnlyNotHalted() pool := GetPool(token0Path, token1Path, fee) amount0Req := u256.MustFromDecimal(amount0Requested) amount1Req := u256.MustFromDecimal(amount1Requested) - amount0 := u256Min(amount0Req, pool.protocolFees.token0) - amount1 := u256Min(amount1Req, pool.protocolFees.token1) + amount0 := u256Min(amount0Req, pool.ProtocolFeesToken0()) + amount1 := u256Min(amount1Req, pool.ProtocolFeesToken1()) - amount0, amount1 = pool.saveProtocolFees(amount0, amount1) + amount0, amount1 = pool.saveProtocolFees(amount0.Clone(), amount1.Clone()) uAmount0 := amount0.Uint64() uAmount1 := amount1.Uint64() token0Teller := common.GetTokenTeller(pool.token0Path) checkTransferError(token0Teller.Transfer(recipient, uAmount0)) + newBalanceToken0, err := updatePoolBalance(pool.BalanceToken0(), pool.BalanceToken1(), amount0, true) + if err != nil { + panic(err) + } + pool.balances.token0 = newBalanceToken0 token1Teller := common.GetTokenTeller(pool.token1Path) checkTransferError(token1Teller.Transfer(recipient, uAmount1)) + newBalanceToken1, err := updatePoolBalance(pool.BalanceToken0(), pool.BalanceToken1(), amount1, false) + if err != nil { + panic(err) + } + pool.balances.token1 = newBalanceToken1 return amount0.ToString(), amount1.ToString() } @@ -424,19 +496,19 @@ func collectProtocol( // Returns the adjusted amounts that will actually be collected for both tokens. func (p *Pool) saveProtocolFees(amount0, amount1 *u256.Uint) (*u256.Uint, *u256.Uint) { cond01 := amount0.Gt(u256.Zero()) - cond02 := amount0.Eq(p.protocolFees.token0) + cond02 := amount0.Eq(p.ProtocolFeesToken0()) if cond01 && cond02 { amount0 = new(u256.Uint).Sub(amount0, u256.One()) } cond11 := amount1.Gt(u256.Zero()) - cond12 := amount1.Eq(p.protocolFees.token1) + cond12 := amount1.Eq(p.ProtocolFeesToken1()) if cond11 && cond12 { amount1 = new(u256.Uint).Sub(amount1, u256.One()) } - p.protocolFees.token0 = new(u256.Uint).Sub(p.protocolFees.token0, amount0) - p.protocolFees.token1 = new(u256.Uint).Sub(p.protocolFees.token1, amount1) + p.protocolFees.token0 = new(u256.Uint).Sub(p.ProtocolFeesToken0(), amount0) + p.protocolFees.token1 = new(u256.Uint).Sub(p.ProtocolFeesToken1(), amount1) // return rest fee return amount0, amount1 diff --git a/pool/pool_manager.gno b/pool/pool_manager.gno index c8f5c25be..575b95594 100644 --- a/pool/pool_manager.gno +++ b/pool/pool_manager.gno @@ -6,7 +6,6 @@ import ( "gno.land/p/demo/ufmt" - "gno.land/r/gnoswap/v1/common" "gno.land/r/gnoswap/v1/consts" en "gno.land/r/gnoswap/v1/emission" @@ -130,13 +129,10 @@ func CreatePool( token0Path string, token1Path string, fee uint32, - _sqrtPriceX96 string, + sqrtPriceX96 string, ) { - common.IsHalted() - en.MintAndDistributeGns() - - poolInfo := newPoolParams(token0Path, token1Path, fee, _sqrtPriceX96) - + assertOnlyNotHalted() + poolInfo := newPoolParams(token0Path, token1Path, fee, sqrtPriceX96) if poolInfo.isSameTokenPath() { panic(addDetailToError( errDuplicateTokenInPool, @@ -146,14 +142,14 @@ func CreatePool( ), )) } + en.MintAndDistributeGns() // wrap first token0Path, token1Path = poolInfo.wrap() - poolPath := GetPoolPath(token0Path, token1Path, fee) // reinitialize poolInfo with wrapped tokens - poolInfo = newPoolParams(token0Path, token1Path, fee, _sqrtPriceX96) + poolInfo = newPoolParams(token0Path, token1Path, fee, sqrtPriceX96) // then check if token0Path == token1Path if poolInfo.isSameTokenPath() { @@ -173,8 +169,7 @@ func CreatePool( )) } - // TODO: make this as a parameter - prevAddr, prevRealm := getPrev() + prevAddr, prevRealm := getPrevAsString() // check whether the pool already exist pool, exist := pools.Get(poolPath) @@ -208,7 +203,7 @@ func CreatePool( "token0Path", token0Path, "token1Path", token1Path, "fee", ufmt.Sprintf("%d", fee), - "sqrtPriceX96", _sqrtPriceX96, + "sqrtPriceX96", sqrtPriceX96, "internal_poolPath", poolPath, ) } @@ -220,46 +215,91 @@ func DoesPoolPathExist(poolPath string) bool { return exist } -// GetPool retrieves the pool for the given token paths and fee. -// It constructs the poolPath from the given parameters and returns the corresponding pool. -// Returns pool struct +// GetPool retrieves a pool instance based on the provided token paths and fee tier. +// +// This function determines the pool path by combining the paths of token0 and token1 along with the fee tier, +// and then retrieves the corresponding pool instance using that path. +// +// Parameters: +// - token0Path (string): The unique path for token0. +// - token1Path (string): The unique path for token1. +// - fee (uint32): The fee tier for the pool, expressed in basis points (e.g., 3000 for 0.3%). +// +// Returns: +// - *Pool: A pointer to the Pool instance corresponding to the provided tokens and fee tier. +// +// Notes: +// - The order of token paths (token0Path and token1Path) matters and should match the pool's configuration. +// - Ensure that the tokens and fee tier provided are valid and registered in the system. +// +// Example: +// pool := GetPool("gno.land/r/demo/wugnot", "gno.land/r/gnoswap/v1/gns", 3000) func GetPool(token0Path, token1Path string, fee uint32) *Pool { poolPath := GetPoolPath(token0Path, token1Path, fee) - pool, exist := pools[poolPath] - if !exist { - panic(addDetailToError( - errDataNotFound, - ufmt.Sprintf("pool_manager.gno__GetPool() || expected poolPath(%s) to exist", poolPath), - )) - } - - return pool + return GetPoolFromPoolPath(poolPath) } -// GetPoolFromPoolPath retrieves the pool for the given poolPath. +// GetPoolFromPoolPath retrieves a pool instance based on the provided pool path. +// +// This function checks if a pool exists for the given poolPath in the `pools` mapping. +// If the pool exists, it returns the pool instance. Otherwise, it panics with a descriptive error. +// +// Parameters: +// - poolPath (string): The unique identifier or path for the pool. +// +// Returns: +// - *Pool: A pointer to the Pool instance corresponding to the given poolPath. +// +// Panics: +// - If the `poolPath` does not exist in the `pools` mapping, it panics with an error message +// indicating that the expected poolPath was not found. +// +// Notes: +// - Ensure that the `poolPath` provided is valid and corresponds to an existing pool in the `pools` mapping. +// +// Example: +// pool := GetPoolFromPoolPath("path/to/pool") func GetPoolFromPoolPath(poolPath string) *Pool { pool, exist := pools[poolPath] if !exist { panic(addDetailToError( errDataNotFound, - ufmt.Sprintf("pool_manager.gno__GetPoolFromPoolPath() || expected poolPath(%s) to exist", poolPath), + ufmt.Sprintf("expected poolPath(%s) to exist", poolPath), )) } - return pool } -// GetPoolPath generates a poolPath from the given token paths and fee. -// The poolPath is constructed by joining the token paths and fee with colons. +// GetPoolPath generates a unique pool path string based on the token paths and fee tier. +// +// This function ensures that the token paths are registered and sorted in alphabetical order +// before combining them with the fee tier to create a unique identifier for the pool. +// +// Parameters: +// - token0Path (string): The unique identifier or path for token0. +// - token1Path (string): The unique identifier or path for token1. +// - fee (uint32): The fee tier for the pool, expressed in basis points (e.g., 3000 for 0.3%). +// +// Returns: +// - string: A unique pool path string in the format "token0Path:token1Path:fee". +// +// Notes: +// - The function validates that both `token0Path` and `token1Path` are registered in the system +// using `common.MustRegistered`. +// - The token paths are sorted alphabetically to ensure consistent pool path generation, regardless +// of the input order. +// - This sorting guarantees that the pool path remains deterministic for the same pair of tokens and fee. +// +// Example: +// poolPath := GetPoolPath("path/to/token0", "path/to/token1", 3000) +// // Output: "path/to/token0:path/to/token1:3000" func GetPoolPath(token0Path, token1Path string, fee uint32) string { - common.MustRegistered(token0Path) - common.MustRegistered(token1Path) + assertOnlyRegistered(token0Path) + assertOnlyRegistered(token1Path) - // TODO: this check is not unnecessary, if we are sure that // all the token paths in the pool are sorted in alphabetical order. if strings.Compare(token1Path, token0Path) < 0 { token0Path, token1Path = token1Path, token0Path } - return ufmt.Sprintf("%s:%s:%d", token0Path, token1Path, fee) } diff --git a/pool/pool_manager_test.gno b/pool/pool_manager_test.gno index 328315b56..36637a281 100644 --- a/pool/pool_manager_test.gno +++ b/pool/pool_manager_test.gno @@ -202,3 +202,59 @@ func TestCreatePool(t *testing.T) { resetObject(t) } + +func TestGetPool(t *testing.T) { + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) + shouldPanic bool + expected string + }{ + { + name: "Panic - unregisterd poolPath ", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + }, + action: func(t *testing.T) { + GetPool(barPath, fooPath, fee500) + }, + shouldPanic: true, + expected: "[GNOSWAP-POOL-008] requested data not found || expected poolPath(gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500) to exist", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tt.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + case error: + if r.(error).Error() != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r.(error).Error(), tt.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + } + }() + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + } + }) + } +} diff --git a/pool/pool_test.gno b/pool/pool_test.gno index b58d95ac3..7a7050321 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -1,13 +1,10 @@ package pool import ( - "std" "testing" "gno.land/p/demo/testutils" "gno.land/p/demo/uassert" - - i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" "gno.land/r/gnoswap/v1/consts" @@ -100,7 +97,7 @@ func TestBurn(t *testing.T) { } // setup position for this test - posKey := positionGetKey(mockCaller, tt.tickLower, tt.tickUpper) + posKey := getPositionKey(mockCaller, tt.tickLower, tt.tickUpper) mockPool.positions[posKey] = mockPosition if tt.expectPanic { @@ -140,3 +137,28 @@ func TestBurn(t *testing.T) { } } +func TestSetFeeProtocolInternal(t *testing.T) { + tests := []struct { + name string + feeProtocol0 uint8 + feeProtocol1 uint8 + eventName string + }{ + { + name: "set fee protocol by admin", + feeProtocol0: 4, + feeProtocol1: 5, + eventName: "SetFeeProtocolByAdmin", + }, + } + + for _, tt := range tests { + t.Run("set fee protocol by admin", func(t *testing.T) { + InitialisePoolTest(t) + pool := GetPool(wugnotPath, gnsPath, fee3000) + SetFeeProtocolByAdmin(tt.feeProtocol0, tt.feeProtocol1) + uassert.Equal(t, tt.feeProtocol0, pool.Slot0FeeProtocol()%16) + uassert.Equal(t, tt.feeProtocol1, pool.Slot0FeeProtocol()>>4) + }) + } +} diff --git a/pool/pool_transfer.gno b/pool/pool_transfer.gno index 201ae2cfb..da3951f01 100644 --- a/pool/pool_transfer.gno +++ b/pool/pool_transfer.gno @@ -11,7 +11,7 @@ import ( u256 "gno.land/p/gnoswap/uint256" ) -// transferAndVerify performs a token transfer out of the pool while ensuring +// safeTransfer performs a token transfer out of the pool while ensuring // the pool has sufficient balance and updating internal accounting. // This function is typically used during swaps and liquidity removals. // @@ -33,7 +33,7 @@ import ( // 4. Update pool's internal balance // // Panics if any validation fails or if the transfer fails -func (p *Pool) transferAndVerify( +func (p *Pool) safeTransfer( to std.Address, tokenPath string, amount *i256.Int, @@ -47,16 +47,13 @@ func (p *Pool) transferAndVerify( absAmount := amount.Abs() - token0 := p.balances.token0 - token1 := p.balances.token1 + token0 := p.BalanceToken0() + token1 := p.BalanceToken1() if err := validatePoolBalance(token0, token1, absAmount, isToken0); err != nil { panic(err) } - amountUint64, err := safeConvertToUint64(absAmount) - if err != nil { - panic(err) - } + amountUint64 := safeConvertToUint64(absAmount) token := common.GetTokenTeller(tokenPath) checkTransferError(token.Transfer(to, amountUint64)) @@ -73,45 +70,79 @@ func (p *Pool) transferAndVerify( } } -// transferFromAndVerify performs a token transfer into the pool using transferFrom -// while updating the pool's internal accounting. This function is typically used -// during swaps and liquidity additions. +// safeTransferFrom securely transfers tokens into the pool while ensuring balance consistency. // -// The function assumes the sender has approved the pool to spend their tokens. +// This function performs the following steps: +// 1. Validates and converts the transfer amount to `uint64` using `safeConvertToUint64`. +// 2. Executes the token transfer using `TransferFrom` via the token teller contract. +// 3. Verifies that the destination balance reflects the correct amount after transfer. +// 4. Updates the pool's internal balances (`token0` or `token1`) and validates the updated state. // // Parameters: -// - from: source address for the transfer -// - to: destination address (typically the pool) -// - tokenPath: path identifier of the token to transfer -// - amount: amount to transfer (must be positive) -// - isToken0: true if transferring token0, false for token1 +// - from (std.Address): Source address for the token transfer. +// - to (std.Address): Destination address, typically the pool address. +// - tokenPath (string): Path identifier for the token being transferred. +// - amount (*u256.Uint): The amount of tokens to transfer (must be a positive value). +// - isToken0 (bool): A flag indicating whether the token being transferred is token0 (`true`) or token1 (`false`). // -// The function will: -// 1. Convert amount to uint64 (must fit) -// 2. Execute the transferFrom -// 3. Update pool's internal balance +// Panics: +// - If the `amount` exceeds the uint64 range during conversion. +// - If the token transfer (`TransferFrom`) fails. +// - If the destination balance after the transfer does not match the expected amount. +// - If the pool's internal balances (`token0` or `token1`) overflow or become inconsistent. // -// Panics if the amount conversion fails or if the transfer fails -func (p *Pool) transferFromAndVerify( +// Notes: +// - The function assumes that the sender (`from`) has approved the pool to spend the specified tokens. +// - The balance consistency check ensures that no tokens are lost or double-counted during the transfer. +// - Pool balance updates are performed atomically to ensure internal consistency. +// +// Example: +// p.safeTransferFrom( +// +// sender, poolAddress, "path/to/token0", u256.MustFromDecimal("1000"), true +// +// ) +func (p *Pool) safeTransferFrom( from, to std.Address, tokenPath string, amount *u256.Uint, isToken0 bool, ) { - absAmount := amount - amountUint64, err := safeConvertToUint64(absAmount) - if err != nil { - panic(err) - } + amountUint64 := safeConvertToUint64(amount) - token := common.GetTokenTeller(tokenPath) - checkTransferError(token.TransferFrom(from, to, amountUint64)) + token := common.GetToken(tokenPath) + beforeBalance := token.BalanceOf(to) + + teller := common.GetTokenTeller(tokenPath) + checkTransferError(teller.TransferFrom(from, to, amountUint64)) + + afterBalance := token.BalanceOf(to) + if (beforeBalance + amountUint64) != afterBalance { + panic(ufmt.Sprintf( + "%v. beforeBalance(%d) + amount(%d) != afterBalance(%d)", + errTransferFailed, beforeBalance, amountUint64, afterBalance, + )) + } // update pool balances if isToken0 { - p.balances.token0 = new(u256.Uint).Add(p.balances.token0, absAmount) + beforeToken0 := p.balances.token0.Clone() + p.balances.token0 = new(u256.Uint).Add(p.balances.token0, amount) + if p.balances.token0.Lt(beforeToken0) { + panic(ufmt.Sprintf( + "%v. token0(%s) < beforeToken0(%s)", + errBalanceUpdateFailed, p.balances.token0.ToString(), beforeToken0.ToString(), + )) + } } else { - p.balances.token1 = new(u256.Uint).Add(p.balances.token1, absAmount) + beforeToken1 := p.balances.token1.Clone() + p.balances.token1 = new(u256.Uint).Add(p.balances.token1, amount) + if p.balances.token1.Lt(beforeToken1) { + panic(ufmt.Sprintf( + "%v. token1(%s) < beforeToken1(%s)", + errBalanceUpdateFailed, p.balances.token1.ToString(), beforeToken1.ToString(), + )) + } } } @@ -150,7 +181,7 @@ func updatePoolBalance( if isBalanceOverflowOrNegative(overflow, newBalance) { return nil, ufmt.Errorf( "%v. cannot decrease, token0(%s) - amount(%s)", - errTransferFailed, token0.ToString(), amount.ToString(), + errBalanceUpdateFailed, token0.ToString(), amount.ToString(), ) } return newBalance, nil @@ -160,12 +191,13 @@ func updatePoolBalance( if isBalanceOverflowOrNegative(overflow, newBalance) { return nil, ufmt.Errorf( "%v. cannot decrease, token1(%s) - amount(%s)", - errTransferFailed, token1.ToString(), amount.ToString(), + errBalanceUpdateFailed, token1.ToString(), amount.ToString(), ) } return newBalance, nil } +// isBalanceOverflowOrNegative checks if the balance calculation resulted in an overflow or negative value. func isBalanceOverflowOrNegative(overflow bool, newBalance *u256.Uint) bool { return overflow || newBalance.Lt(u256.Zero()) } diff --git a/pool/pool_transfer_test.gno b/pool/pool_transfer_test.gno index 562825b9a..9a610d51d 100644 --- a/pool/pool_transfer_test.gno +++ b/pool/pool_transfer_test.gno @@ -134,10 +134,10 @@ func TestTransferFromAndVerify(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - TokenFaucet(t, fooPath, pusers.AddressOrName(tt.from)) - TokenApprove(t, fooPath, pusers.AddressOrName(tt.from), pool, u256.MustFromDecimal(tt.amount.ToString()).Uint64()) + TokenFaucet(t, tt.tokenPath, pusers.AddressOrName(tt.from)) + TokenApprove(t, tt.tokenPath, pusers.AddressOrName(tt.from), pool, u256.MustFromDecimal(tt.amount.ToString()).Uint64()) - tt.pool.transferFromAndVerify(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) + tt.pool.safeTransferFrom(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) if !tt.pool.balances.token0.Eq(tt.expectedBal0) { t.Errorf("token0 balance mismatch: expected %s, got %s", @@ -165,7 +165,7 @@ func TestTransferFromAndVerify(t *testing.T) { TokenFaucet(t, fooPath, pusers.AddressOrName(testutils.TestAddress("from_addr"))) TokenApprove(t, fooPath, pusers.AddressOrName(testutils.TestAddress("from_addr")), pusers.AddressOrName(consts.POOL_ADDR), u256.MustFromDecimal(negativeAmount.Abs().ToString()).Uint64()) - pool.transferFromAndVerify( + pool.safeTransferFrom( testutils.TestAddress("from_addr"), testutils.TestAddress("to_addr"), fooPath, @@ -197,7 +197,7 @@ func TestTransferFromAndVerify(t *testing.T) { } }() - pool.transferFromAndVerify( + pool.safeTransferFrom( testutils.TestAddress("from_addr"), testutils.TestAddress("to_addr"), fooPath, diff --git a/pool/position.gno b/pool/position.gno index 64d37e9d5..8afba9905 100644 --- a/pool/position.gno +++ b/pool/position.gno @@ -15,49 +15,52 @@ var ( Q128 = u256.MustFromDecimal(consts.Q128) ) -// positionGetKey generates a unique key for a position based on the owner's address and the tick range. -func positionGetKey( +// getPositionKey generates a unique, encoded key for a liquidity position. +// +// This function creates a unique key for identifying a liquidity position in a pool. The key is based +// on the position's owner address, lower tick, and upper tick values. The generated key is then encoded +// as a base64 string to ensure compatibility and uniqueness. +// +// Parameters: +// - owner: std.Address, the address of the position's owner. +// - tickLower: int32, the lower tick boundary for the position. +// - tickUpper: int32, the upper tick boundary for the position. +// +// Returns: +// - string: A base64-encoded string representing the unique position key. +// +// Workflow: +// 1. Validates that the `owner` address is valid using `assertOnlyValidAddress`. +// 2. Ensures `tickLower` is less than `tickUpper` using `assertTickLowerLessThanUpper`. +// 3. Constructs the position key as a formatted string: +// "____" +// 4. Encodes the generated position key into a base64 string for safety and uniqueness. +// 5. Returns the encoded position key. +// +// Example: +// +// owner := std.Address("0x123456789") +// positionKey := getPositionKey(owner, 100, 200) +// fmt.Println("Position Key:", positionKey) +// // Output: base64-encoded string representing "0x123456789__100__200" +// +// Notes: +// - The base64 encoding ensures that the position key can be safely used as an identifier +// across different systems or data stores. +// - The function will panic if: +// - The `owner` address is invalid. +// - `tickLower` is greater than or equal to `tickUpper`. +func getPositionKey( owner std.Address, tickLower int32, tickUpper int32, ) string { - if !owner.IsValid() { - panic(addDetailToError( - errInvalidAddress, - ufmt.Sprintf("position.gno__positionGetKey() || invalid owner address %s", owner.String()), - )) - } - - if tickLower > tickUpper { - panic(addDetailToError( - errInvalidTickRange, - ufmt.Sprintf("position.gno__positionGetKey() || tickLower(%d) is greater than tickUpper(%d)", tickLower, tickUpper), - )) - } + assertOnlyValidAddress(owner) + assertTickLowerLessThanUpper(tickLower, tickUpper) positionKey := ufmt.Sprintf("%s__%d__%d", owner.String(), tickLower, tickUpper) - - encoded := base64.StdEncoding.EncodeToString([]byte(positionKey)) - return encoded -} - -// positionUpdateWithKey updates a position in the pool and returns the updated position. -func (pool *Pool) positionUpdateWithKey( - positionKey string, - liquidityDelta *i256.Int, - feeGrowthInside0X128 *u256.Uint, - feeGrowthInside1X128 *u256.Uint, -) PositionInfo { - // if pointer is nil, set to zero for calculation - liquidityDelta = liquidityDelta.NilToZero() - feeGrowthInside0X128 = feeGrowthInside0X128.NilToZero() - feeGrowthInside1X128 = feeGrowthInside1X128.NilToZero() - - positionToUpdate := pool.positions[positionKey] - positionAfterUpdate := positionUpdate(positionToUpdate, liquidityDelta, feeGrowthInside0X128, feeGrowthInside1X128) - pool.positions[positionKey] = positionAfterUpdate - - return positionAfterUpdate + encodedPositionKey := base64.StdEncoding.EncodeToString([]byte(positionKey)) + return encodedPositionKey } // positionUpdate calculates and returns an updated PositionInfo. @@ -67,7 +70,7 @@ func positionUpdate( feeGrowthInside0X128 *u256.Uint, feeGrowthInside1X128 *u256.Uint, ) PositionInfo { - position.init() + position.valueOrZero() var liquidityNext *u256.Uint if liquidityDelta.IsZero() { @@ -105,25 +108,48 @@ func positionUpdate( return position } -// receiver getters +// positionUpdateWithKey updates a position in the pool and returns the updated position. +func (p *Pool) positionUpdateWithKey( + positionKey string, + liquidityDelta *i256.Int, + feeGrowthInside0X128 *u256.Uint, + feeGrowthInside1X128 *u256.Uint, +) PositionInfo { + // if pointer is nil, set to zero for calculation + liquidityDelta = liquidityDelta.NilToZero() + feeGrowthInside0X128 = feeGrowthInside0X128.NilToZero() + feeGrowthInside1X128 = feeGrowthInside1X128.NilToZero() + + positionToUpdate := p.GetPosition(positionKey) + positionAfterUpdate := positionUpdate(positionToUpdate, liquidityDelta, feeGrowthInside0X128, feeGrowthInside1X128) + + p.positions[positionKey] = positionAfterUpdate + + return positionAfterUpdate +} -func (p *Pool) GetPositionLiquidity(key string) *u256.Uint { +// PositionLiquidity returns the liquidity of a position. +func (p *Pool) PositionLiquidity(key string) *u256.Uint { return p.mustGetPosition(key).liquidity } -func (p *Pool) GetPositionFeeGrowthInside0LastX128(key string) *u256.Uint { +// PositionFeeGrowthInside0LastX128 returns the fee growth of token0 inside a position. +func (p *Pool) PositionFeeGrowthInside0LastX128(key string) *u256.Uint { return p.mustGetPosition(key).feeGrowthInside0LastX128 } -func (p *Pool) GetPositionFeeGrowthInside1LastX128(key string) *u256.Uint { +// PositionFeeGrowthInside1LastX128 returns the fee growth of token1 inside a position. +func (p *Pool) PositionFeeGrowthInside1LastX128(key string) *u256.Uint { return p.mustGetPosition(key).feeGrowthInside1LastX128 } -func (p *Pool) GetPositionTokensOwed0(key string) *u256.Uint { +// PositionTokensOwed0 returns the amount of token0 owed by a position. +func (p *Pool) PositionTokensOwed0(key string) *u256.Uint { return p.mustGetPosition(key).tokensOwed0 } -func (p *Pool) GetPositionTokensOwed1(key string) *u256.Uint { +// PositionTokensOwed1 returns the amount of token1 owed by a position. +func (p *Pool) PositionTokensOwed1(key string) *u256.Uint { return p.mustGetPosition(key).tokensOwed1 } @@ -135,6 +161,15 @@ func (p *Pool) mustGetPosition(key string) PositionInfo { ufmt.Sprintf("position(%s) does not exist", key), )) } + return position +} +func (p *Pool) GetPosition(key string) PositionInfo { + position, exist := p.positions[key] + if !exist { + newPosition := PositionInfo{} + newPosition.valueOrZero() + return newPosition + } return position } diff --git a/pool/position_modify.gno b/pool/position_modify.gno index c6c574005..fe3bf876f 100644 --- a/pool/position_modify.gno +++ b/pool/position_modify.gno @@ -1,11 +1,12 @@ package pool import ( - "gno.land/r/gnoswap/v1/common" - + "gno.land/p/demo/ufmt" i256 "gno.land/p/gnoswap/int256" plp "gno.land/p/gnoswap/pool" u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" ) // modifyPosition updates a position in the pool and calculates the amount of tokens @@ -25,46 +26,49 @@ import ( // - *u256.Uint: amount of token0 needed/returned // - *u256.Uint: amount of token1 needed/returned func (p *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *u256.Uint, *u256.Uint) { + checkTicks(params.tickLower, params.tickUpper) + + // get current state and price bounds + tick := p.Slot0Tick() // update position state - position := p.updatePosition(params) + position := p.updatePosition(params, tick) liqDelta := params.liquidityDelta - if liqDelta.IsZero() { return position, u256.Zero(), u256.Zero() } amount0, amount1 := i256.Zero(), i256.Zero() - // get current state and price bounds - tick := p.slot0.tick // covert ticks to sqrt price to use in amount calculations // price = 1.0001^tick, but we use sqrtPriceX96 sqrtRatioLower := common.TickMathGetSqrtRatioAtTick(params.tickLower) sqrtRatioUpper := common.TickMathGetSqrtRatioAtTick(params.tickUpper) - sqrtPriceX96 := p.slot0.sqrtPriceX96 + sqrtPriceX96 := p.Slot0SqrtPriceX96() // calculate token amounts based on current price position relative to range switch { case tick < params.tickLower: // case 1 // full range between lower and upper tick is used for token0 + // current tick is below the passed range; liquidity can only become in range by crossing from left to + // right, when we'll need _more_ token0 (it's becoming more valuable) so user must provide it amount0 = calculateToken0Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) case tick < params.tickUpper: // case 2 liquidityBefore := p.liquidity - // token0 used from current price to upper tick amount0 = calculateToken0Amount(sqrtPriceX96, sqrtRatioUpper, liqDelta) // token1 used from lower tick to current price amount1 = calculateToken1Amount(sqrtRatioLower, sqrtPriceX96, liqDelta) - // update pool's active liquidity since price is in range p.liquidity = liquidityMathAddDelta(liquidityBefore, liqDelta) default: // case 3 // full range between lower and upper tick is used for token1 + // current tick is above the passed range; liquidity can only become in range by crossing from right to + // left, when we'll need _more_ token1 (it's becoming more valuable) so user must provide it amount1 = calculateToken1Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) } @@ -80,3 +84,36 @@ func calculateToken1Amount(sqrtPriceLower, sqrtPriceUpper *u256.Uint, liquidityD res := plp.SqrtPriceMathGetAmount1DeltaStr(sqrtPriceLower, sqrtPriceUpper, liquidityDelta) return i256.MustFromDecimal(res) } + +func checkTicks(tickLower, tickUpper int32) { + assertTickLowerLessThanUpper(tickLower, tickUpper) + assertValidTickLower(tickLower) + assertValidTickUpper(tickUpper) +} + +func assertTickLowerLessThanUpper(tickLower, tickUpper int32) { + if tickLower >= tickUpper { + panic(addDetailToError( + errInvalidTickRange, + ufmt.Sprintf("tickLower(%d), tickUpper(%d)", tickLower, tickUpper), + )) + } +} + +func assertValidTickLower(tickLower int32) { + if tickLower < consts.MIN_TICK { + panic(addDetailToError( + errTickLowerInvalid, + ufmt.Sprintf("tickLower(%d) < MIN_TICK(%d)", tickLower, consts.MIN_TICK), + )) + } +} + +func assertValidTickUpper(tickUpper int32) { + if tickUpper > consts.MAX_TICK { + panic(addDetailToError( + errTickUpperInvalid, + ufmt.Sprintf("tickUpper(%d) > MAX_TICK(%d)", tickUpper, consts.MAX_TICK), + )) + } +} diff --git a/pool/position_modify_test.gno b/pool/position_modify_test.gno index a6c2da6d0..4a8ed1401 100644 --- a/pool/position_modify_test.gno +++ b/pool/position_modify_test.gno @@ -134,3 +134,159 @@ func TestModifyPositionEdgeCases(t *testing.T) { uassert.Equal(t, amount1.ToString(), "2958014") }) } + +func TestAssertTickLowerLessThanUpper(t *testing.T) { + tests := []struct { + name string + tickLower int32 + tickUpper int32 + shouldPanic bool + expected string + }{ + { + name: "tickLower is less than tickUpper", + tickLower: -100, + tickUpper: 100, + shouldPanic: false, + expected: "", + }, + { + name: "tickLower equals tickUpper", + tickLower: 50, + tickUpper: 50, + shouldPanic: true, + expected: "[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(50), tickUpper(50)", + }, + { + name: "tickLower greater than tickUpper", + tickLower: 200, + tickUpper: 100, + shouldPanic: true, + expected: "[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(200), tickUpper(100)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } else { + if tt.shouldPanic { + uassert.Equal(t, tt.expected, r.(string)) + } else { + t.Errorf("expected no panic, but got: %v", r) + } + } + }() + assertTickLowerLessThanUpper(tt.tickLower, tt.tickUpper) + }) + } +} + +func TestAssertValidTickLower(t *testing.T) { + tests := []struct { + name string + tickLower int32 + tickUpper int32 + shouldPanic bool + expected string + }{ + { + name: "tickLower equals MIN_TICK", + tickLower: consts.MIN_TICK, + tickUpper: 100, + shouldPanic: false, + expected: "", + }, + { + name: "tickLower greater than MIN_TICK", + tickLower: consts.MIN_TICK + 1, + tickUpper: 50, + shouldPanic: false, + expected: "", + }, + { + name: "tickLower less than MIN_TICK (panic)", + tickLower: consts.MIN_TICK - 1, + tickUpper: 100, + shouldPanic: true, + expected: "[GNOSWAP-POOL-028] tickLower is invalid || tickLower(-887273) < MIN_TICK(-887272)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } else { + if tt.shouldPanic { + uassert.Equal(t, tt.expected, r.(string)) + } else { + t.Errorf("expected no panic, but got: %v", r) + } + } + }() + assertValidTickLower(tt.tickLower) + }) + } +} + +func TestAssertValidTickUpper(t *testing.T) { + tests := []struct { + name string + tickLower int32 + tickUpper int32 + shouldPanic bool + expected string + }{ + { + name: "tickUpper equals MAX_TICK", + tickLower: consts.MIN_TICK, + tickUpper: consts.MAX_TICK, + shouldPanic: false, + expected: "", + }, + { + name: "tickUpper less than MAX_TICK", + tickLower: consts.MIN_TICK + 1, + tickUpper: consts.MAX_TICK - 1, + shouldPanic: false, + expected: "", + }, + { + name: "tickUpper greater than MAX_TICK (panic)", + tickLower: consts.MIN_TICK - 1, + tickUpper: consts.MAX_TICK + 1, + shouldPanic: true, + expected: "[GNOSWAP-POOL-029] tickUpper is invalid || tickUpper(887273) > MAX_TICK(887272)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } else { + if tt.shouldPanic { + uassert.Equal(t, tt.expected, r.(string)) + } else { + t.Errorf("expected no panic, but got: %v", r) + } + } + }() + assertValidTickUpper(tt.tickUpper) + }) + } +} diff --git a/pool/position_test.gno b/pool/position_test.gno index 46928b6a8..b96821af7 100644 --- a/pool/position_test.gno +++ b/pool/position_test.gno @@ -25,18 +25,23 @@ func TestPositionGetKey(t *testing.T) { panicMsg string expectedKey string }{ - {invalidAddr, 100, 200, true, `[GNOSWAP-POOL-023] invalid address || position.gno__positionGetKey() || invalid owner address invalidAddr`, ""}, // invalid address - {validAddr, 200, 100, true, `[GNOSWAP-POOL-024] tickLower is greater than tickUpper || position.gno__positionGetKey() || tickLower(200) is greater than tickUpper(100)`, ""}, // tickLower > tickUpper - {validAddr, -100, -200, true, `[GNOSWAP-POOL-024] tickLower is greater than tickUpper || position.gno__positionGetKey() || tickLower(-100) is greater than tickUpper(-200)`, ""}, // tickLower > tickUpper - {validAddr, 100, 100, false, "", "ZzF3ZXNrYzZ0eWc5anhndWpsdGEwNDdoNmx0YTA0N2g2bGRqbHVkdV9fMTAwX18xMDA="}, // tickLower == tickUpper - {validAddr, 100, 200, false, "", "ZzF3ZXNrYzZ0eWc5anhndWpsdGEwNDdoNmx0YTA0N2g2bGRqbHVkdV9fMTAwX18yMDA="}, // tickLower < tickUpper + {invalidAddr, 100, 200, true, `[GNOSWAP-POOL-023] invalid address || (invalidAddr)`, ""}, // invalid address + {validAddr, 200, 100, true, `[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(200), tickUpper(100)`, ""}, // tickLower > tickUpper + {validAddr, -100, -200, true, `[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(-100), tickUpper(-200)`, ""}, // tickLower > tickUpper + {validAddr, 100, 100, true, "[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(100), tickUpper(100)", ""}, // tickLower == tickUpper + {validAddr, 100, 200, false, "", "ZzF3ZXNrYzZ0eWc5anhndWpsdGEwNDdoNmx0YTA0N2g2bGRqbHVkdV9fMTAwX18yMDA="}, // tickLower < tickUpper } for _, tc := range tests { + defer func() { + if r := recover(); r != nil { + uassert.Equal(t, tc.panicMsg, r.(string)) + } + }() if tc.shouldPanic { - uassert.PanicsWithMessage(t, tc.panicMsg, func() { positionGetKey(tc.owner, tc.tickLower, tc.tickUpper) }) + uassert.PanicsWithMessage(t, tc.panicMsg, func() { getPositionKey(tc.owner, tc.tickLower, tc.tickUpper) }) } else { - key := positionGetKey(tc.owner, tc.tickLower, tc.tickUpper) + key := getPositionKey(tc.owner, tc.tickLower, tc.tickUpper) uassert.Equal(t, tc.expectedKey, key) } } @@ -55,7 +60,7 @@ func TestPositionUpdateWithKey(t *testing.T) { ) dummyPool = newPool(poolParams) - positionKey = positionGetKey( + positionKey = getPositionKey( testutils.TestAddress("dummyAddr"), 100, 200, diff --git a/pool/position_update.gno b/pool/position_update.gno index 5f21e49f7..827d91f10 100644 --- a/pool/position_update.gno +++ b/pool/position_update.gno @@ -1,68 +1,97 @@ package pool -import ( - u256 "gno.land/p/gnoswap/uint256" -) - -func (pool *Pool) updatePosition(positionParams ModifyPositionParams) PositionInfo { - feeGrowthGlobal0X128 := pool.feeGrowthGlobal0X128.Clone() - feeGrowthGlobal1X128 := pool.feeGrowthGlobal1X128.Clone() +// updatePosition modifies the position's liquidity and updates the corresponding tick states. +// +// This function updates the position data based on the specified liquidity delta and tick range. +// It also manages the fee growth, tick state flipping, and cleanup of unused tick data. +// +// Parameters: +// - positionParams: ModifyPositionParams, the parameters for the position modification, which include: +// - owner: The address of the position owner. +// - tickLower: The lower tick boundary of the position. +// - tickUpper: The upper tick boundary of the position. +// - liquidityDelta: The change in liquidity (positive or negative). +// - tick: int32, the current tick position. +// +// Returns: +// - PositionInfo: The updated position information. +// +// Workflow: +// 1. Clone the global fee growth values (token 0 and token 1). +// 2. If the liquidity delta is non-zero: +// - Update the lower and upper ticks using `tickUpdate`, flipping their states if necessary. +// - If a tick's state was flipped, update the tick bitmap to reflect the new state. +// 3. Calculate the fee growth inside the tick range using `getFeeGrowthInside`. +// 4. Generate a unique position key and update the position data using `positionUpdateWithKey`. +// 5. If liquidity is being removed (negative delta), clean up unused tick data by deleting the tick entries. +// 6. Return the updated position. +// +// Notes: +// - The function flips the tick states and cleans up unused tick data when liquidity is removed. +// - It ensures fee growth and position data remain accurate after the update. +// +// Example Usage: +// +// updatedPosition := pool.updatePosition(positionParams, currentTick) +// fmt.Println("Updated Position Info:", updatedPosition) +func (p *Pool) updatePosition(positionParams ModifyPositionParams, tick int32) PositionInfo { + feeGrowthGlobal0X128 := p.FeeGrowthGlobal0X128().Clone() + feeGrowthGlobal1X128 := p.FeeGrowthGlobal1X128().Clone() var flippedLower, flippedUpper bool if !(positionParams.liquidityDelta.IsZero()) { - flippedLower = pool.tickUpdate( + flippedLower = p.tickUpdate( positionParams.tickLower, - pool.slot0.tick, + tick, positionParams.liquidityDelta, feeGrowthGlobal0X128, feeGrowthGlobal1X128, false, - pool.maxLiquidityPerTick, + p.maxLiquidityPerTick, ) - flippedUpper = pool.tickUpdate( + flippedUpper = p.tickUpdate( positionParams.tickUpper, - pool.slot0.tick, + tick, positionParams.liquidityDelta, feeGrowthGlobal0X128, feeGrowthGlobal1X128, true, - pool.maxLiquidityPerTick, + p.maxLiquidityPerTick, ) if flippedLower { - pool.tickBitmapFlipTick(positionParams.tickLower, pool.tickSpacing) + p.tickBitmapFlipTick(positionParams.tickLower, p.tickSpacing) } if flippedUpper { - pool.tickBitmapFlipTick(positionParams.tickUpper, pool.tickSpacing) + p.tickBitmapFlipTick(positionParams.tickUpper, p.tickSpacing) } } - feeGrowthInside0X128, feeGrowthInside1X128 := pool.calculateFeeGrowthInside( + feeGrowthInside0X128, feeGrowthInside1X128 := p.getFeeGrowthInside( positionParams.tickLower, positionParams.tickUpper, - pool.slot0.tick, + tick, feeGrowthGlobal0X128, feeGrowthGlobal1X128, ) - positionKey := positionGetKey(positionParams.owner, positionParams.tickLower, positionParams.tickUpper) - - position := pool.positionUpdateWithKey( + positionKey := getPositionKey(positionParams.owner, positionParams.tickLower, positionParams.tickUpper) + position := p.positionUpdateWithKey( positionKey, positionParams.liquidityDelta, - u256.MustFromDecimal(feeGrowthInside0X128.ToString()), - u256.MustFromDecimal(feeGrowthInside1X128.ToString()), + feeGrowthInside0X128.Clone(), + feeGrowthInside1X128.Clone(), ) + // clear any tick data that is no longer needed if positionParams.liquidityDelta.IsNeg() { if flippedLower { - delete(pool.ticks, positionParams.tickLower) + delete(p.ticks, positionParams.tickLower) } - if flippedUpper { - delete(pool.ticks, positionParams.tickUpper) + delete(p.ticks, positionParams.tickUpper) } } diff --git a/pool/position_update_test.gno b/pool/position_update_test.gno index 24ca4cb55..ad2e67077 100644 --- a/pool/position_update_test.gno +++ b/pool/position_update_test.gno @@ -3,19 +3,15 @@ package pool import ( "testing" - "std" - - "gno.land/p/demo/uassert" - - "gno.land/r/gnoswap/v1/consts" i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/consts" ) func TestUpdatePosition(t *testing.T) { poolParams := &createPoolParams{ - token0Path: "token0", - token1Path: "token1", + token0Path: "token0", + token1Path: "token1", fee: 500, tickSpacing: 10, sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 @@ -23,8 +19,8 @@ func TestUpdatePosition(t *testing.T) { p := newPool(poolParams) tests := []struct { - name string - positionParams ModifyPositionParams + name string + positionParams ModifyPositionParams expectLiquidity *u256.Uint }{ { @@ -61,11 +57,12 @@ func TestUpdatePosition(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - position := p.updatePosition(tt.positionParams) - + tick := p.Slot0Tick() + position := p.updatePosition(tt.positionParams, tick) + if !position.liquidity.Eq(tt.expectLiquidity) { - t.Errorf("liquidity mismatch: expected %s, got %s", - tt.expectLiquidity.ToString(), + t.Errorf("liquidity mismatch: expected %s, got %s", + tt.expectLiquidity.ToString(), position.liquidity.ToString()) } diff --git a/pool/protocol_fee_pool_creation.gno b/pool/protocol_fee_pool_creation.gno index a56284298..c6e4e06ae 100644 --- a/pool/protocol_fee_pool_creation.gno +++ b/pool/protocol_fee_pool_creation.gno @@ -32,7 +32,7 @@ func SetPoolCreationFee(fee uint64) { } setPoolCreationFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetPoolCreationFee", "prevAddr", prevAddr, @@ -54,7 +54,7 @@ func SetPoolCreationFeeByAdmin(fee uint64) { } setPoolCreationFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetPoolCreationFeeByAdmin", "prevAddr", prevAddr, diff --git a/pool/protocol_fee_withdrawal.gno b/pool/protocol_fee_withdrawal.gno index cacc602c7..d7d208b9b 100644 --- a/pool/protocol_fee_withdrawal.gno +++ b/pool/protocol_fee_withdrawal.gno @@ -72,7 +72,7 @@ func HandleWithdrawalFee( token1Teller := common.GetTokenTeller(token1Path) checkTransferError(token1Teller.TransferFrom(positionCaller, consts.PROTOCOL_FEE_ADDR, feeAmount1.Uint64())) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "WithdrawalFee", "prevAddr", prevAddr, @@ -106,7 +106,7 @@ func SetWithdrawalFee(fee uint64) { setWithdrawalFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetWithdrawalFee", "prevAddr", prevAddr, @@ -126,7 +126,7 @@ func SetWithdrawalFeeByAdmin(fee uint64) { setWithdrawalFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetWithdrawalFeeByAdmin", "prevAddr", prevAddr, diff --git a/pool/swap.gno b/pool/swap.gno index 47c5e2d67..3b8b9c60f 100644 --- a/pool/swap.gno +++ b/pool/swap.gno @@ -107,7 +107,7 @@ func Swap( // actual swap pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "Swap", @@ -512,13 +512,13 @@ func tickTransition(step StepComputations, zeroForOne bool, state SwapState, poo func (p *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { if zeroForOne { // payer > POOL - p.transferFromAndVerify(payer, consts.POOL_ADDR, p.token0Path, amount0.Abs(), true) + p.safeTransferFrom(payer, consts.POOL_ADDR, p.token0Path, amount0.Abs(), true) // POOL > recipient - p.transferAndVerify(recipient, p.token1Path, amount1, false) + p.safeTransfer(recipient, p.token1Path, amount1, false) } else { // payer > POOL - p.transferFromAndVerify(payer, consts.POOL_ADDR, p.token1Path, amount1.Abs(), false) + p.safeTransferFrom(payer, consts.POOL_ADDR, p.token1Path, amount1.Abs(), false) // POOL > recipient - p.transferAndVerify(recipient, p.token0Path, amount0, true) + p.safeTransfer(recipient, p.token0Path, amount0, true) } } diff --git a/pool/tests/__TEST_pool_burn_test.gnoA b/pool/tests/__TEST_pool_burn_test.gnoA index bcaa501be..64d597353 100644 --- a/pool/tests/__TEST_pool_burn_test.gnoA +++ b/pool/tests/__TEST_pool_burn_test.gnoA @@ -81,7 +81,7 @@ func TestDoesNotClear(t *testing.T) { uassert.Equal(t, liq.ToString(), "0") // tokensOwed - thisPositionKey := positionGetKey(consts.POSITION_ADDR, -887160, 887160) + thisPositionKey := getPositionKey(consts.POSITION_ADDR, -887160, 887160) thisPosition := thisPool.positions[thisPositionKey] tokensOwed0 := thisPosition.tokensOwed0 diff --git a/pool/tests/__TEST_pool_spec_#6_test.gnoA b/pool/tests/__TEST_pool_spec_#6_test.gnoA index 03aed8d62..1e36d3c2f 100644 --- a/pool/tests/__TEST_pool_spec_#6_test.gnoA +++ b/pool/tests/__TEST_pool_spec_#6_test.gnoA @@ -133,7 +133,7 @@ func TestWorkAccross(t *testing.T) { ) // tokensOwed - thisPositionKey := positionGetKey(consts.POSITION_ADDR, -887270, 887270) + thisPositionKey := getPositionKey(consts.POSITION_ADDR, -887270, 887270) thisPosition := thisPool.positions[thisPositionKey] tokensOwed0 := thisPosition.tokensOwed0 diff --git a/pool/tick.gno b/pool/tick.gno index d0db635c2..63e070222 100644 --- a/pool/tick.gno +++ b/pool/tick.gno @@ -19,6 +19,42 @@ func calculateMaxLiquidityPerTick(tickSpacing int32) *u256.Uint { return new(u256.Uint).Div(u256.MustFromDecimal(consts.MAX_UINT128), u256.NewUint(numTicks)) } +// getFeeGrowthBelowX128 calculates the fee growth below a specified tick. +// +// This function computes the fee growth for token 0 and token 1 below a given tick (`tickLower`) +// relative to the current tick (`tickCurrent`). The fee growth values are adjusted based on whether +// the `tickCurrent` is above or below the `tickLower`. +// +// Parameters: +// - tickLower: int32, the lower tick boundary for fee calculation. +// - tickCurrent: int32, the current tick index. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth for token 0 in X128 precision. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth for token 1 in X128 precision. +// - lowerTick: TickInfo, the fee growth and liquidity details for the lower tick. +// +// Returns: +// - *u256.Uint: Fee growth below `tickLower` for token 0. +// - *u256.Uint: Fee growth below `tickLower` for token 1. +// +// Workflow: +// 1. If `tickCurrent` is greater than or equal to `tickLower`: +// - Return the `feeGrowthOutside0X128` and `feeGrowthOutside1X128` values of the `lowerTick`. +// 2. If `tickCurrent` is below `tickLower`: +// - Compute the fee growth below the lower tick by subtracting `feeGrowthOutside` values +// from the global fee growth values (`feeGrowthGlobal0X128` and `feeGrowthGlobal1X128`). +// 3. Return the calculated fee growth values for both tokens. +// +// Behavior: +// - If `tickCurrent >= tickLower`, the fee growth outside the lower tick is returned as-is. +// - If `tickCurrent < tickLower`, the fee growth is calculated as: +// feeGrowthBelow = feeGrowthGlobal - feeGrowthOutside +// +// Example: +// +// feeGrowth0, feeGrowth1 := getFeeGrowthBelowX128( +// 100, 150, globalFeeGrowth0, globalFeeGrowth1, lowerTickInfo, +// ) +// fmt.Println("Fee Growth Below:", feeGrowth0, feeGrowth1) func getFeeGrowthBelowX128( tickLower, tickCurrent int32, feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, @@ -28,12 +64,48 @@ func getFeeGrowthBelowX128( return lowerTick.feeGrowthOutside0X128, lowerTick.feeGrowthOutside1X128 } - below0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, lowerTick.feeGrowthOutside0X128) - below1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, lowerTick.feeGrowthOutside1X128) + feeGrowthBelow0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, lowerTick.feeGrowthOutside0X128) + feeGrowthBelow1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, lowerTick.feeGrowthOutside1X128) - return below0X128, below1X128 + return feeGrowthBelow0X128, feeGrowthBelow1X128 } +// getFeeGrowthAboveX128 calculates the fee growth above a specified tick. +// +// This function computes the fee growth for token 0 and token 1 above a given tick (`tickUpper`) +// relative to the current tick (`tickCurrent`). The fee growth values are adjusted based on whether +// the `tickCurrent` is above or below the `tickUpper`. +// +// Parameters: +// - tickUpper: int32, the upper tick boundary for fee calculation. +// - tickCurrent: int32, the current tick index. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth for token 0 in X128 precision. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth for token 1 in X128 precision. +// - upperTick: TickInfo, the fee growth and liquidity details for the upper tick. +// +// Returns: +// - *u256.Uint: Fee growth above `tickUpper` for token 0. +// - *u256.Uint: Fee growth above `tickUpper` for token 1. +// +// Workflow: +// 1. If `tickCurrent` is less than `tickUpper`: +// - Return the `feeGrowthOutside0X128` and `feeGrowthOutside1X128` values of the `upperTick`. +// 2. If `tickCurrent` is greater than or equal to `tickUpper`: +// - Compute the fee growth above the upper tick by subtracting `feeGrowthOutside` values +// from the global fee growth values (`feeGrowthGlobal0X128` and `feeGrowthGlobal1X128`). +// 3. Return the calculated fee growth values for both tokens. +// +// Behavior: +// - If `tickCurrent < tickUpper`, the fee growth outside the upper tick is returned as-is. +// - If `tickCurrent >= tickUpper`, the fee growth is calculated as: +// feeGrowthAbove = feeGrowthGlobal - feeGrowthOutside +// +// Example: +// +// feeGrowth0, feeGrowth1 := getFeeGrowthAboveX128( +// 200, 150, globalFeeGrowth0, globalFeeGrowth1, upperTickInfo, +// ) +// fmt.Println("Fee Growth Above:", feeGrowth0, feeGrowth1) func getFeeGrowthAboveX128( tickUpper, tickCurrent int32, feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, @@ -43,15 +115,51 @@ func getFeeGrowthAboveX128( return upperTick.feeGrowthOutside0X128, upperTick.feeGrowthOutside1X128 } - above0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, upperTick.feeGrowthOutside0X128) - above1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, upperTick.feeGrowthOutside1X128) + feeGrowthAbove0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, upperTick.feeGrowthOutside0X128) + feeGrowthAbove1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, upperTick.feeGrowthOutside1X128) - return above0X128, above1X128 + return feeGrowthAbove0X128, feeGrowthAbove1X128 } -// calculateFeeGrowthInside calculates the fee growth inside a tick range, -// and returns the fee growth inside for both tokens. -func (p *Pool) calculateFeeGrowthInside( +// getFeeGrowthInside calculates the fee growth within a specified tick range. +// +// This function computes the accumulated fee growth for token 0 and token 1 inside a given tick range +// (`tickLower` to `tickUpper`) relative to the current tick position (`tickCurrent`). It isolates the fee +// growth within the range by subtracting the fee growth below the lower tick and above the upper tick +// from the global fee growth. +// +// Parameters: +// - tickLower: int32, the lower tick boundary of the range. +// - tickUpper: int32, the upper tick boundary of the range. +// - tickCurrent: int32, the current tick index. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth for token 0 in X128 precision. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth for token 1 in X128 precision. +// +// Returns: +// - *u256.Uint: Fee growth inside the tick range for token 0. +// - *u256.Uint: Fee growth inside the tick range for token 1. +// +// Workflow: +// 1. Retrieve the tick information (`lower` and `upper`) for the lower and upper tick boundaries +// using `p.getTick`. +// 2. Calculate the fee growth below the lower tick using `getFeeGrowthBelowX128`. +// 3. Calculate the fee growth above the upper tick using `getFeeGrowthAboveX128`. +// 4. Subtract the fee growth below and above the range from the global fee growth values: +// feeGrowthInside = feeGrowthGlobal - feeGrowthBelow - feeGrowthAbove +// 5. Return the computed fee growth values for token 0 and token 1 within the range. +// +// Behavior: +// - The fee growth is isolated within the range `[tickLower, tickUpper]`. +// - The function ensures the calculations accurately consider the tick boundaries and the current tick position. +// +// Example: +// +// feeGrowth0, feeGrowth1 := pool.getFeeGrowthInside( +// 100, 200, 150, globalFeeGrowth0, globalFeeGrowth1, +// ) +// fmt.Println("Fee Growth Inside (Token 0):", feeGrowth0) +// fmt.Println("Fee Growth Inside (Token 1):", feeGrowth1) +func (p *Pool) getFeeGrowthInside( tickLower int32, tickUpper int32, tickCurrent int32, @@ -70,11 +178,53 @@ func (p *Pool) calculateFeeGrowthInside( return feeGrowthInside0X128, feeGrowthInside1X128 } -// tickUpdate updates a tick's state and returns whether the tick was flipped. +// tickUpdate updates the state of a specific tick. +// +// This function applies a given liquidity change (liquidityDelta) to the specified tick, updates +// the fee growth values if necessary, and adjusts the net liquidity based on whether the tick +// is an upper or lower boundary. It also verifies that the total liquidity does not exceed the +// maximum allowed value and ensures the net liquidity stays within the valid int128 range. +// +// Parameters: +// - tick: int32, the index of the tick to update. +// - tickCurrent: int32, the current active tick index. +// - liquidityDelta: *i256.Int, the amount of liquidity to add or remove. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth value for token 0. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth value for token 1. +// - upper: bool, indicates if this is the upper boundary (true for upper, false for lower). +// - maxLiquidity: *u256.Uint, the maximum allowed liquidity. +// +// Returns: +// - flipped: bool, indicates if the tick's initialization state has changed. +// (e.g., liquidity transitioning from zero to non-zero, or vice versa) +// +// Workflow: +// 1. Nil input values are replaced with zero. +// 2. The function retrieves the tick information for the specified tick index. +// 3. Applies the liquidityDelta to compute the new total liquidity (liquidityGross). +// - If the total liquidity exceeds the maximum allowed value, the function panics. +// 4. Checks whether the tick's initialized state has changed and sets the `flipped` flag. +// 5. If the tick was previously uninitialized and its index is less than or equal to the current tick, +// the fee growth values are initialized to the current global values. +// 6. Updates the tick's net liquidity: +// - For an upper boundary, it subtracts liquidityDelta. +// - For a lower boundary, it adds liquidityDelta. +// - Ensures the net liquidity remains within the int128 range using `checkOverFlowInt128`. +// 7. Updates the tick's state with the new values. +// 8. Returns whether the tick's initialized state has flipped. +// +// Panic Conditions: +// - The total liquidity (liquidityGross) exceeds the maximum allowed liquidity (maxLiquidity). +// - The net liquidity (liquidityNet) exceeds the int128 range. +// +// Example: +// +// flipped := pool.tickUpdate(10, 5, liquidityDelta, feeGrowth0, feeGrowth1, true, maxLiquidity) +// fmt.Println("Tick flipped:", flipped) func (p *Pool) tickUpdate( tick int32, tickCurrent int32, - liquidityDelta *i256.Int, // int128 + liquidityDelta *i256.Int, feeGrowthGlobal0X128 *u256.Uint, feeGrowthGlobal1X128 *u256.Uint, upper bool, @@ -84,15 +234,15 @@ func (p *Pool) tickUpdate( feeGrowthGlobal0X128 = feeGrowthGlobal0X128.NilToZero() feeGrowthGlobal1X128 = feeGrowthGlobal1X128.NilToZero() - thisTick := p.getTick(tick) + tickInfo := p.getTick(tick) - liquidityGrossBefore := thisTick.liquidityGross + liquidityGrossBefore := tickInfo.liquidityGross.Clone() liquidityGrossAfter := liquidityMathAddDelta(liquidityGrossBefore, liquidityDelta) if !(liquidityGrossAfter.Lte(maxLiquidity)) { panic(addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("tick.gno__tickUpdate() || liquidityGrossAfter(%s) overflows maxLiquidity(%s)", liquidityGrossAfter.ToString(), maxLiquidity.ToString()), + ufmt.Sprintf("liquidityGrossAfter(%s) overflows maxLiquidity(%s)", liquidityGrossAfter.ToString(), maxLiquidity.ToString()), )) } @@ -100,22 +250,23 @@ func (p *Pool) tickUpdate( if liquidityGrossBefore.IsZero() { if tick <= tickCurrent { - thisTick.feeGrowthOutside0X128 = feeGrowthGlobal0X128 - thisTick.feeGrowthOutside1X128 = feeGrowthGlobal1X128 + tickInfo.feeGrowthOutside0X128 = feeGrowthGlobal0X128.Clone() + tickInfo.feeGrowthOutside1X128 = feeGrowthGlobal1X128.Clone() } - - thisTick.initialized = true + tickInfo.initialized = true } - thisTick.liquidityGross = liquidityGrossAfter + tickInfo.liquidityGross = liquidityGrossAfter.Clone() if upper { - thisTick.liquidityNet = i256.Zero().Sub(thisTick.liquidityNet, liquidityDelta) + tickInfo.liquidityNet = i256.Zero().Sub(tickInfo.liquidityNet, liquidityDelta) + checkOverFlowInt128(tickInfo.liquidityNet) } else { - thisTick.liquidityNet = i256.Zero().Add(thisTick.liquidityNet, liquidityDelta) + tickInfo.liquidityNet = i256.Zero().Add(tickInfo.liquidityNet, liquidityDelta) + checkOverFlowInt128(tickInfo.liquidityNet) } - p.ticks[tick] = thisTick + p.setTick(tick, tickInfo) return flipped } @@ -136,47 +287,97 @@ func (p *Pool) tickCross( return thisTick.liquidityNet.Clone() } -// getTick returns a tick's state. +// setTick updates the tick data for the specified tick index in the pool. +func (p *Pool) setTick(tick int32, newTickInfo TickInfo) { + p.ticks[tick] = newTickInfo +} + +// getTick retrieves the TickInfo associated with the specified tick index from the pool. +// If the TickInfo contains any nil fields, they are replaced with zero values using valueOrZero. +// +// Parameters: +// - tick: The tick index (int32) for which the TickInfo is to be retrieved. +// +// Behavior: +// - Retrieves the TickInfo for the given tick from the pool's tick map. +// - Ensures that all fields of TickInfo are non-nil by calling valueOrZero, which replaces nil values with zero. +// - Returns the updated TickInfo. +// +// Returns: +// - TickInfo: The tick data with all fields guaranteed to have valid values (nil fields are set to zero). +// +// Use Case: +// This function ensures the retrieved tick data is always valid and safe for further operations, +// such as calculations or updates, by sanitizing nil fields in the TickInfo structure. func (p *Pool) getTick(tick int32) TickInfo { tickInfo := p.ticks[tick] - tickInfo.init() - + tickInfo.valueOrZero() return tickInfo } -// receiver getters +// GetTickLiquidityGross returns the gross liquidity for the specified tick. func (p *Pool) GetTickLiquidityGross(tick int32) *u256.Uint { return p.mustGetTick(tick).liquidityGross } +// GetTickLiquidityNet returns the net liquidity for the specified tick. func (p *Pool) GetTickLiquidityNet(tick int32) *i256.Int { return p.mustGetTick(tick).liquidityNet } +// GetTickFeeGrowthOutside0X128 returns the fee growth outside the tick for token 0. func (p *Pool) GetTickFeeGrowthOutside0X128(tick int32) *u256.Uint { return p.mustGetTick(tick).feeGrowthOutside0X128 } +// GetTickFeeGrowthOutside1X128 returns the fee growth outside the tick for token 1. func (p *Pool) GetTickFeeGrowthOutside1X128(tick int32) *u256.Uint { return p.mustGetTick(tick).feeGrowthOutside1X128 } +// GetTickCumulativeOutside returns the cumulative liquidity outside the tick. func (p *Pool) GetTickCumulativeOutside(tick int32) int64 { return p.mustGetTick(tick).tickCumulativeOutside } +// GetTickSecondsPerLiquidityOutsideX128 returns the seconds per liquidity outside the tick. func (p *Pool) GetTickSecondsPerLiquidityOutsideX128(tick int32) *u256.Uint { return p.mustGetTick(tick).secondsPerLiquidityOutsideX128 } +// GetTickSecondsOutside returns the seconds outside the tick. func (p *Pool) GetTickSecondsOutside(tick int32) uint32 { return p.mustGetTick(tick).secondsOutside } +// GetTickInitialized returns whether the tick is initialized. func (p *Pool) GetTickInitialized(tick int32) bool { return p.mustGetTick(tick).initialized } +// mustGetTick retrieves the TickInfo for a specific tick, panicking if the tick does not exist. +// +// This function ensures that the requested tick data exists in the pool's tick mapping. +// If the tick does not exist, it panics with an appropriate error message. +// +// Parameters: +// - tick: int32, the index of the tick to retrieve. +// +// Returns: +// - TickInfo: The information associated with the specified tick. +// +// Behavior: +// - Checks if the tick exists in the pool's tick mapping (`p.ticks`). +// - If the tick exists, it returns the corresponding `TickInfo`. +// - If the tick does not exist, the function panics with a descriptive error. +// +// Panic Conditions: +// - The specified tick does not exist in the pool's mapping. +// +// Example: +// +// tickInfo := pool.mustGetTick(10) +// fmt.Println("Tick Info:", tickInfo) func (p *Pool) mustGetTick(tick int32) TickInfo { tickInfo, exist := p.ticks[tick] if !exist { diff --git a/pool/tick_bitmap.gno b/pool/tick_bitmap.gno index 5c4dbf0c8..53c9cc416 100644 --- a/pool/tick_bitmap.gno +++ b/pool/tick_bitmap.gno @@ -1,8 +1,6 @@ package pool import ( - "gno.land/p/demo/ufmt" - plp "gno.land/p/gnoswap/pool" u256 "gno.land/p/gnoswap/uint256" @@ -12,23 +10,45 @@ import ( func tickBitmapPosition(tick int32) (int16, uint8) { wordPos := int16(tick >> 8) // tick / 256 bitPos := uint8(tick % 256) - return wordPos, bitPos } -// tickBitmapFlipTick flips tthe bit corresponding to the given tick -// in the pool's tick bitmap. +// tickBitmapFlipTick flips the state of a tick in the tick bitmap. +// +// This function toggles the "initialized" state of a tick in the tick bitmap. +// It ensures that the tick aligns with the specified tick spacing and then +// flips the corresponding bit in the bitmap representation. +// +// Parameters: +// - tick: int32, the tick index to toggle. +// - tickSpacing: int32, the spacing between valid ticks. +// The tick must align with this spacing. +// +// Workflow: +// 1. Validates that the `tick` aligns with `tickSpacing` using `checkTickSpacing`. +// 2. Computes the position of the bit in the tick bitmap: +// - `wordPos`: Determines which word in the bitmap contains the bit. +// - `bitPos`: Identifies the position of the bit within the word. +// 3. Creates a bitmask using `Lsh` (Left Shift) to target the bit at `bitPos`. +// 4. Toggles (flips) the bit using XOR with the current value of the tick bitmap. +// 5. Updates the tick bitmap with the modified word. +// +// Behavior: +// - If the bit is `0` (uninitialized), it will be flipped to `1` (initialized). +// - If the bit is `1` (initialized), it will be flipped to `0` (uninitialized). +// +// Example: +// +// pool.tickBitmapFlipTick(120, 60) +// // This flips the bit for tick 120 with a tick spacing of 60. +// +// Notes: +// - The `tick` must be divisible by `tickSpacing`. If not, the function will panic. func (p *Pool) tickBitmapFlipTick( tick int32, tickSpacing int32, ) { - if tick%tickSpacing != 0 { - panic(addDetailToError( - errInvalidTickAndTickSpacing, - ufmt.Sprintf("tick_bitmap.gno__tickBitmapFlipTick() || tick(%d) MOD tickSpacing(%d) != 0(%d)", tick, tickSpacing, tick%tickSpacing), - )) - } - + checkTickSpacing(tick, tickSpacing) wordPos, bitPos := tickBitmapPosition(tick / tickSpacing) mask := new(u256.Uint).Lsh(u256.One(), uint(bitPos)) diff --git a/pool/tick_test.gno b/pool/tick_test.gno index 34fb95ab8..32fef15e7 100644 --- a/pool/tick_test.gno +++ b/pool/tick_test.gno @@ -130,8 +130,9 @@ func TestCalculateFeeGrowthInside(t *testing.T) { want0: "13", want1: "12", preconditions: func() { - pool.setTick( + setTick( t, + pool, 2, u256.NewUint(2), u256.NewUint(3), @@ -154,9 +155,10 @@ func TestCalculateFeeGrowthInside(t *testing.T) { want0: "13", want1: "12", preconditions: func() { - pool.deleteTick(t, 2) // delete tick from previous test - pool.setTick( + deleteTick(t, pool, 2) // delete tick from previous test + setTick( t, + pool, -2, u256.NewUint(2), u256.NewUint(3), @@ -180,8 +182,9 @@ func TestCalculateFeeGrowthInside(t *testing.T) { want1: "11", preconditions: func() { // we already have tick -2 - pool.setTick( + setTick( t, + pool, 2, u256.NewUint(4), u256.NewUint(1), @@ -204,10 +207,11 @@ func TestCalculateFeeGrowthInside(t *testing.T) { want0: "16", want1: "13", preconditions: func() { - pool.deleteTick(t, 2) - pool.deleteTick(t, -2) - pool.setTick( + deleteTick(t, pool, 2) + deleteTick(t, pool, -2) + setTick( t, + pool, -2, u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639932"), // max uint256 - 3 u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639933"), // max uint256 - 2 @@ -218,8 +222,9 @@ func TestCalculateFeeGrowthInside(t *testing.T) { 0, true, ) - pool.setTick( + setTick( t, + pool, 2, u256.NewUint(3), u256.NewUint(5), @@ -239,7 +244,7 @@ func TestCalculateFeeGrowthInside(t *testing.T) { if tt.preconditions != nil { tt.preconditions() } - got0, got1 := pool.calculateFeeGrowthInside( + got0, got1 := pool.getFeeGrowthInside( tt.tickLower, tt.tickUpper, tt.tickCurrent, @@ -247,7 +252,7 @@ func TestCalculateFeeGrowthInside(t *testing.T) { tt.feeGrowthGlobal1X128, ) if got0.ToString() != tt.want0 || got1.ToString() != tt.want1 { - t.Errorf("calculateFeeGrowthInside() = (%v, %v), want (%v, %v)", + t.Errorf("getFeeGrowthInside() = (%v, %v), want (%v, %v)", got0.ToString(), got1.ToString(), tt.want0, tt.want1) } }) @@ -287,7 +292,7 @@ func TestTickUpdate(t *testing.T) { { name: "does not flip from nonzero to greater nonzero", preconditions: func() { - pool.deleteTick(t, 0) + deleteTick(t, pool, 0) pool.tickUpdate(0, 0, i256.One(), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) }, tick: 0, @@ -302,7 +307,7 @@ func TestTickUpdate(t *testing.T) { { name: "flips from nonzero to zero", preconditions: func() { - pool.deleteTick(t, 0) + deleteTick(t, pool, 0) pool.tickUpdate(0, 0, i256.One(), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) }, tick: 0, @@ -317,7 +322,7 @@ func TestTickUpdate(t *testing.T) { { name: "does not flip from nonzero to lesser nonzero", preconditions: func() { - pool.deleteTick(t, 0) + deleteTick(t, pool, 0) pool.tickUpdate(0, 0, i256.NewInt(2), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) }, tick: 0, @@ -332,7 +337,7 @@ func TestTickUpdate(t *testing.T) { { name: "reverts if total liquidity gross is greater than max", preconditions: func() { - pool.deleteTick(t, 0) + deleteTick(t, pool, 0) pool.tickUpdate(0, 0, i256.NewInt(2), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) pool.tickUpdate(0, 0, i256.One(), u256.Zero(), u256.Zero(), true, u256.NewUint(3)) @@ -384,8 +389,8 @@ func TestTickUpdate(t *testing.T) { { name: "assumes all growth happens below ticks lte current tick", preconditions: func() { - pool.deleteTick(t, 0) - pool.deleteTick(t, 1) + deleteTick(t, pool, 0) + deleteTick(t, pool, 1) pool.tickUpdate(1, 1, i256.One(), u256.One(), u256.NewUint(2), false, u256.MustFromDecimal("340282366920938463463374607431768211455")) }, tick: 0, @@ -574,8 +579,125 @@ func TestGetTick(t *testing.T) { } } -func (pool *Pool) setTick( +func TestGetFeeGrowthBelowX128(t *testing.T) { + // Setup test data + globalFeeGrowth0 := u256.NewUint(1000) // Global fee growth for token 0 + globalFeeGrowth1 := u256.NewUint(2000) // Global fee growth for token 1 + + lowerTick := TickInfo{ + feeGrowthOutside0X128: u256.NewUint(300), // fee growth outside for token 0 + feeGrowthOutside1X128: u256.NewUint(500), // fee growth outside for token 1 + } + + tests := []struct { + name string + tickLower int32 + tickCurrent int32 + expectedFeeGrowth0 *u256.Uint + expectedFeeGrowth1 *u256.Uint + }{ + { + name: "tickCurrent >= tickLower - Return feeGrowthOutside directly", + tickLower: 100, + tickCurrent: 100, + expectedFeeGrowth0: lowerTick.feeGrowthOutside0X128, + expectedFeeGrowth1: lowerTick.feeGrowthOutside1X128, + }, + { + name: "tickCurrent > tickLower - Return feeGrowthOutside directly", + tickLower: 50, + tickCurrent: 100, + expectedFeeGrowth0: lowerTick.feeGrowthOutside0X128, + expectedFeeGrowth1: lowerTick.feeGrowthOutside1X128, + }, + { + name: "tickCurrent < tickLower - Subtract feeGrowthOutside from global", + tickLower: 100, + tickCurrent: 50, + expectedFeeGrowth0: u256.NewUint(700), // 1000 - 300 + expectedFeeGrowth1: u256.NewUint(1500), // 2000 - 500 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call the function + feeGrowth0, feeGrowth1 := getFeeGrowthBelowX128( + tt.tickLower, tt.tickCurrent, + globalFeeGrowth0, globalFeeGrowth1, + lowerTick, + ) + + // Assertions + uassert.True(t, feeGrowth0.Eq(tt.expectedFeeGrowth0), + "Expected feeGrowth0: %s, got: %s", tt.expectedFeeGrowth0.ToString(), feeGrowth0.ToString()) + uassert.True(t, feeGrowth1.Eq(tt.expectedFeeGrowth1), + "Expected feeGrowth1: %s, got: %s", tt.expectedFeeGrowth1.ToString(), feeGrowth1.ToString()) + }) + } +} + +func TestGetFeeGrowthAboveX128(t *testing.T) { + // Setup test data + globalFeeGrowth0 := u256.NewUint(1000) // Global fee growth for token 0 + globalFeeGrowth1 := u256.NewUint(2000) // Global fee growth for token 1 + + upperTick := TickInfo{ + feeGrowthOutside0X128: u256.NewUint(300), // Fee growth outside for token 0 + feeGrowthOutside1X128: u256.NewUint(500), // Fee growth outside for token 1 + } + + tests := []struct { + name string + tickUpper int32 + tickCurrent int32 + expectedFeeGrowth0 *u256.Uint + expectedFeeGrowth1 *u256.Uint + }{ + { + name: "tickCurrent < tickUpper - Return feeGrowthOutside directly", + tickUpper: 100, + tickCurrent: 50, + expectedFeeGrowth0: upperTick.feeGrowthOutside0X128, // 300 + expectedFeeGrowth1: upperTick.feeGrowthOutside1X128, // 500 + }, + { + name: "tickCurrent >= tickUpper - Subtract feeGrowthOutside from global", + tickUpper: 100, + tickCurrent: 150, + expectedFeeGrowth0: u256.NewUint(700), // 1000 - 300 + expectedFeeGrowth1: u256.NewUint(1500), // 2000 - 500 + }, + { + name: "tickCurrent == tickUpper - Subtract feeGrowthOutside from global", + tickUpper: 100, + tickCurrent: 100, + expectedFeeGrowth0: u256.NewUint(700), // 1000 - 300 + expectedFeeGrowth1: u256.NewUint(1500), // 2000 - 500 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call the function + feeGrowth0, feeGrowth1 := getFeeGrowthAboveX128( + tt.tickUpper, tt.tickCurrent, + globalFeeGrowth0, globalFeeGrowth1, + upperTick, + ) + + // Assertions + uassert.True(t, feeGrowth0.Eq(tt.expectedFeeGrowth0), + "Expected feeGrowth0: %s, got: %s", tt.expectedFeeGrowth0.ToString(), feeGrowth0.ToString()) + uassert.True(t, feeGrowth1.Eq(tt.expectedFeeGrowth1), + "Expected feeGrowth1: %s, got: %s", tt.expectedFeeGrowth1.ToString(), feeGrowth1.ToString()) + }) + } +} + +func setTick( t *testing.T, + pool *Pool, tick int32, feeGrowthOutside0X128 *u256.Uint, feeGrowthOutside1X128 *u256.Uint, @@ -589,7 +711,7 @@ func (pool *Pool) setTick( t.Helper() info := pool.ticks[tick] - info.init() + info.valueOrZero() info.feeGrowthOutside0X128 = feeGrowthOutside0X128 info.feeGrowthOutside1X128 = feeGrowthOutside1X128 @@ -603,8 +725,7 @@ func (pool *Pool) setTick( pool.ticks[tick] = info } -func (pool *Pool) deleteTick(t *testing.T, tick int32) { +func deleteTick(t *testing.T, pool *Pool, tick int32) { t.Helper() - delete(pool.ticks, tick) } diff --git a/pool/type.gno b/pool/type.gno index 0135c34ef..2cca5dbc9 100644 --- a/pool/type.gno +++ b/pool/type.gno @@ -227,7 +227,34 @@ type PositionInfo struct { tokensOwed1 *u256.Uint } -func (p *PositionInfo) init() { +// valueOrZero initializes nil fields in PositionInfo to zero. +// +// This function ensures that all numeric fields in the PositionInfo struct are not nil. +// If a field is nil, it is replaced with a zero value, maintaining consistency and preventing +// potential null pointer issues during calculations. +// +// Fields affected: +// - liquidity: The liquidity amount associated with the position. +// - feeGrowthInside0LastX128: Fee growth for token 0 inside the tick range, last recorded value. +// - feeGrowthInside1LastX128: Fee growth for token 1 inside the tick range, last recorded value. +// - tokensOwed0: The amount of token 0 owed to the position owner. +// - tokensOwed1: The amount of token 1 owed to the position owner. +// +// Behavior: +// - If a field is nil, it is set to its equivalent zero value. +// - If a field already has a value, it remains unchanged. +// +// Example: +// +// position := &PositionInfo{} +// position.valueOrZero() +// fmt.Println(position.liquidity) // Output: 0 +// +// Notes: +// - This function is useful for ensuring numeric fields are properly initialized +// before performing operations or calculations. +// - Prevents runtime errors caused by nil values. +func (p *PositionInfo) valueOrZero() { p.liquidity = p.liquidity.NilToZero() p.feeGrowthInside0LastX128 = p.feeGrowthInside0LastX128.NilToZero() p.feeGrowthInside1LastX128 = p.feeGrowthInside1LastX128.NilToZero() @@ -259,7 +286,26 @@ type TickInfo struct { initialized bool // whether the tick is initialized } -func (t *TickInfo) init() { +// valueOrZero ensures that all fields of TickInfo are valid by setting nil fields to zero, +// while retaining existing values if they are not nil. +// This function updates the TickInfo struct to replace any nil values in its fields +// with their respective zero values, ensuring data consistency. +// +// Behavior: +// - If a field is nil, it is replaced with its zero value. +// - If a field already has a valid value, the value remains unchanged. +// +// Fields: +// - liquidityGross: Gross liquidity for the tick, set to zero if nil, otherwise retains its value. +// - liquidityNet: Net liquidity for the tick, set to zero if nil, otherwise retains its value. +// - feeGrowthOutside0X128: Accumulated fee growth for token0 outside the tick, set to zero if nil, otherwise retains its value. +// - feeGrowthOutside1X128: Accumulated fee growth for token1 outside the tick, set to zero if nil, otherwise retains its value. +// - secondsPerLiquidityOutsideX128: Time per liquidity outside the tick, set to zero if nil, otherwise retains its value. +// +// Use Case: +// This function ensures all numeric fields in TickInfo are non-nil and have valid values, +// preventing potential runtime errors caused by nil values during operations like arithmetic or comparisons. +func (t *TickInfo) valueOrZero() { t.liquidityGross = t.liquidityGross.NilToZero() t.liquidityNet = t.liquidityNet.NilToZero() t.feeGrowthOutside0X128 = t.feeGrowthOutside0X128.NilToZero() @@ -327,71 +373,71 @@ func newPool(poolInfo *createPoolParams) *Pool { } } -func (p *Pool) GetToken0Path() string { +func (p *Pool) Token0Path() string { return p.token0Path } -func (p *Pool) GetToken1Path() string { +func (p *Pool) Token1Path() string { return p.token1Path } -func (p *Pool) GetFee() uint32 { +func (p *Pool) Fee() uint32 { return p.fee } -func (p *Pool) GetBalanceToken0() *u256.Uint { +func (p *Pool) BalanceToken0() *u256.Uint { return p.balances.token0 } -func (p *Pool) GetBalanceToken1() *u256.Uint { +func (p *Pool) BalanceToken1() *u256.Uint { return p.balances.token1 } -func (p *Pool) GetTickSpacing() int32 { +func (p *Pool) TickSpacing() int32 { return p.tickSpacing } -func (p *Pool) GetMaxLiquidityPerTick() *u256.Uint { +func (p *Pool) MaxLiquidityPerTick() *u256.Uint { return p.maxLiquidityPerTick } -func (p *Pool) GetSlot0() Slot0 { +func (p *Pool) Slot0() Slot0 { return p.slot0 } -func (p *Pool) GetSlot0SqrtPriceX96() *u256.Uint { +func (p *Pool) Slot0SqrtPriceX96() *u256.Uint { return p.slot0.sqrtPriceX96 } -func (p *Pool) GetSlot0Tick() int32 { +func (p *Pool) Slot0Tick() int32 { return p.slot0.tick } -func (p *Pool) GetSlot0FeeProtocol() uint8 { +func (p *Pool) Slot0FeeProtocol() uint8 { return p.slot0.feeProtocol } -func (p *Pool) GetSlot0Unlocked() bool { +func (p *Pool) Slot0Unlocked() bool { return p.slot0.unlocked } -func (p *Pool) GetFeeGrowthGlobal0X128() *u256.Uint { +func (p *Pool) FeeGrowthGlobal0X128() *u256.Uint { return p.feeGrowthGlobal0X128 } -func (p *Pool) GetFeeGrowthGlobal1X128() *u256.Uint { +func (p *Pool) FeeGrowthGlobal1X128() *u256.Uint { return p.feeGrowthGlobal1X128 } -func (p *Pool) GetProtocolFeesToken0() *u256.Uint { +func (p *Pool) ProtocolFeesToken0() *u256.Uint { return p.protocolFees.token0 } -func (p *Pool) GetProtocolFeesToken1() *u256.Uint { +func (p *Pool) ProtocolFeesToken1() *u256.Uint { return p.protocolFees.token1 } -func (p *Pool) GetLiquidity() *u256.Uint { +func (p *Pool) Liquidity() *u256.Uint { return p.liquidity } diff --git a/pool/utils.gno b/pool/utils.gno index e45116217..c731d3618 100644 --- a/pool/utils.gno +++ b/pool/utils.gno @@ -4,24 +4,129 @@ import ( "std" "gno.land/p/demo/ufmt" - pusers "gno.land/p/demo/users" - + i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" ) -func safeConvertToUint64(value *u256.Uint) (uint64, error) { +// safeConvertToUint64 safely converts a *u256.Uint value to a uint64, ensuring no overflow. +// +// This function attempts to convert the given *u256.Uint value to a uint64. If the value exceeds +// the maximum allowable range for uint64 (`2^64 - 1`), it triggers a panic with a descriptive error message. +// +// Parameters: +// - value (*u256.Uint): The unsigned 256-bit integer to be converted. +// +// Returns: +// - uint64: The converted value if it falls within the uint64 range. +// +// Panics: +// - If the `value` exceeds the range of uint64, the function will panic with an error indicating +// the overflow and the original value. +// +// Notes: +// - This function uses the `Uint64WithOverflow` method to detect overflow during the conversion. +// - It is essential to validate large values before calling this function to avoid unexpected panics. +// +// Example: +// safeValue := safeConvertToUint64(u256.MustFromDecimal("18446744073709551615")) // Valid conversion +// safeConvertToUint64(u256.MustFromDecimal("18446744073709551616")) // Panics due to overflow +func safeConvertToUint64(value *u256.Uint) uint64 { res, overflow := value.Uint64WithOverflow() if overflow { - return 0, ufmt.Errorf( + panic(ufmt.Sprintf( "%v: amount(%s) overflows uint64 range", - errOutOfRange, value.ToString(), - ) + errOutOfRange, value.ToString())) } + return res +} + +// safeConvertToInt128 safely converts a *u256.Uint value to an *i256.Int, ensuring it does not exceed the int128 range. +// +// This function converts an unsigned 256-bit integer (*u256.Uint) into a signed 256-bit integer (*i256.Int). +// It checks whether the resulting value falls within the valid range of int128 (`-2^127` to `2^127 - 1`). +// If the value exceeds the maximum allowable int128 range, it triggers a panic with a descriptive error message. +// +// Parameters: +// - value (*u256.Uint): The unsigned 256-bit integer to be converted. +// +// Returns: +// - *i256.Int: The converted value if it falls within the int128 range. +// +// Panics: +// - If the converted value exceeds the maximum int128 value (`2^127 - 1`), the function will panic with an +// error message indicating the overflow and the original value. +// +// Notes: +// - The function uses `i256.FromUint256` to perform the conversion. +// - The constant `MAX_INT128` is used to define the upper bound of the int128 range (`170141183460469231731687303715884105727`). +// +// Example: +// validInt128 := safeConvertToInt128(u256.MustFromDecimal("170141183460469231731687303715884105727")) // Valid conversion +// safeConvertToInt128(u256.MustFromDecimal("170141183460469231731687303715884105728")) // Panics due to overflow +func safeConvertToInt128(value *u256.Uint) *i256.Int { + liquidityDelta := i256.FromUint256(value) + if liquidityDelta.Gt(i256.MustFromDecimal(consts.MAX_INT128)) { + panic(ufmt.Sprintf( + "%v: amount(%s) overflows int128 range", + errOverFlow, value.ToString())) + } + return liquidityDelta +} - return res, nil +// toUint128 ensures a *u256.Uint value fits within the uint128 range. +// +// This function validates that the given `value` is properly initialized and checks whether +// it exceeds the maximum value of uint128. If the value exceeds the uint128 range, +// it applies a masking operation to truncate the value to fit within the uint128 limit. +// +// Parameters: +// - value: *u256.Uint, the value to be checked and possibly truncated. +// +// Returns: +// - *u256.Uint: A value guaranteed to fit within the uint128 range. +// +// Notes: +// - The mask ensures that only the lower 128 bits of the value are retained. +// - If the input value is already within the uint128 range, it remains unchanged. +// - MAX_UINT128 is a constant representing `2^128 - 1`. +func toUint128(value *u256.Uint) *u256.Uint { + assertOnlyInitializedUint256(value) + if value.Gt(u256.MustFromDecimal(consts.MAX_UINT128)) { + mask := new(u256.Uint).Lsh(u256.One(), consts.Q128_RESOLUTION) + mask = mask.Sub(mask, u256.One()) + value = value.And(value, mask) + } + return value } +// a2u converts a std.Address to a pusers.AddressOrName, ensuring the input address is valid. +// +// This function takes a `std.Address` and verifies its validity. If the address is invalid, +// the function triggers a panic with an appropriate error message. For valid addresses, +// it performs the conversion to `pusers.AddressOrName`. +// +// Parameters: +// - addr (std.Address): The input address to be converted. +// +// Returns: +// - pusers.AddressOrName: The converted address, wrapped as a `pusers.AddressOrName` type. +// +// Panics: +// - If the provided `addr` is invalid, the function will panic with an error indicating +// the invalid address. +// +// Notes: +// - The function relies on the `addr.IsValid()` method to determine the validity of the input address. +// - It uses `addDetailToError` to provide additional context for the error message when an invalid +// address is encountered. +// +// Example: +// converted := a2u(std.Address("validAddress")) // Successful conversion +// a2u(std.Address("")) // Panics due to invalid address func a2u(addr std.Address) pusers.AddressOrName { if !addr.IsValid() { panic(addDetailToError( @@ -32,23 +137,60 @@ func a2u(addr std.Address) pusers.AddressOrName { return pusers.AddressOrName(addr) } +// u256Min returns the smaller of two *u256.Uint values. +// +// This function compares two unsigned 256-bit integers and returns the smaller of the two. +// If `num1` is less than `num2`, it returns `num1`; otherwise, it returns `num2`. +// +// Parameters: +// - num1 (*u256.Uint): The first unsigned 256-bit integer. +// - num2 (*u256.Uint): The second unsigned 256-bit integer. +// +// Returns: +// - *u256.Uint: The smaller of `num1` and `num2`. +// +// Notes: +// - This function uses the `Lt` (less than) method of `*u256.Uint` to perform the comparison. +// - The function assumes both input values are non-nil. If nil inputs are possible in the usage context, +// additional validation may be needed. +// +// Example: +// smaller := u256Min(u256.MustFromDecimal("10"), u256.MustFromDecimal("20")) // Returns 10 +// smaller := u256Min(u256.MustFromDecimal("30"), u256.MustFromDecimal("20")) // Returns 20 func u256Min(num1, num2 *u256.Uint) *u256.Uint { if num1.Lt(num2) { return num1 } - return num2 } -func isUserCall() bool { - return std.PrevRealm().IsUser() +// derivePkgAddr derives the Realm address from it's pkgPath parameter +func derivePkgAddr(pkgPath string) std.Address { + return std.DerivePkgAddr(pkgPath) } -func getPrev() (string, string) { - prev := std.PrevRealm() +// getPrevRealm returns object of the previous realm. +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrevAsString returns the address and package path of the previous realm. +func getPrevAsString() (string, string) { + prev := getPrevRealm() return prev.Addr().String(), prev.PkgPath() } +// isUserCall returns true if the caller is a user. +func isUserCall() bool { + return std.PrevRealm().IsUser() +} + +// checkTransferError checks transfer error. func checkTransferError(err error) { if err != nil { panic(addDetailToError( @@ -57,3 +199,79 @@ func checkTransferError(err error) { )) } } + +// checkOverFlowInt128 checks if the value overflows the int128 range. +func checkOverFlowInt128(value *i256.Int) { + if value.Gt(i256.MustFromDecimal(consts.MAX_INT128)) { + panic(ufmt.Sprintf( + "%v: amount(%s) overflows int128 range", + errOverFlow, value.ToString())) + } +} + +// checkTickSpacing checks if the tick is divisible by the tickSpacing. +func checkTickSpacing(tick, tickSpacing int32) { + if tick%tickSpacing != 0 { + panic(addDetailToError( + errInvalidTickAndTickSpacing, + ufmt.Sprintf("tick(%d) MOD tickSpacing(%d) != 0(%d)", tick, tickSpacing, tick%tickSpacing), + )) + } +} + +// assertOnlyValidAddress panics if the address is invalid. +func assertOnlyValidAddress(addr std.Address) { + if !addr.IsValid() { + panic(addDetailToError( + errInvalidAddress, + ufmt.Sprintf("(%s)", addr), + )) + } +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + +// assertOnlyPositionContract panics if the caller is not the position contract. +func assertOnlyPositionContract() { + caller := getPrevAddr() + if err := common.PositionOnly(caller); err != nil { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only position(%s) can call, called from %s", consts.POSITION_ADDR, caller.String()), + )) + } +} + +// assertOnlyInitializedUint256 panics if the value is nil. +func assertOnlyInitializedUint256(value *u256.Uint) { + if value == nil { + panic(addDetailToError( + errInvalidInput, + "value is nil", + )) + } +} + +// assertOnlyAdmin panics if the caller is not the admin. +func assertOnlyAdmin() { + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err) + } +} + +// assertOnlyGovernance panics if the caller is not the governance. +func assertOnlyGovernance() { + caller := getPrevAddr() + if err := common.GovernanceOnly(caller); err != nil { + panic(err) + } +} + +// assertOnlyRegistered panics if the token is not registered. +func assertOnlyRegistered(tokenPath string) { + common.MustRegistered(tokenPath) +} diff --git a/pool/utils_test.gno b/pool/utils_test.gno index 623cc4aef..e54578948 100644 --- a/pool/utils_test.gno +++ b/pool/utils_test.gno @@ -6,10 +6,12 @@ import ( "gno.land/p/demo/testutils" "gno.land/p/demo/uassert" - + pusers "gno.land/p/demo/users" + i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" "gno.land/r/demo/users" + "gno.land/r/gnoswap/v1/consts" ) func TestA2U(t *testing.T) { @@ -125,7 +127,7 @@ func TestIsUserCall(t *testing.T) { } } -func TestGetPrev(t *testing.T) { +func TestGetPrevAsString(t *testing.T) { tests := []struct { name string action func() (string, string) @@ -137,7 +139,7 @@ func TestGetPrev(t *testing.T) { action: func() (string, string) { userRealm := std.NewUserRealm(std.Address("user")) std.TestSetRealm(userRealm) - return getPrev() + return getPrevAsString() }, expectedAddr: "user", expectedPkgPath: "", @@ -147,7 +149,7 @@ func TestGetPrev(t *testing.T) { action: func() (string, string) { codeRealm := std.NewCodeRealm("gno.land/r/demo/realm") std.TestSetRealm(codeRealm) - return getPrev() + return getPrevAsString() }, expectedAddr: std.DerivePkgAddr("gno.land/r/demo/realm").String(), expectedPkgPath: "gno.land/r/demo/realm", @@ -162,3 +164,290 @@ func TestGetPrev(t *testing.T) { }) } } + +func TestSafeConvertToUint64(t *testing.T) { + tests := []struct { + name string + value *u256.Uint + wantRes uint64 + wantPanic bool + }{ + {"normal conversion", u256.NewUint(123), 123, false}, + {"overflow", u256.MustFromDecimal(consts.MAX_UINT128), 0, true}, + {"max uint64", u256.NewUint(1<<64 - 1), 1<<64 - 1, false}, + {"zero", u256.NewUint(0), 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.wantPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } + if tt.wantPanic { + t.Errorf("expected panic, but none occurred") + } + }() + + res := safeConvertToUint64(tt.value) + if res != tt.wantRes { + t.Errorf("safeConvertToUint64() = %v, want %v", res, tt.wantRes) + } + }) + } +} + +func TestSafeConvertToInt128(t *testing.T) { + tests := []struct { + name string + value string + wantRes string + wantPanic bool + }{ + {"normal conversion", "170141183460469231731687303715884105727", "170141183460469231731687303715884105727", false}, + {"overflow", "170141183460469231731687303715884105728", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.wantPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } + if tt.wantPanic { + t.Errorf("expected panic, but none occurred") + } + }() + + res := safeConvertToInt128(u256.MustFromDecimal(tt.value)) + if res.ToString() != tt.wantRes { + t.Errorf("safeConvertToUint64() = %v, want %v", res, tt.wantRes) + } + }) + } +} + +func TestA2u(t *testing.T) { + var ( + addr = std.Address("g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8c") + ) + + tests := []struct { + name string + input std.Address + expected pusers.AddressOrName + }{ + { + name: "Success - a2u", + input: addr, + expected: pusers.AddressOrName(addr), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := a2u(tc.input) + uassert.Equal(t, users.Resolve(got).String(), users.Resolve(tc.expected).String()) + }) + } +} + +func TestDerivePkgAddr(t *testing.T) { + var ( + pkgPath = "gno.land/r/gnoswap/v1/position" + ) + tests := []struct { + name string + input string + expected string + }{ + { + name: "Success - derivePkgAddr", + input: pkgPath, + expected: "g1q646ctzhvn60v492x8ucvyqnrj2w30cwh6efk5", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := derivePkgAddr(tc.input) + uassert.Equal(t, got.String(), tc.expected) + }) + } +} + +func TestGetPrevRealm(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prevRealm is User", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevRealm() + uassert.Equal(t, got.Addr().String(), tc.expected[0]) + uassert.Equal(t, got.PkgPath(), tc.expected[1]) + }) + } +} + +func TestGetPrevAddr(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected std.Address + }{ + { + name: "Success - prev Address is User", + originCaller: consts.ADMIN, + expected: "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevAddr() + uassert.Equal(t, got.String(), tc.expected.String()) + }) + } +} + +func TestCheckOverFlowInt128(t *testing.T) { + tests := []struct { + name string + input *i256.Int + shouldPanic bool + expected string + }{ + { + name: "Valid value within int128 range", + input: i256.MustFromDecimal("1"), + shouldPanic: false, + }, + { + name: "Edge case - MAX_INT128", + input: i256.MustFromDecimal(consts.MAX_INT128), + shouldPanic: false, + }, + { + name: "Overflow case - exceeds MAX_INT128", + input: i256.MustFromDecimal(consts.MAX_INT256), // 최대값 + 1 + shouldPanic: true, + expected: "[GNOSWAP-POOL-026] overflow: amount(170141183460469231731687303715884105728) overflows int128 range", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.shouldPanic { + t.Errorf("Unexpected panic: %v", r) + } + } else if tt.shouldPanic { + uassert.Equal(t, tt.expected, r) + t.Errorf("Expected panic but none occurred") + } + }() + checkOverFlowInt128(tt.input) + }) + } +} + +func TestCheckTickSpacing(t *testing.T) { + tests := []struct { + name string + tick int32 + tickSpacing int32 + shouldPanic bool + expected string + }{ + { + name: "Valid tick - divisible by tickSpacing", + tick: 120, + tickSpacing: 60, + shouldPanic: false, + }, + { + name: "Valid tick - zero tick", + tick: 0, + tickSpacing: 10, + shouldPanic: false, + }, + { + name: "Invalid tick - not divisible", + tick: 15, + tickSpacing: 10, + shouldPanic: true, + expected: "[GNOSWAP-POOL-022] invalid tick and tick spacing requested || tick(15) MOD tickSpacing(10) != 0(5)", + }, + { + name: "Invalid tick - negative tick", + tick: -35, + tickSpacing: 20, + shouldPanic: true, + expected: "[GNOSWAP-POOL-022] invalid tick and tick spacing requested || tick(-35) MOD tickSpacing(20) != 0(-15)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.shouldPanic { + t.Errorf("Unexpected panic: %v", r) + } + } else if tt.shouldPanic { + uassert.Equal(t, tt.expected, r) + } + }() + checkTickSpacing(tt.tick, tt.tickSpacing) + }) + } +} + +func TestAssertOnlyValidAddress(t *testing.T) { + tests := []struct { + name string + addr std.Address + expected bool + errorMsg string + }{ + { + name: "Success - valid address", + addr: consts.ADMIN, + expected: true, + }, + { + name: "Failure - invalid address", + addr: "g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8", // invalid length + expected: false, + errorMsg: "[GNOSWAP-POOL-023] invalid address || (g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyValidAddress(tc.addr) + }) + } else { + uassert.PanicsWithMessage(t, tc.errorMsg, func() { + assertOnlyValidAddress(tc.addr) + }) + } + }) + } +} diff --git a/position/_RPC_api.gno b/position/_RPC_api.gno index 10df9acf2..5b1cd3d0d 100644 --- a/position/_RPC_api.gno +++ b/position/_RPC_api.gno @@ -7,14 +7,11 @@ import ( "gno.land/p/demo/json" "gno.land/p/demo/ufmt" + i256 "gno.land/p/gnoswap/int256" "gno.land/r/gnoswap/v1/common" "gno.land/r/gnoswap/v1/consts" - - pl "gno.land/r/gnoswap/v1/pool" - - i256 "gno.land/p/gnoswap/int256" - "gno.land/r/gnoswap/v1/gnft" + pl "gno.land/r/gnoswap/v1/pool" ) type RpcPosition struct { @@ -69,7 +66,7 @@ func ApiGetPositions() string { } // STAT NODE - _stat := json.ObjectNode("", map[string]*json.Node{ + stat := json.ObjectNode("", map[string]*json.Node{ "height": json.NumberNode("height", float64(std.GetHeight())), "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), }) @@ -99,7 +96,7 @@ func ApiGetPositions() string { } node := json.ObjectNode("", map[string]*json.Node{ - "stat": _stat, + "stat": stat, "response": responses, }) @@ -390,7 +387,7 @@ func rpcMakePosition(lpTokenId uint64) RpcPosition { burned := isBurned(lpTokenId) pool := pl.GetPoolFromPoolPath(position.poolKey) - currentX96 := pool.GetSlot0SqrtPriceX96() + currentX96 := pool.Slot0SqrtPriceX96() lowerX96 := common.TickMathGetSqrtRatioAtTick(position.tickLower) upperX96 := common.TickMathGetSqrtRatioAtTick(position.tickUpper) @@ -439,12 +436,12 @@ func unclaimedFee(tokenId uint64) (*i256.Int, *i256.Int) { poolKey := positions[tokenId].poolKey pool := pl.GetPoolFromPoolPath(poolKey) - currentTick := pool.GetSlot0Tick() + currentTick := pool.Slot0Tick() - _feeGrowthGlobal0X128 := pool.GetFeeGrowthGlobal0X128() // u256 + _feeGrowthGlobal0X128 := pool.FeeGrowthGlobal0X128() // u256 feeGrowthGlobal0X128 := i256.FromUint256(_feeGrowthGlobal0X128) // i256 - _feeGrowthGlobal1X128 := pool.GetFeeGrowthGlobal1X128() // u256 + _feeGrowthGlobal1X128 := pool.FeeGrowthGlobal1X128() // u256 feeGrowthGlobal1X128 := i256.FromUint256(_feeGrowthGlobal1X128) // i256 _tickUpperFeeGrowthOutside0X128 := pool.GetTickFeeGrowthOutside0X128(tickUpper) // u256 diff --git a/position/liquidity_management.gno b/position/liquidity_management.gno index be926f912..d700bfb33 100644 --- a/position/liquidity_management.gno +++ b/position/liquidity_management.gno @@ -16,7 +16,7 @@ import ( func addLiquidity(params AddLiquidityParams) (*u256.Uint, *u256.Uint, *u256.Uint) { pool := pl.GetPoolFromPoolPath(params.poolKey) - sqrtPriceX96 := pool.GetSlot0SqrtPriceX96() + sqrtPriceX96 := pool.Slot0SqrtPriceX96() sqrtRatioAX96 := common.TickMathGetSqrtRatioAtTick(params.tickLower) sqrtRatioBX96 := common.TickMathGetSqrtRatioAtTick(params.tickUpper) diff --git a/position/position.gno b/position/position.gno index f33ffd0b6..42503da34 100644 --- a/position/position.gno +++ b/position/position.gno @@ -189,7 +189,8 @@ func mint(params MintParams) (uint64, *u256.Uint, *u256.Uint, *u256.Uint) { nextId++ positionKey := positionKeyCompute(GetOrigPkgAddr(), params.tickLower, params.tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128 := pool.PositionFeeGrowthInside0LastX128(positionKey) + _feeGrowthInside1LastX128 := pool.PositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -320,7 +321,8 @@ func increaseLiquidity(params IncreaseLiquidityParams) (uint64, *u256.Uint, *u25 pool := pl.GetPoolFromPoolPath(position.poolKey) positionKey := positionKeyCompute(GetOrigPkgAddr(), position.tickLower, position.tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128 := pool.PositionFeeGrowthInside0LastX128(positionKey) + _feeGrowthInside1LastX128 := pool.PositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -446,7 +448,8 @@ func decreaseLiquidity(params DecreaseLiquidityParams) (uint64, *u256.Uint, *u25 verifyBurnedAmounts(burnedAmount0, burnedAmount1, params.amount0Min, params.amount1Min) positionKey := positionKeyCompute(GetOrigPkgAddr(), position.tickLower, position.tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128 := pool.PositionFeeGrowthInside0LastX128(positionKey) + _feeGrowthInside1LastX128 := pool.PositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -592,7 +595,8 @@ func Reposition( pool := pl.GetPoolFromPoolPath(position.poolKey) positionKey := positionKeyCompute(GetOrigPkgAddr(), tickLower, tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128 := pool.PositionFeeGrowthInside0LastX128(positionKey) + _feeGrowthInside1LastX128 := pool.PositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -676,7 +680,8 @@ func CollectFee(tokenId uint64, unwrapResult bool) (uint64, string, string, stri positionKey := positionKeyCompute(GetOrigPkgAddr(), position.tickLower, position.tickUpper) pool := pl.GetPoolFromPoolPath(position.poolKey) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128 := pool.PositionFeeGrowthInside0LastX128(positionKey) + _feeGrowthInside1LastX128 := pool.PositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) From 1e9c27049826c7283009baff46e6714de446d5d4 Mon Sep 17 00:00:00 2001 From: Blake <104744707+r3v4s@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:13:11 +0900 Subject: [PATCH 3/4] refactor: use avl.Tree in pool contract (#430) --- pool/_helper_test.gno | 13 +++-- pool/api.gno | 43 +++++++++----- pool/api_test.gno | 2 +- pool/getter.gno | 6 +- pool/getter_test.gno | 16 +++--- pool/pool.gno | 17 +++--- pool/pool_manager.gno | 82 ++++++++++++++++----------- pool/pool_manager_test.gno | 31 +--------- pool/pool_test.gno | 7 ++- pool/position.gno | 38 ++++++++----- pool/position_update.gno | 4 +- pool/position_update_test.gno | 4 +- pool/protocol_fee_withdrawal_test.gno | 3 +- pool/swap_test.gno | 17 +++--- pool/tick.gno | 30 +++++++--- pool/tick_bitmap.gno | 29 ++++++++-- pool/tick_bitmap_test.gno | 7 ++- pool/tick_test.gno | 41 +++++++------- pool/type.gno | 29 +++++----- pool/utils.gno | 2 +- 20 files changed, 236 insertions(+), 185 deletions(-) diff --git a/pool/_helper_test.gno b/pool/_helper_test.gno index ea719ff2c..34826edaa 100644 --- a/pool/_helper_test.gno +++ b/pool/_helper_test.gno @@ -13,6 +13,7 @@ import ( "gno.land/r/onbloc/qux" "gno.land/r/onbloc/usdc" + "gno.land/p/demo/avl" "gno.land/p/demo/testutils" "gno.land/p/demo/uassert" pusers "gno.land/p/demo/users" @@ -23,7 +24,7 @@ import ( const ( ugnotDenom string = "ugnot" - ugnotPath string = "gno.land/r/gnoswap/v1/pool:ugnot" + ugnotPath string = "ugnot" wugnotPath string = "gno.land/r/demo/wugnot" gnsPath string = "gno.land/r/gnoswap/v1/gns" barPath string = "gno.land/r/onbloc/bar" @@ -517,7 +518,7 @@ func ugnotDeposit(t *testing.T, addr std.Address, amount uint64) { // resetObject resets the object state(clear or make it default values) func resetObject(t *testing.T) { - pools = make(poolMap) + pools = avl.NewTree() slot0FeeProtocol = 0 poolCreationFee = 100_000_000 withdrawalFeeBPS = 100 @@ -570,11 +571,11 @@ func burnUsdc(addr pusers.AddressOrName) { func TestBeforeResetObject(t *testing.T) { // make some data - pools = make(poolMap) - pools["gno.land/r/gnoswap/v1/gns:gno.land/r/onbloc/usdc"] = &Pool{ + pools = avl.NewTree() + pools.Set("gno.land/r/gnoswap/v1/gns:gno.land/r/onbloc/usdc", &Pool{ token0Path: "gno.land/r/gnoswap/v1/gns", token1Path: "gno.land/r/onbloc/usdc", - } + }) slot0FeeProtocol = 1 poolCreationFee = 100_000_000 @@ -591,7 +592,7 @@ func TestBeforeResetObject(t *testing.T) { func TestResetObject(t *testing.T) { resetObject(t) - uassert.Equal(t, len(pools), 0) + uassert.Equal(t, pools.Size(), 0) uassert.Equal(t, slot0FeeProtocol, uint8(0)) uassert.Equal(t, poolCreationFee, uint64(100_000_000)) uassert.Equal(t, withdrawalFeeBPS, uint64(100)) diff --git a/pool/api.gno b/pool/api.gno index afe4d9fa6..32c15a7db 100644 --- a/pool/api.gno +++ b/pool/api.gno @@ -4,6 +4,7 @@ import ( b64 "encoding/base64" "gno.land/p/demo/json" + u256 "gno.land/p/gnoswap/uint256" "std" "strconv" @@ -81,14 +82,16 @@ type RpcPosition struct { func ApiGetPools() string { rpcPools := []RpcPool{} - for poolPath, _ := range pools { + pools.Iterate("", "", func(poolPath string, value interface{}) bool { rpcPool := rpcMakePool(poolPath) rpcPools = append(rpcPools, rpcPool) - } + + return false + }) responses := json.ArrayNode("", []*json.Node{}) for _, pool := range rpcPools { - _poolNode := json.ObjectNode("", map[string]*json.Node{ + poolNode := json.ObjectNode("", map[string]*json.Node{ "poolPath": json.StringNode("poolPath", pool.PoolPath), "token0Path": json.StringNode("token0Path", pool.Token0Path), "token1Path": json.StringNode("token1Path", pool.Token1Path), @@ -110,7 +113,7 @@ func ApiGetPools() string { "tickBitmaps": json.ObjectNode("tickBitmaps", makeRpcTickBitmapsJson(pool.TickBitmaps)), "positions": json.ArrayNode("positions", makeRpcPositionsArray(pool.Positions)), }) - responses.AppendArray(_poolNode) + responses.AppendArray(poolNode) } node := json.ObjectNode("", map[string]*json.Node{ @@ -122,8 +125,7 @@ func ApiGetPools() string { } func ApiGetPool(poolPath string) string { - _, exist := pools[poolPath] - if !exist { + if !pools.Has(poolPath) { return "" } rpcPool := rpcMakePool(poolPath) @@ -198,8 +200,11 @@ func rpcMakePool(poolPath string) RpcPool { rpcPool.Liquidity = pool.liquidity.ToString() rpcPool.Ticks = RpcTicks{} - for tick, tickInfo := range pool.ticks { - rpcPool.Ticks[tick] = RpcTickInfo{ + pool.ticks.Iterate("", "", func(tickStr string, iTickInfo interface{}) bool { + tick, _ := strconv.Atoi(tickStr) + tickInfo := iTickInfo.(TickInfo) + + rpcPool.Ticks[int32(tick)] = RpcTickInfo{ LiquidityGross: tickInfo.liquidityGross.ToString(), LiquidityNet: tickInfo.liquidityNet.ToString(), FeeGrowthOutside0X128: tickInfo.feeGrowthOutside0X128.ToString(), @@ -209,17 +214,22 @@ func rpcMakePool(poolPath string) RpcPool { SecondsOutside: tickInfo.secondsOutside, Initialized: tickInfo.initialized, } - } + + return false + }) rpcPool.TickBitmaps = RpcTickBitmaps{} - for tick, tickBitmap := range pool.tickBitmaps { - rpcPool.TickBitmaps[tick] = tickBitmap.ToString() - } + pool.tickBitmaps.Iterate("", "", func(tickStr string, iTickBitmap interface{}) bool { + tick, _ := strconv.Atoi(tickStr) + pool.setTickBitmap(int16(tick), iTickBitmap.(*u256.Uint)) + + return false + }) - Positions := pool.positions rpcPositions := []RpcPosition{} - for posKey, posInfo := range Positions { + pool.positions.Iterate("", "", func(posKey string, iPositionInfo interface{}) bool { owner, tickLower, tickUpper := posKeyDivide(posKey) + posInfo := iPositionInfo.(PositionInfo) rpcPositions = append(rpcPositions, RpcPosition{ Owner: owner, @@ -229,7 +239,10 @@ func rpcMakePool(poolPath string) RpcPool { Token0Owed: posInfo.tokensOwed0.ToString(), Token1Owed: posInfo.tokensOwed1.ToString(), }) - } + + return false + }) + rpcPool.Positions = rpcPositions return rpcPool diff --git a/pool/api_test.gno b/pool/api_test.gno index 58031e921..fcef1abe2 100644 --- a/pool/api_test.gno +++ b/pool/api_test.gno @@ -21,7 +21,7 @@ func TestInitTwoPools(t *testing.T) { // bar:baz CreatePool(barPath, bazPath, fee500, "130621891405341611593710811006") // tick 10000 - uassert.Equal(t, len(pools), 2) + uassert.Equal(t, pools.Size(), 2) } func TestApiGetPools(t *testing.T) { diff --git a/pool/getter.gno b/pool/getter.gno index 443516141..6c22f1d15 100644 --- a/pool/getter.gno +++ b/pool/getter.gno @@ -2,9 +2,11 @@ package pool func PoolGetPoolList() []string { poolPaths := []string{} - for poolPath, _ := range pools { + pools.Iterate("", "", func(poolPath string, _ interface{}) bool { poolPaths = append(poolPaths, poolPath) - } + + return false + }) return poolPaths } diff --git a/pool/getter_test.gno b/pool/getter_test.gno index 1772cc228..e0727669b 100644 --- a/pool/getter_test.gno +++ b/pool/getter_test.gno @@ -3,6 +3,8 @@ package pool import ( "testing" + "gno.land/p/demo/avl" + i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" ) @@ -35,8 +37,8 @@ func TestInitData(t *testing.T) { liquidity: u256.NewUint(1000000), } - mockTicks := Ticks{} - mockTicks[0] = TickInfo{ + mockTicks := avl.NewTree() + mockTicks.Set("0", TickInfo{ liquidityGross: u256.NewUint(1000000), liquidityNet: i256.NewInt(2000000), feeGrowthOutside0X128: u256.NewUint(3000000), @@ -45,20 +47,20 @@ func TestInitData(t *testing.T) { secondsPerLiquidityOutsideX128: u256.NewUint(6000000), secondsOutside: 7, initialized: true, - } + }) mockPool.ticks = mockTicks - mockPositions := Positions{} - mockPositions["test_position"] = PositionInfo{ + mockPositions := avl.NewTree() + mockPositions.Set("test_position", PositionInfo{ liquidity: u256.NewUint(1000000), feeGrowthInside0LastX128: u256.NewUint(2000000), feeGrowthInside1LastX128: u256.NewUint(3000000), tokensOwed0: u256.NewUint(4000000), tokensOwed1: u256.NewUint(5000000), - } + }) mockPool.positions = mockPositions - pools["token0:token1:3000"] = mockPool + pools.Set("token0:token1:3000", mockPool) } func TestPoolGetters(t *testing.T) { diff --git a/pool/pool.gno b/pool/pool.gno index df2e51ee8..67308c228 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -124,7 +124,7 @@ func Burn( } positionKey := getPositionKey(caller, tickLower, tickUpper) - pool.positions[positionKey] = position + pool.setPosition(positionKey, position) // actual token transfer happens in Collect() return amount0.ToString(), amount1.ToString() @@ -192,7 +192,7 @@ func Collect( checkTransferError(token1.Transfer(recipient, amount1.Uint64())) } - pool.positions[positionKey] = position + pool.setPosition(positionKey, position) return amount0.ToString(), amount1.ToString() } @@ -316,12 +316,15 @@ func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 { // - feePrtocol0 occupies the lower 4 bits // - feeProtocol1 is shifted the lower 4 positions to occupy the upper 4 bits newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 ) + // Update slot0 for each pool - for _, pool := range pools { - if pool != nil { - pool.slot0.feeProtocol = newFee - } - } + pools.Iterate("", "", func(poolPath string, iPool interface{}) bool { + pool := iPool.(*Pool) + pool.slot0.feeProtocol = newFee + + return false + }) + // update slot0 slot0FeeProtocol = newFee return newFee diff --git a/pool/pool_manager.gno b/pool/pool_manager.gno index 575b95594..cd04d0416 100644 --- a/pool/pool_manager.gno +++ b/pool/pool_manager.gno @@ -2,8 +2,10 @@ package pool import ( "std" + "strconv" "strings" + "gno.land/p/demo/avl" "gno.land/p/demo/ufmt" "gno.land/r/gnoswap/v1/consts" @@ -15,39 +17,18 @@ import ( u256 "gno.land/p/gnoswap/uint256" ) -type poolMap map[string]*Pool - -func (pm *poolMap) Get(poolPath string) (*Pool, bool) { - pool, exist := (*pm)[poolPath] - if !exist { - return nil, false - } - - return pool, true -} - -func (pm *poolMap) Set(poolPath string, pool *Pool) { - (*pm)[poolPath] = pool -} - -type tickSpacingMap map[uint32]int32 - -func (t *tickSpacingMap) Get(fee uint32) int32 { - return (*t)[fee] -} - var ( - feeAmountTickSpacing tickSpacingMap = make(tickSpacingMap) // maps fee to tickSpacing || map[feeAmount]tick_spacing - pools poolMap = make(poolMap) // maps poolPath to pool || map[poolPath]*Pool + feeAmountTickSpacing = avl.NewTree() // feeBps(uint32) -> tickSpacing(int32) + pools = avl.NewTree() // poolPath -> *Pool slot0FeeProtocol uint8 = 0 ) func init() { - feeAmountTickSpacing[100] = 1 // 0.01% - feeAmountTickSpacing[500] = 10 // 0.05% - feeAmountTickSpacing[3000] = 60 // 0.3% - feeAmountTickSpacing[10000] = 200 // 1% + setFeeAmountTickSpacing(100, 1) // 0.01% + setFeeAmountTickSpacing(500, 10) // 0.05% + setFeeAmountTickSpacing(3000, 60) // 0.3% + setFeeAmountTickSpacing(10000, 200) // 1% } // createPoolParams holds the essential parameters for creating a new pool. @@ -66,7 +47,7 @@ func newPoolParams( sqrtPriceX96 string, ) *createPoolParams { price := u256.MustFromDecimal(sqrtPriceX96) - tickSpacing := feeAmountTickSpacing.Get(fee) + tickSpacing := GetFeeAmountTickSpacing(fee) return &createPoolParams{ token0Path: token0Path, token1Path: token1Path, @@ -211,8 +192,7 @@ func CreatePool( // DoesPoolPathExist checks if a pool exists for the given poolPath. // The poolPath is a unique identifier for a pool, combining token paths and fee. func DoesPoolPathExist(poolPath string) bool { - _, exist := pools[poolPath] - return exist + return pools.Has(poolPath) } // GetPool retrieves a pool instance based on the provided token paths and fee tier. @@ -260,14 +240,14 @@ func GetPool(token0Path, token1Path string, fee uint32) *Pool { // Example: // pool := GetPoolFromPoolPath("path/to/pool") func GetPoolFromPoolPath(poolPath string) *Pool { - pool, exist := pools[poolPath] + iPool, exist := pools.Get(poolPath) if !exist { panic(addDetailToError( errDataNotFound, ufmt.Sprintf("expected poolPath(%s) to exist", poolPath), )) } - return pool + return iPool.(*Pool) } // GetPoolPath generates a unique pool path string based on the token paths and fee tier. @@ -303,3 +283,41 @@ func GetPoolPath(token0Path, token1Path string, fee uint32) string { } return ufmt.Sprintf("%s:%s:%d", token0Path, token1Path, fee) } + +// GetFeeAmountTickSpacing retrieves the tick spacing associated with a given fee amount. +// The tick spacing determines the minimum distance between ticks in the pool. +// +// Parameters: +// - fee (uint32): The fee tier in basis points (e.g., 3000 for 0.3%) +// +// Returns: +// - int32: The tick spacing value for the given fee tier +// +// Panics: +// - If the fee amount is not registered in feeAmountTickSpacing +func GetFeeAmountTickSpacing(fee uint32) int32 { + feeStr := strconv.FormatUint(uint64(fee), 10) + iTickSpacing, exist := feeAmountTickSpacing.Get(feeStr) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("expected feeAmountTickSpacing(%s) to exist", feeStr), + )) + } + + return iTickSpacing.(int32) +} + +// setFeeAmountTickSpacing associates a tick spacing value with a fee amount. +// This is typically called during initialization to set up supported fee tiers. +// +// Parameters: +// - fee (uint32): The fee tier in basis points (e.g., 3000 for 0.3%) +// - tickSpacing (int32): The minimum tick spacing for this fee tier +// +// Note: Smaller tick spacing allows for more granular price points but increases +// computational overhead. Higher fee tiers typically use larger tick spacing. +func setFeeAmountTickSpacing(fee uint32, tickSpacing int32) { + feeStr := strconv.FormatUint(uint64(fee), 10) + feeAmountTickSpacing.Set(feeStr, tickSpacing) +} diff --git a/pool/pool_manager_test.gno b/pool/pool_manager_test.gno index 36637a281..4ea35cd95 100644 --- a/pool/pool_manager_test.gno +++ b/pool/pool_manager_test.gno @@ -8,29 +8,6 @@ import ( "gno.land/r/gnoswap/v1/consts" ) -func TestPoolMapOperations(t *testing.T) { - pm := make(poolMap) - - poolPath := "token0:token1:500" - params := newPoolParams("token0", "token1", 500, "4295128740") - pool := newPool(params) - - pm.Set(poolPath, pool) - - retrieved, exists := pm.Get(poolPath) - if !exists { - t.Error("Expected pool to exist") - } - if retrieved != pool { - t.Error("Retrieved pool doesn't match original") - } - - _, exists = pm.Get("nonexistent") - if exists { - t.Error("Expected pool to not exist") - } -} - func TestNewPoolParams(t *testing.T) { params := newPoolParams( "token0", @@ -98,7 +75,7 @@ func TestTickSpacingMap(t *testing.T) { } for _, tt := range tests { - spacing := feeAmountTickSpacing.Get(tt.fee) + spacing := GetFeeAmountTickSpacing(tt.fee) if spacing != tt.tickSpacing { t.Errorf("For fee %d, expected tick spacing %d, got %d", tt.fee, tt.tickSpacing, spacing) @@ -176,11 +153,7 @@ func TestCreatePool(t *testing.T) { if !tt.shouldPanic { // verify pool was created correctly poolPath := GetPoolPath(tt.token0Path, tt.token1Path, tt.fee) - pool, exists := pools.Get(poolPath) - if !exists { - t.Errorf("pool was not created") - return - } + pool := mustGetPool(poolPath) // check if GNOT was properly wrapped expectedToken0 := tt.token0Path diff --git a/pool/pool_test.gno b/pool/pool_test.gno index 7a7050321..9cef9f82e 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -3,6 +3,7 @@ package pool import ( "testing" + "gno.land/p/demo/avl" "gno.land/p/demo/testutils" "gno.land/p/demo/uassert" u256 "gno.land/p/gnoswap/uint256" @@ -57,7 +58,7 @@ func TestBurn(t *testing.T) { tokensOwed1: u256.NewUint(0), } mockPool := &Pool{ - positions: make(map[string]PositionInfo), + positions: avl.NewTree(), } GetPool = func(token0Path, token1Path string, fee uint32) *Pool { @@ -98,7 +99,7 @@ func TestBurn(t *testing.T) { // setup position for this test posKey := getPositionKey(mockCaller, tt.tickLower, tt.tickUpper) - mockPool.positions[posKey] = mockPosition + mockPool.positions.Set(posKey, mockPosition) if tt.expectPanic { defer func() { @@ -125,7 +126,7 @@ func TestBurn(t *testing.T) { t.Errorf("expected amount1 %s, got %s", tt.expectedAmount1, amount1) } - newPosition := mockPool.positions[posKey] + newPosition := mockPool.mustGetPosition(posKey) if newPosition.tokensOwed0.IsZero() { t.Error("expected tokensOwed0 to be updated") } diff --git a/pool/position.gno b/pool/position.gno index 8afba9905..572b556dd 100644 --- a/pool/position.gno +++ b/pool/position.gno @@ -120,10 +120,10 @@ func (p *Pool) positionUpdateWithKey( feeGrowthInside0X128 = feeGrowthInside0X128.NilToZero() feeGrowthInside1X128 = feeGrowthInside1X128.NilToZero() - positionToUpdate := p.GetPosition(positionKey) + positionToUpdate, _ := p.GetPosition(positionKey) positionAfterUpdate := positionUpdate(positionToUpdate, liquidityDelta, feeGrowthInside0X128, feeGrowthInside1X128) - p.positions[positionKey] = positionAfterUpdate + p.setPosition(positionKey, positionAfterUpdate) return positionAfterUpdate } @@ -153,23 +153,31 @@ func (p *Pool) PositionTokensOwed1(key string) *u256.Uint { return p.mustGetPosition(key).tokensOwed1 } -func (p *Pool) mustGetPosition(key string) PositionInfo { - position, exist := p.positions[key] +// GetPosition returns the position info for a given key. +func (p *Pool) GetPosition(key string) (PositionInfo, bool) { + iPositionInfo, exist := p.positions.Get(key) if !exist { - panic(addDetailToError( - errDataNotFound, - ufmt.Sprintf("position(%s) does not exist", key), - )) + newPosition := PositionInfo{} + newPosition.valueOrZero() + return newPosition, false } - return position + + return iPositionInfo.(PositionInfo), true } -func (p *Pool) GetPosition(key string) PositionInfo { - position, exist := p.positions[key] +// setPosition sets the position info for a given key. +func (p *Pool) setPosition(posKey string, positionInfo PositionInfo) { + p.positions.Set(posKey, positionInfo) +} + +// mustGetPosition returns the position info for a given key. +func (p *Pool) mustGetPosition(positionKey string) PositionInfo { + positionInfo, exist := p.GetPosition(positionKey) if !exist { - newPosition := PositionInfo{} - newPosition.valueOrZero() - return newPosition + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("positionKey(%s) does not exist", positionKey), + )) } - return position + return positionInfo } diff --git a/pool/position_update.gno b/pool/position_update.gno index 827d91f10..9e3fe0304 100644 --- a/pool/position_update.gno +++ b/pool/position_update.gno @@ -88,10 +88,10 @@ func (p *Pool) updatePosition(positionParams ModifyPositionParams, tick int32) P // clear any tick data that is no longer needed if positionParams.liquidityDelta.IsNeg() { if flippedLower { - delete(p.ticks, positionParams.tickLower) + p.deleteTick(positionParams.tickLower) } if flippedUpper { - delete(p.ticks, positionParams.tickUpper) + p.deleteTick(positionParams.tickUpper) } } diff --git a/pool/position_update_test.gno b/pool/position_update_test.gno index ad2e67077..b54ae6d09 100644 --- a/pool/position_update_test.gno +++ b/pool/position_update_test.gno @@ -67,8 +67,8 @@ func TestUpdatePosition(t *testing.T) { } if !tt.positionParams.liquidityDelta.IsZero() { - lowerTick := p.ticks[tt.positionParams.tickLower] - upperTick := p.ticks[tt.positionParams.tickUpper] + lowerTick := p.mustGetTick(tt.positionParams.tickLower) + upperTick := p.mustGetTick(tt.positionParams.tickUpper) if !lowerTick.initialized { t.Error("lower tick not initialized") diff --git a/pool/protocol_fee_withdrawal_test.gno b/pool/protocol_fee_withdrawal_test.gno index 241459d91..7cbac5d39 100644 --- a/pool/protocol_fee_withdrawal_test.gno +++ b/pool/protocol_fee_withdrawal_test.gno @@ -55,8 +55,7 @@ func TestHandleWithdrawalFee(t *testing.T) { InitialisePoolTest(t) std.TestSetRealm(std.NewUserRealm(users.Resolve(position))) poolPath := GetPoolPath(wugnotPath, gnsPath, fee3000) - _, found := pools[poolPath] - if !found { + if !pools.Has(poolPath) { panic("pool not found") } TokenApprove(t, wugnotPath, alice, protocolFee, uint64(0)) diff --git a/pool/swap_test.gno b/pool/swap_test.gno index 3b73a5139..25cb7c4f8 100644 --- a/pool/swap_test.gno +++ b/pool/swap_test.gno @@ -4,6 +4,7 @@ import ( "std" "testing" + "gno.land/p/demo/avl" "gno.land/p/demo/uassert" "gno.land/r/demo/users" @@ -201,14 +202,13 @@ func TestComputeSwap(t *testing.T) { }, feeGrowthGlobal0X128: u256.Zero(), feeGrowthGlobal1X128: u256.Zero(), - tickBitmaps: make(TickBitmaps), - ticks: make(Ticks), - positions: make(Positions), + tickBitmaps: avl.NewTree(), + ticks: avl.NewTree(), + positions: avl.NewTree(), } wordPos, _ := tickBitmapPosition(0) - // TODO: use avl - mockPool.tickBitmaps[wordPos] = u256.NewUint(1) + mockPool.setTickBitmap(wordPos, u256.NewUint(1)) t.Run("basic swap", func(t *testing.T) { comp := SwapComputation{ @@ -280,7 +280,6 @@ func TestComputeSwap(t *testing.T) { } func TestSwap_Failures(t *testing.T) { - t.Skip() const addr = pusers.AddressOrName(consts.ROUTER_ADDR) tests := []struct { @@ -412,9 +411,9 @@ func TestDrySwap_Failures(t *testing.T) { }, feeGrowthGlobal0X128: u256.Zero(), feeGrowthGlobal1X128: u256.Zero(), - tickBitmaps: make(TickBitmaps), - ticks: make(Ticks), - positions: make(Positions), + tickBitmaps: avl.NewTree(), + ticks: avl.NewTree(), + positions: avl.NewTree(), } originalGetPool := GetPool diff --git a/pool/tick.gno b/pool/tick.gno index 63e070222..7c147d1eb 100644 --- a/pool/tick.gno +++ b/pool/tick.gno @@ -1,6 +1,8 @@ package pool import ( + "strconv" + "gno.land/p/demo/ufmt" i256 "gno.land/p/gnoswap/int256" @@ -282,14 +284,21 @@ func (p *Pool) tickCross( thisTick.feeGrowthOutside0X128 = new(u256.Uint).Sub(feeGrowthGlobal0X128, thisTick.feeGrowthOutside0X128) thisTick.feeGrowthOutside1X128 = new(u256.Uint).Sub(feeGrowthGlobal1X128, thisTick.feeGrowthOutside1X128) - p.ticks[tick] = thisTick + p.setTick(tick, thisTick) return thisTick.liquidityNet.Clone() } // setTick updates the tick data for the specified tick index in the pool. func (p *Pool) setTick(tick int32, newTickInfo TickInfo) { - p.ticks[tick] = newTickInfo + tickStr := strconv.Itoa(int(tick)) + p.ticks.Set(tickStr, newTickInfo) +} + +// deleteTick deletes the tick data for the specified tick index in the pool. +func (p *Pool) deleteTick(tick int32) { + tickStr := strconv.Itoa(int(tick)) + p.ticks.Remove(tickStr) } // getTick retrieves the TickInfo associated with the specified tick index from the pool. @@ -310,9 +319,15 @@ func (p *Pool) setTick(tick int32, newTickInfo TickInfo) { // This function ensures the retrieved tick data is always valid and safe for further operations, // such as calculations or updates, by sanitizing nil fields in the TickInfo structure. func (p *Pool) getTick(tick int32) TickInfo { - tickInfo := p.ticks[tick] - tickInfo.valueOrZero() - return tickInfo + tickStr := strconv.Itoa(int(tick)) + iTickInfo, exist := p.ticks.Get(tickStr) + if !exist { + tickInfo := TickInfo{} + tickInfo.valueOrZero() + return tickInfo + } + + return iTickInfo.(TickInfo) } // GetTickLiquidityGross returns the gross liquidity for the specified tick. @@ -379,7 +394,8 @@ func (p *Pool) GetTickInitialized(tick int32) bool { // tickInfo := pool.mustGetTick(10) // fmt.Println("Tick Info:", tickInfo) func (p *Pool) mustGetTick(tick int32) TickInfo { - tickInfo, exist := p.ticks[tick] + tickStr := strconv.Itoa(int(tick)) + iTickInfo, exist := p.ticks.Get(tickStr) if !exist { panic(addDetailToError( errDataNotFound, @@ -387,5 +403,5 @@ func (p *Pool) mustGetTick(tick int32) TickInfo { )) } - return tickInfo + return iTickInfo.(TickInfo) } diff --git a/pool/tick_bitmap.gno b/pool/tick_bitmap.gno index 53c9cc416..dc7369b7f 100644 --- a/pool/tick_bitmap.gno +++ b/pool/tick_bitmap.gno @@ -1,6 +1,9 @@ package pool import ( + "strconv" + + "gno.land/p/demo/ufmt" plp "gno.land/p/gnoswap/pool" u256 "gno.land/p/gnoswap/uint256" @@ -79,16 +82,32 @@ func (p *Pool) tickBitmapNextInitializedTickWithInOneWord( // getTickBitmap gets the tick bitmap for the given word position // if the tick bitmap is not initialized, initialize it to zero func (p *Pool) getTickBitmap(wordPos int16) *u256.Uint { - if p.tickBitmaps[wordPos] == nil { - p.tickBitmaps[wordPos] = u256.Zero() + wordPosStr := strconv.Itoa(int(wordPos)) + + if !p.tickBitmaps.Has(wordPosStr) { + p.initTickBitmap(wordPos) + } + + iU256, exist := p.tickBitmaps.Get(wordPosStr) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("tickBitmap(%d) does not exist", wordPos), + )) } - return p.tickBitmaps[wordPos] + return iU256.(*u256.Uint) } // setTickBitmap sets the tick bitmap for the given word position -func (p *Pool) setTickBitmap(wordPos int16, bitmap *u256.Uint) { - p.tickBitmaps[wordPos] = bitmap +func (p *Pool) setTickBitmap(wordPos int16, tickBitmap *u256.Uint) { + wordPosStr := strconv.Itoa(int(wordPos)) + p.tickBitmaps.Set(wordPosStr, tickBitmap) +} + +// initTickBitmap initializes the tick bitmap for the given word position +func (p *Pool) initTickBitmap(wordPos int16) { + p.setTickBitmap(wordPos, u256.Zero()) } // getWordAndBitPos gets tick's wordPos and bitPos depending on the swap direction diff --git a/pool/tick_bitmap_test.gno b/pool/tick_bitmap_test.gno index 37c8c047b..05d52a001 100644 --- a/pool/tick_bitmap_test.gno +++ b/pool/tick_bitmap_test.gno @@ -3,6 +3,7 @@ package pool import ( "testing" + "gno.land/p/demo/avl" u256 "gno.land/p/gnoswap/uint256" ) @@ -76,7 +77,7 @@ func TestTickBitmapFlipTick(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pool := &Pool{ - tickBitmaps: make(map[int16]*u256.Uint), + tickBitmaps: avl.NewTree(), } if tt.shouldPanic { @@ -92,7 +93,7 @@ func TestTickBitmapFlipTick(t *testing.T) { if !tt.shouldPanic { wordPos, bitPos := tickBitmapPosition(tt.tick / tt.tickSpacing) expected := new(u256.Uint).Lsh(u256.NewUint(1), uint(bitPos)) - if pool.tickBitmaps[wordPos].Cmp(expected) != 0 { + if pool.getTickBitmap(wordPos).Cmp(expected) != 0 { t.Errorf("bitmap not set correctly") } } @@ -137,7 +138,7 @@ func TestTickBitmapNextInitializedTickWithInOneWord(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pool := &Pool{ - tickBitmaps: make(map[int16]*u256.Uint), + tickBitmaps: avl.NewTree(), } if tt.setupBitmap != nil { tt.setupBitmap(pool) diff --git a/pool/tick_test.gno b/pool/tick_test.gno index 32fef15e7..d7a2aee05 100644 --- a/pool/tick_test.gno +++ b/pool/tick_test.gno @@ -3,6 +3,7 @@ package pool import ( "testing" + "gno.land/p/demo/avl" "gno.land/p/demo/uassert" i256 "gno.land/p/gnoswap/int256" @@ -60,24 +61,24 @@ func TestcalculateMaxLiquidityPerTick(t *testing.T) { func TestCalculateFeeGrowthInside(t *testing.T) { // Create a mock pool pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } // Setup test ticks - pool.ticks[0] = TickInfo{ + pool.ticks.Set("0", TickInfo{ liquidityGross: u256.NewUint(1000), liquidityNet: i256.NewInt(100), feeGrowthOutside0X128: u256.NewUint(5), feeGrowthOutside1X128: u256.NewUint(7), initialized: true, - } - pool.ticks[100] = TickInfo{ + }) + pool.ticks.Set("100", TickInfo{ liquidityGross: u256.NewUint(2000), liquidityNet: i256.NewInt(-100), feeGrowthOutside0X128: u256.NewUint(10), feeGrowthOutside1X128: u256.NewUint(15), initialized: true, - } + }) tests := []struct { name string @@ -261,7 +262,7 @@ func TestCalculateFeeGrowthInside(t *testing.T) { func TestTickUpdate(t *testing.T) { pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } tests := []struct { @@ -403,7 +404,7 @@ func TestTickUpdate(t *testing.T) { wantFlipped: false, shouldPanic: false, verify: func() { - info := pool.ticks[1] + info := pool.mustGetTick(1) uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "1") uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "2") }, @@ -424,7 +425,7 @@ func TestTickUpdate(t *testing.T) { wantFlipped: false, shouldPanic: false, verify: func() { - info := pool.ticks[1] + info := pool.mustGetTick(1) uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "1") uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "2") }, @@ -441,9 +442,9 @@ func TestTickUpdate(t *testing.T) { wantFlipped: false, shouldPanic: false, verify: func() { - info := pool.ticks[2] - uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "") - uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "") + info := pool.getTick(2) + uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "0") + uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "0") }, }, } @@ -494,17 +495,17 @@ func TestTickUpdate(t *testing.T) { func TestTickCross(t *testing.T) { pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } // Setup initial tick state - pool.ticks[100] = TickInfo{ + pool.ticks.Set("100", TickInfo{ liquidityGross: u256.NewUint(1000), liquidityNet: i256.NewInt(500), feeGrowthOutside0X128: u256.NewUint(10), feeGrowthOutside1X128: u256.NewUint(15), initialized: true, - } + }) tests := []struct { name string @@ -539,7 +540,7 @@ func TestTickCross(t *testing.T) { func TestGetTick(t *testing.T) { pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } // Setup a tick @@ -550,7 +551,7 @@ func TestGetTick(t *testing.T) { feeGrowthOutside1X128: u256.NewUint(15), initialized: true, } - pool.ticks[50] = expectedTick + pool.setTick(50, expectedTick) tests := []struct { name string @@ -710,9 +711,7 @@ func setTick( ) { t.Helper() - info := pool.ticks[tick] - info.valueOrZero() - + info := pool.getTick(tick) info.feeGrowthOutside0X128 = feeGrowthOutside0X128 info.feeGrowthOutside1X128 = feeGrowthOutside1X128 info.liquidityGross = liquidityGross @@ -722,10 +721,10 @@ func setTick( info.secondsOutside = secondsOutside info.initialized = initialized - pool.ticks[tick] = info + pool.setTick(tick, info) } func deleteTick(t *testing.T, pool *Pool, tick int32) { t.Helper() - delete(pool.ticks, tick) + pool.deleteTick(tick) } diff --git a/pool/type.gno b/pool/type.gno index 2cca5dbc9..a94fb96ff 100644 --- a/pool/type.gno +++ b/pool/type.gno @@ -3,6 +3,7 @@ package pool import ( "std" + "gno.land/p/demo/avl" "gno.land/p/demo/ufmt" "gno.land/r/gnoswap/v1/common" @@ -313,12 +314,6 @@ func (t *TickInfo) valueOrZero() { t.secondsPerLiquidityOutsideX128 = t.secondsPerLiquidityOutsideX128.NilToZero() } -type ( - Ticks map[int32]TickInfo // tick => TickInfo - TickBitmaps map[int16]*u256.Uint // tick(wordPos) => bitmap(tickWord ^ mask) - Positions map[string]PositionInfo // positionKey => PositionInfo -) - // type Pool describes a single Pool's state // A pool is identificed with a unique key (token0, token1, fee), where token0 < token1 type Pool struct { @@ -343,11 +338,11 @@ type Pool struct { liquidity *u256.Uint // total amount of liquidity in the pool - ticks Ticks // maps tick index to tick + ticks *avl.Tree // tick(int32) -> TickInfo - tickBitmaps TickBitmaps // maps tick index to tick bitmap + tickBitmaps *avl.Tree // tick(wordPos)(int16) -> bitMap(tickWord ^ mask)(*u256.Uint) - positions Positions // maps the key (caller, lower tick, upper tick) to a unique position + positions *avl.Tree // maps the key (caller, lower tick, upper tick) to a unique position } func newPool(poolInfo *createPoolParams) *Pool { @@ -367,9 +362,9 @@ func newPool(poolInfo *createPoolParams) *Pool { feeGrowthGlobal1X128: u256.Zero(), protocolFees: newProtocolFees(), liquidity: u256.Zero(), - ticks: Ticks{}, - tickBitmaps: TickBitmaps{}, - positions: Positions{}, + ticks: avl.NewTree(), + tickBitmaps: avl.NewTree(), + positions: avl.NewTree(), } } @@ -442,10 +437,12 @@ func (p *Pool) Liquidity() *u256.Uint { } func mustGetPool(poolPath string) *Pool { - pool, exist := pools[poolPath] - if !exist { - panic(addDetailToError(errDataNotFound, - ufmt.Sprintf("poolPath(%s) does not exist", poolPath))) + pool := GetPoolFromPoolPath(poolPath) + if pool == nil { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("expected poolPath(%s) to exist", poolPath), + )) } return pool } diff --git a/pool/utils.gno b/pool/utils.gno index c731d3618..9f3017af0 100644 --- a/pool/utils.gno +++ b/pool/utils.gno @@ -82,7 +82,7 @@ func safeConvertToInt128(value *u256.Uint) *i256.Int { // This function validates that the given `value` is properly initialized and checks whether // it exceeds the maximum value of uint128. If the value exceeds the uint128 range, // it applies a masking operation to truncate the value to fit within the uint128 limit. -// +//q // Parameters: // - value: *u256.Uint, the value to be checked and possibly truncated. // From 38d013a07449b990c1d7b10e177643f8cd3c7361 Mon Sep 17 00:00:00 2001 From: Blake <104744707+r3v4s@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:40:59 +0900 Subject: [PATCH 4/4] GSW-1845 refactor: gns.gno (#436) * GSW-1845 fix: remove deprecated require in gno.mod * GSW-1845 refactor: separate math logic from mint() * refactor: change function name to aviod confusion - from `Mint` to `MintGns` * test: reset gns token object * chore * feat: check functino to whether emission ended or not * test: detail unit test * test: mint amount test based on height skipping * feat: condition helper * refactor: better readability * fix: tc * Update _deploy/r/gnoswap/gns/gns.gno Co-authored-by: Dongwon <74406335+dongwon8247@users.noreply.github.com> * Update _deploy/r/gnoswap/gns/gns.gno Co-authored-by: Lee ByeongJun * chore: categorize getter/setter * feat: use min64 to make caluclate logic more straightforward * test: check burn method(burnt amount) effection on mintedAmount, leftEmissionAmount * chore: simplifiy skipIfSameHeight * fix: typo * chore: comments --------- Co-authored-by: Dongwon <74406335+dongwon8247@users.noreply.github.com> Co-authored-by: Lee ByeongJun --- _deploy/r/gnoswap/gns/_helper_test.gno | 20 ++ _deploy/r/gnoswap/gns/gno.mod | 12 - _deploy/r/gnoswap/gns/gns.gno | 209 +++++++++------ _deploy/r/gnoswap/gns/gns_test.gno | 249 ++++++++++++++++++ _deploy/r/gnoswap/gns/halving.gno | 14 +- .../tests/gns_calculate_and_mint_test.gnoA | 99 +++++++ _deploy/r/gnoswap/gns/tests/gns_test.gnoA | 129 --------- .../minted_and_left_emission_amount_test.gnoA | 89 +++++++ _deploy/r/gnoswap/gns/utils.gno | 26 ++ emission/emission.gno | 2 +- 10 files changed, 627 insertions(+), 222 deletions(-) create mode 100644 _deploy/r/gnoswap/gns/_helper_test.gno create mode 100644 _deploy/r/gnoswap/gns/gns_test.gno create mode 100644 _deploy/r/gnoswap/gns/tests/gns_calculate_and_mint_test.gnoA delete mode 100644 _deploy/r/gnoswap/gns/tests/gns_test.gnoA create mode 100644 _deploy/r/gnoswap/gns/tests/minted_and_left_emission_amount_test.gnoA diff --git a/_deploy/r/gnoswap/gns/_helper_test.gno b/_deploy/r/gnoswap/gns/_helper_test.gno new file mode 100644 index 000000000..2e9763908 --- /dev/null +++ b/_deploy/r/gnoswap/gns/_helper_test.gno @@ -0,0 +1,20 @@ +package gns + +import ( + "testing" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/ownable" + + "gno.land/r/gnoswap/v1/consts" +) + +func testResetGnsTokenObject(t *testing.T) { + t.Helper() + + Token, privateLedger = grc20.NewToken("Gnoswap", "GNS", 6) + UserTeller = Token.CallerTeller() + owner = ownable.NewWithAddress(consts.ADMIN) + + privateLedger.Mint(owner.Owner(), INITIAL_MINT_AMOUNT) +} diff --git a/_deploy/r/gnoswap/gns/gno.mod b/_deploy/r/gnoswap/gns/gno.mod index 8c541067d..67209d1d3 100644 --- a/_deploy/r/gnoswap/gns/gno.mod +++ b/_deploy/r/gnoswap/gns/gno.mod @@ -1,13 +1 @@ module gno.land/r/gnoswap/v1/gns - -require ( - gno.land/p/demo/grc/grc20 v0.0.0-latest - gno.land/p/demo/json v0.0.0-latest - gno.land/p/demo/ownable v0.0.0-latest - gno.land/p/demo/ufmt v0.0.0-latest - gno.land/p/demo/users v0.0.0-latest - gno.land/r/demo/users v0.0.0-latest - gno.land/r/demo/grc20reg v0.0.0-latest - gno.land/r/gnoswap/v1/common v0.0.0-latest - gno.land/r/gnoswap/v1/consts v0.0.0-latest -) diff --git a/_deploy/r/gnoswap/gns/gns.gno b/_deploy/r/gnoswap/gns/gns.gno index 327ffb448..d4c3a682f 100644 --- a/_deploy/r/gnoswap/gns/gns.gno +++ b/_deploy/r/gnoswap/gns/gns.gno @@ -12,43 +12,81 @@ import ( "gno.land/r/demo/grc20reg" "gno.land/r/demo/users" - "gno.land/r/gnoswap/v1/common" "gno.land/r/gnoswap/v1/consts" ) -const MAXIMUM_SUPPLY = uint64(1_000_000_000_000_000) // 1B +const ( + MAXIMUM_SUPPLY = uint64(1_000_000_000_000_000) + INITIAL_MINT_AMOUNT = uint64(100_000_000_000_000) + MAX_EMISSION_AMOUNT = uint64(900_000_000_000_000) // MAXIMUM_SUPPLY - INITIAL_MINT_AMOUNT +) + +var ( + lastMintedHeight = std.GetHeight() +) var ( - lastMintedHeight int64 - amountToEmission uint64 + // Initial amount set to 900_000_000_000_000 (MAXIMUM_SUPPLY - INITIAL_MINT_AMOUNT). + // leftEmissionAmount will decrease as tokens are minted. + leftEmissionAmount = MAX_EMISSION_AMOUNT + mintedEmissionAmount = uint64(0) ) var ( Token, privateLedger = grc20.NewToken("Gnoswap", "GNS", 6) UserTeller = Token.CallerTeller() - owner = ownable.NewWithAddress("g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d") // ADMIN + owner = ownable.NewWithAddress(consts.ADMIN) ) func init() { - privateLedger.Mint(owner.Owner(), 100_000_000_000_000) // 100_000_000 GNS ≈ 0.1B + privateLedger.Mint(owner.Owner(), INITIAL_MINT_AMOUNT) getter := func() *grc20.Token { return Token } grc20reg.Register(getter, "") +} + +// MintedEmissionAmount returns the amount of GNS that has been minted by the emission contract. +// It does not include initial minted amount. +func MintedEmissionAmount() uint64 { + return TotalSupply() - INITIAL_MINT_AMOUNT +} - amountToEmission = MAXIMUM_SUPPLY - uint64(100_000_000_000_000) +func MintGns(address pusers.AddressOrName) uint64 { + lastMintedHeight := GetLastMintedHeight() + currentHeight := std.GetHeight() - lastMintedHeight = std.GetHeight() + // skip minting process if gns for current block is already minted + if skipIfSameHeight(lastMintedHeight, currentHeight) { + return 0 + } + + assertShouldNotBeHalted() + assertCallerIsEmission() + + // calculate gns amount to mint, and the mint to the target address + amountToMint := calculateAmountToMint(lastMintedHeight+1, currentHeight) + err := privateLedger.Mint(users.Resolve(address), amountToMint) + if err != nil { + panic(err.Error()) + } + + // update + setLastMintedHeight(currentHeight) + setMintedEmissionAmount(GetMintedEmissionAmount() + amountToMint) + setLeftEmissionAmount(GetLeftEmissionAmount() - amountToMint) + + return amountToMint } -func GetAmountToEmission() uint64 { return amountToEmission } +func Burn(from pusers.AddressOrName, amount uint64) { + owner.AssertCallerIsOwner() + fromAddr := users.Resolve(from) + checkErr(privateLedger.Burn(fromAddr, amount)) +} func TotalSupply() uint64 { return UserTeller.TotalSupply() } -func TotalMinted() uint64 { - return UserTeller.TotalSupply() - uint64(100_000_000_000_000) -} - func BalanceOf(owner pusers.AddressOrName) uint64 { ownerAddr := users.Resolve(owner) return UserTeller.BalanceOf(ownerAddr) @@ -93,93 +131,106 @@ func Render(path string) string { } } -// Mint mints GNS to the address. -// Only emission contract can call Mint. -func Mint(address pusers.AddressOrName) uint64 { - common.IsHalted() - - caller := std.PrevRealm().Addr() - if caller != consts.EMISSION_ADDR { - panic(addDetailToError( - errNoPermission, - ufmt.Sprintf("only emission contract(%s) can call Mint, called from %s", consts.EMISSION_ADDR, caller.String()), - )) +func checkErr(err error) { + if err != nil { + panic(err.Error()) } +} - // if not yet initialized, mint 0 amount - if initialized == false { - return 0 - } +// helper functions - // calculate gns emission amount for every block, and send by single call - // for this case, we assume that inside of block range gnoswap state hasn't changed. - nowHeight := std.GetHeight() - amountToMint := uint64(0) +// calculateAmountToMint calculates the amount of gns to mint +// It calculates the amount of gns to mint for each halving year for block range. +// It also handles the left emission amount if the current block range includes halving year end block. +func calculateAmountToMint(fromHeight, toHeight int64) uint64 { + fromYear := GetHalvingYearByHeight(fromHeight) + toYear := GetHalvingYearByHeight(toHeight) - if lastMintedHeight >= nowHeight { + if isEmissionEnded(fromYear) || isEmissionEnded(toYear) { return 0 } - // If from, to block is at same halving year, no need iterate - fromYear := GetHalvingYearByHeight(lastMintedHeight + 1) - toYear := GetHalvingYearByHeight(nowHeight) - - if fromYear == toYear { - numBlock := uint64(nowHeight - lastMintedHeight) - singleBlockAmount := GetAmountByHeight(nowHeight) - totalBlockAmount := singleBlockAmount * numBlock - - amountToMint += totalBlockAmount - amountToMint = checkAndHandleIfLastBlockOfHalvingYear(nowHeight, amountToMint) - - halvingYearMintAmount[fromYear] += totalBlockAmount - } else { - for i := lastMintedHeight + 1; i <= nowHeight; i++ { - amount := GetAmountByHeight(i) - amount = checkAndHandleIfLastBlockOfHalvingYear(i, amount) - year := GetHalvingYearByHeight(i) - halvingYearMintAmount[year] += amount - amountToMint += amount + totalAmountToMint := uint64(0) + + for i := fromYear; i <= toYear; i++ { + yearEndHeight := GetHalvingYearBlock(i) + mintUntilHeight := i64Min(yearEndHeight, toHeight) + + // how many blocks to calculate + blocks := uint64(mintUntilHeight-fromHeight) + 1 + + // amount of gns to mint for each block for current year + singleBlockAmount := GetAmountByHeight(yearEndHeight) + + // amount of gns to mint for current year + yearAmountToMint := singleBlockAmount * blocks + + // if last block of halving year, handle left emission amount + if isLastBlockOfHalvingYear(mintUntilHeight) { + yearAmountToMint += handleLeftEmissionAmount(i, yearAmountToMint) } - } + totalAmountToMint += yearAmountToMint + SetHalvingYearMintAmount(i, GetHalvingYearMintAmount(i)+yearAmountToMint) - err := privateLedger.Mint(users.Resolve(address), amountToMint) - if err != nil { - panic(err.Error()) + // update fromHeight for next year (if necessary) + fromHeight = mintUntilHeight + 1 } - lastMintedHeight = nowHeight - - return amountToMint + return totalAmountToMint } -func Burn(from pusers.AddressOrName, amount uint64) { - owner.AssertCallerIsOwner() - fromAddr := users.Resolve(from) - checkErr(privateLedger.Burn(fromAddr, amount)) +// isLastBlockOfHalvingYear returns true if the current block is the last block of a halving year. +func isLastBlockOfHalvingYear(height int64) bool { + year := GetHalvingYearByHeight(height) + lastBlock := GetHalvingYearBlock(year) + + return height == lastBlock } -func checkAndHandleIfLastBlockOfHalvingYear(height int64, amount uint64) uint64 { - year := GetHalvingYearByHeight(height) - lastBlock := halvingYearBlock[year] - if height == lastBlock { - leftForThisYear := halvingYearAmount[year] - halvingYearMintAmount[year] - amount = leftForThisYear - return amount - } +// handleLeftEmissionAmount handles the left emission amount for a halving year. +// It calculates the left emission amount by subtracting the halving year mint amount from the halving year amount. +func handleLeftEmissionAmount(year int64, amount uint64) uint64 { + return GetHalvingYearAmount(year) - GetHalvingYearMintAmount(year) - amount +} - return amount +// skipIfSameHeight returns true if the current block height is the same as the last minted height. +// This prevents multiple gns minting inside the same block. +func skipIfSameHeight(lastMintedHeight, currentHeight int64) bool { + return lastMintedHeight == currentHeight } -func checkErr(err error) { - if err != nil { - panic(err.Error()) +// isEmissionEnded returns true if the emission is ended. +// It returns false if the emission is not ended. +func isEmissionEnded(year int64) bool { + if 1 <= year && year <= 12 { + return false } + + return true } -// TODO: -// 1. when emission contract mint gns reward, last executed height should be get from gns contract. -// mint function of gns contract and mintGns function of emission contract should be synchronized. +// Getter func GetLastMintedHeight() int64 { return lastMintedHeight } + +func GetLeftEmissionAmount() uint64 { + return leftEmissionAmount +} + +func GetMintedEmissionAmount() uint64 { + return mintedEmissionAmount +} + +// Setter +func setLastMintedHeight(height int64) { + lastMintedHeight = height +} + +func setLeftEmissionAmount(amount uint64) { + leftEmissionAmount = amount +} + +func setMintedEmissionAmount(amount uint64) { + mintedEmissionAmount = amount +} diff --git a/_deploy/r/gnoswap/gns/gns_test.gno b/_deploy/r/gnoswap/gns/gns_test.gno new file mode 100644 index 000000000..c7657f9bd --- /dev/null +++ b/_deploy/r/gnoswap/gns/gns_test.gno @@ -0,0 +1,249 @@ +package gns + +import ( + "fmt" + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + "gno.land/r/gnoswap/v1/consts" +) + +const ( + // gnoVM test context default height + // ref: https://github.com/gnolang/gno/blob/a85a53d5b38f0a21d66262a823a8b07f4f836b68/gnovm/pkg/test/test.go#L31-L32 + GNO_VM_DEFAULT_HEIGHT = int64(123) +) + +var ( + emissionRealm = std.NewCodeRealm(consts.EMISSION_PATH) + adminRealm = std.NewUserRealm(consts.ADMIN) +) + +var ( + alice = testutils.TestAddress("alice") + bob = testutils.TestAddress("bob") +) + +func TestIsLastBlockOfHalvingYear(t *testing.T) { + tests := make([]struct { + name string + height int64 + want bool + }, 0, 24) + + for i := int64(1); i <= 12; i++ { + tests = append(tests, struct { + name string + height int64 + want bool + }{ + name: fmt.Sprintf("last block of halving year %d", i), + height: halvingYearBlock[i], + want: true, + }) + + tests = append(tests, struct { + name string + height int64 + want bool + }{ + name: fmt.Sprintf("not last block of halving year %d", i), + height: halvingYearBlock[i] - 1, + want: false, + }) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uassert.Equal(t, tt.want, isLastBlockOfHalvingYear(tt.height)) + }) + } +} + +func TestHandleLeftEmissionAmount(t *testing.T) { + tests := make([]struct { + name string + year int64 + amount uint64 + want uint64 + }, 0, 24) + + for i := int64(1); i <= 12; i++ { + tests = append(tests, struct { + name string + year int64 + amount uint64 + want uint64 + }{ + name: fmt.Sprintf("handle left emission amount for year %d, non minted", i), + year: i, + amount: 0, + want: halvingYearAmount[i], + }) + + tests = append(tests, struct { + name string + year int64 + amount uint64 + want uint64 + }{ + name: fmt.Sprintf("handle left emission amount for year %d, minted", i), + year: i, + amount: uint64(123456), + want: halvingYearAmount[i] - uint64(123456), + }) + } +} + +func TestSkipIfSameHeight(t *testing.T) { + t.Run("should skip if height is same", func(t *testing.T) { + uassert.True(t, skipIfSameHeight(1, 1)) + }) + + t.Run("should not skip if height is different", func(t *testing.T) { + uassert.False(t, skipIfSameHeight(1, 2)) + }) +} + +func TestGetterSetter(t *testing.T) { + t.Run("last minted height", func(t *testing.T) { + value := int64(1234) + setLastMintedHeight(value) + uassert.Equal(t, value, GetLastMintedHeight()) + }) + + t.Run("left emission amount", func(t *testing.T) { + value := uint64(123456) + setLeftEmissionAmount(value) + uassert.Equal(t, value, GetLeftEmissionAmount()) + }) +} + +func TestGrc20Methods(t *testing.T) { + tests := []struct { + name string + fn func() + shouldPanic bool + panicMsg string + }{ + { + name: "TotalSupply", + fn: func() { + uassert.Equal(t, INITIAL_MINT_AMOUNT, TotalSupply()) + }, + }, + { + name: "BalanceOf(admin)", + fn: func() { + uassert.Equal(t, INITIAL_MINT_AMOUNT, BalanceOf(a2u(consts.ADMIN))) + }, + }, + { + name: "BalanceOf(alice)", + fn: func() { + uassert.Equal(t, uint64(0), BalanceOf(a2u(alice))) + }, + }, + { + name: "Allowance(admin, alice)", + fn: func() { + uassert.Equal(t, uint64(0), Allowance(a2u(consts.ADMIN), a2u(alice))) + }, + }, + { + name: "MintGns success", + fn: func() { + std.TestSetRealm(emissionRealm) + MintGns(a2u(consts.ADMIN)) + }, + }, + { + name: "MintGns without permission should panic", + fn: func() { + std.TestSkipHeights(1) + MintGns(a2u(consts.ADMIN)) + }, + shouldPanic: true, + panicMsg: `caller(g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm) has no permission`, + }, + { + name: "Burn success", + fn: func() { + std.TestSetRealm(adminRealm) + Burn(a2u(consts.ADMIN), uint64(1)) + }, + }, + { + name: "Burn without permission should panic", + fn: func() { + Burn(a2u(consts.ADMIN), uint64(1)) + }, + shouldPanic: true, + panicMsg: `ownable: caller is not owner`, + }, + { + name: "Transfer success", + fn: func() { + std.TestSetRealm(adminRealm) + Transfer(a2u(alice), uint64(1)) + }, + }, + { + name: "Transfer without enough balance should panic", + fn: func() { + std.TestSetRealm(std.NewUserRealm(alice)) + Transfer(a2u(bob), uint64(1)) + }, + shouldPanic: true, + panicMsg: `insufficient balance`, + }, + { + name: "Transfer to self should panic", + fn: func() { + std.TestSetRealm(adminRealm) + Transfer(a2u(consts.ADMIN), uint64(1)) + }, + shouldPanic: true, + panicMsg: `cannot send transfer to self`, + }, + { + name: "TransferFrom success", + fn: func() { + // approve first + std.TestSetRealm(adminRealm) + Approve(a2u(alice), uint64(1)) + + // alice transfer admin's balance to bob + std.TestSetRealm(std.NewUserRealm(alice)) + TransferFrom(a2u(consts.ADMIN), a2u(bob), uint64(1)) + }, + }, + { + name: "TransferFrom without enough allowance should panic", + fn: func() { + std.TestSetRealm(adminRealm) + Approve(a2u(alice), uint64(1)) + + std.TestSetRealm(std.NewUserRealm(alice)) + TransferFrom(a2u(consts.ADMIN), a2u(bob), uint64(2)) + }, + shouldPanic: true, + panicMsg: `insufficient allowance`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testResetGnsTokenObject(t) + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, tt.fn) + } else { + uassert.NotPanics(t, func() { tt.fn() }) + } + }) + } +} diff --git a/_deploy/r/gnoswap/gns/halving.gno b/_deploy/r/gnoswap/gns/halving.gno index 9907be7b3..934a34bd0 100644 --- a/_deploy/r/gnoswap/gns/halving.gno +++ b/_deploy/r/gnoswap/gns/halving.gno @@ -153,7 +153,7 @@ func setAvgBlockTimeInMs(ms int64) { blockLeft := timeLeftMs / avgBlockTimeMs // how many reward left to next halving - minted := TotalMinted() + minted := MintedEmissionAmount() amountLeft := halvingYearAccuAmount[year] - minted // how much reward per block @@ -176,6 +176,18 @@ func setAvgBlockTimeInMs(ms int64) { } } +func GetHalvingYearAmount(year int64) uint64 { + return halvingYearAmount[year] +} + +func GetHalvingYearMintAmount(year int64) uint64 { + return halvingYearMintAmount[year] +} + +func SetHalvingYearMintAmount(year int64, amount uint64) { + halvingYearMintAmount[year] = amount +} + func GetAmountByHeight(height int64) uint64 { halvingYear := GetHalvingYearByHeight(height) diff --git a/_deploy/r/gnoswap/gns/tests/gns_calculate_and_mint_test.gnoA b/_deploy/r/gnoswap/gns/tests/gns_calculate_and_mint_test.gnoA new file mode 100644 index 000000000..a5cfd9931 --- /dev/null +++ b/_deploy/r/gnoswap/gns/tests/gns_calculate_and_mint_test.gnoA @@ -0,0 +1,99 @@ +package gns + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" +) + +var ( + emissionRealm = std.NewCodeRealm(consts.EMISSION_PATH) + adminRealm = std.NewUserRealm(consts.ADMIN) +) + +func TestCalculateAmountToMint(t *testing.T) { + t.Run("1 block for same year 01", func(t *testing.T) { + amount := calculateAmountToMint(GetLastMintedHeight()+1, GetLastMintedHeight()+1) + uassert.Equal(t, amountPerBlockPerHalvingYear[1], amount) + }) + + t.Run("2 block for same year 01", func(t *testing.T) { + amount := calculateAmountToMint(GetLastMintedHeight()+1, GetLastMintedHeight()+2) + uassert.Equal(t, amountPerBlockPerHalvingYear[2]*2, amount) + }) + + t.Run("entire block for year 01 + 1 block for year 02", func(t *testing.T) { + calculateAmountToMint(halvingYearBlock[1], halvingYearBlock[1]+1) + + // minted all amount for year 01 + uassert.Equal(t, GetHalvingYearAmount(1), GetHalvingYearMintAmount(1)) + + // minted 1 block for year 02 + uassert.Equal(t, amountPerBlockPerHalvingYear[2], GetHalvingYearMintAmount(2)) + }) + + t.Run("entire block for 12 years", func(t *testing.T) { + calculateAmountToMint(halvingYearBlock[1], halvingYearBlock[12]) + + for year := int64(1); year <= 12; year++ { + uassert.Equal(t, GetHalvingYearAmount(year), GetHalvingYearMintAmount(year)) + } + }) + + t.Run("no emission amount for after 12 years", func(t *testing.T) { + amount := calculateAmountToMint(halvingYearBlock[12], halvingYearBlock[12]+1) + uassert.Equal(t, uint64(0), amount) + }) + + // clear for further test + halvingYearMintAmount = make(map[int64]uint64) +} + +func TestMintGns(t *testing.T) { + t.Run("panic for swap is halted", func(t *testing.T) { + std.TestSetRealm(adminRealm) + common.SetHaltByAdmin(true) // set halt + uassert.PanicsWithMessage(t, "[GNOSWAP-COMMON-002] halted || gnoswap halted", func() { + MintGns(a2u(consts.ADMIN)) + }) + + common.SetHaltByAdmin(false) // unset halt + }) + + t.Run("panic if caller is not emission contract", func(t *testing.T) { + uassert.PanicsWithMessage(t, "caller(g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm) has no permission", func() { + MintGns(a2u(consts.ADMIN)) + }) + }) + + t.Run("do not mint for same block", func(t *testing.T) { + std.TestSetRealm(emissionRealm) + mintedAmount := MintGns(a2u(consts.ADMIN)) + uassert.Equal(t, uint64(0), mintedAmount) + }) + + t.Run("mint by year, until emission ends", func(t *testing.T) { + for year := int64(1); year <= 12; year++ { + std.TestSkipHeights(BLOCK_PER_YEAR) + + std.TestSetRealm(emissionRealm) + mintedAmount := MintGns(a2u(consts.ADMIN)) + + uassert.Equal(t, halvingYearAmount[year], mintedAmount) + uassert.Equal(t, halvingYearAmount[year], halvingYearMintAmount[year]) + uassert.Equal(t, halvingYearAccuAmount[year], MintedEmissionAmount()) + } + }) + + t.Run("no more emission after it ends", func(t *testing.T) { + std.TestSkipHeights(BLOCK_PER_YEAR) + + std.TestSetRealm(emissionRealm) + mintedAmount := MintGns(a2u(consts.ADMIN)) + uassert.Equal(t, uint64(0), mintedAmount) + }) +} diff --git a/_deploy/r/gnoswap/gns/tests/gns_test.gnoA b/_deploy/r/gnoswap/gns/tests/gns_test.gnoA deleted file mode 100644 index 0b17cfc2a..000000000 --- a/_deploy/r/gnoswap/gns/tests/gns_test.gnoA +++ /dev/null @@ -1,129 +0,0 @@ -package gns - -import ( - "std" - "testing" - - "gno.land/p/demo/testutils" - "gno.land/p/demo/uassert" - pusers "gno.land/p/demo/users" - - "gno.land/r/gnoswap/v1/consts" -) - -func TestMint(t *testing.T) { - t.Run("initial mint", func(t *testing.T) { - uassert.Equal(t, uint64(100_000_000_000_000), TotalSupply()) - uassert.Equal(t, int64(123), lastMintedHeight) - }) - - t.Run("panic if not emission", func(t *testing.T) { - uassert.PanicsWithMessage(t, - `[GNOSWAP-GNS-001] caller has no permission || only emission contract(g10xg6559w9e93zfttlhvdmaaa0er3zewcr7nh20) can call Mint, called from g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm`, - func() { Mint(pusers.AddressOrName(testutils.TestAddress("dummy"))) }) - }) - - t.Run("no block mined", func(t *testing.T) { - std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) - Mint(pusers.AddressOrName(consts.EMISSION_ADDR)) - - uassert.Equal(t, uint64(100_000_000_000_000), TotalSupply()) - uassert.Equal(t, int64(123), lastMintedHeight) - }) - - t.Run("1 block mined", func(t *testing.T) { - std.TestSkipHeights(1) - - std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) - Mint(pusers.AddressOrName(consts.EMISSION_ADDR)) - - uassert.Equal(t, uint64(100_000_000_000_000+14_269_406), TotalSupply()) - uassert.Equal(t, int64(124), lastMintedHeight) - }) - - t.Run("10 blocks mined", func(t *testing.T) { - std.TestSkipHeights(10) - - std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) - Mint(pusers.AddressOrName(consts.EMISSION_ADDR)) - - uassert.Equal(t, uint64(100000014269406+142_694_060), TotalSupply()) - uassert.Equal(t, int64(134), lastMintedHeight) - }) - - t.Run("reach first halving year", func(t *testing.T) { - height := std.GetHeight() - uassert.Equal(t, int64(134), height) - - year := GetHalvingYearByHeight(height) - uassert.Equal(t, int64(1), year) - - yearEndHeight := halvingYearBlock[year] - uassert.Equal(t, int64(15768123), yearEndHeight) - - leftBlock := yearEndHeight - height - uassert.Equal(t, int64(15767989), leftBlock) - - std.TestSkipHeights(15767980) // 9 block left to next halving year - - std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) - Mint(pusers.AddressOrName(consts.EMISSION_ADDR)) - - uassert.Equal(t, uint64(100000156963466+224_999_708_419_880), TotalSupply()) - // 324999865383346 - }) - - t.Run("year01 and year02", func(t *testing.T) { - height := std.GetHeight() - uassert.Equal(t, int64(15768114), height) - - year := GetHalvingYearByHeight(height) - uassert.Equal(t, int64(1), year) - - yearEndHeight := halvingYearBlock[year] - uassert.Equal(t, int64(15768123), yearEndHeight) - - leftBlock := yearEndHeight - height - uassert.Equal(t, int64(9), leftBlock) // 9 block left - - std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) - std.TestSkipHeights(9) // year 1 ends - std.TestSkipHeights(10) // year 2 starts and mined 10 blocks - - Mint(pusers.AddressOrName(consts.EMISSION_ADDR)) - - uassert.Equal(t, halvingYearMintAmount[1], uint64(225000000000000)) - uassert.Equal(t, halvingYearAmount[1], uint64(225000000000000)) - - uassert.Equal(t, TotalSupply(), uint64(325000142694060)) - // 325000142694060 - 324999865383346 - // minted: 277310714 - - // year 1 - // block 15768115: 14_269_406 - // block 15768116: 14_269_406 - // block 15768117: 14_269_406 - // block 15768118: 14_269_406 - // block 15768119: 14_269_406 - // block 15768120: 14_269_406 - // block 15768121: 14_269_406 - // block 15768122: 14_269_406 - // block 15768123: 14_269_406 + (left) 6_192_000 - // 128424654 + 6192000 = 134616654 - - // year 2 - // block 15768124: 14_269_406 - // block 15768125: 14_269_406 - // block 15768126: 14_269_406 - // block 15768127: 14_269_406 - // block 15768128: 14_269_406 - // block 15768129: 14_269_406 - // block 15768130: 14_269_406 - // block 15768131: 14_269_406 - // block 15768132: 14_269_406 - // block 15768133: 14_269_406 - // 142694060 - - // 134616654 + 142694060 = 277310714 - }) -} diff --git a/_deploy/r/gnoswap/gns/tests/minted_and_left_emission_amount_test.gnoA b/_deploy/r/gnoswap/gns/tests/minted_and_left_emission_amount_test.gnoA new file mode 100644 index 000000000..aee171e6b --- /dev/null +++ b/_deploy/r/gnoswap/gns/tests/minted_and_left_emission_amount_test.gnoA @@ -0,0 +1,89 @@ +package gns + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + + "gno.land/r/gnoswap/v1/consts" +) + +var ( + emissionRealm = std.NewCodeRealm(consts.EMISSION_PATH) + adminRealm = std.NewUserRealm(consts.ADMIN) +) + +func TestCheckInitialData(t *testing.T) { + t.Run("totalSupply", func(t *testing.T) { + uassert.Equal(t, INITIAL_MINT_AMOUNT, TotalSupply()) + }) + + t.Run("mintedAmount", func(t *testing.T) { + uassert.Equal(t, uint64(0), MintedEmissionAmount()) + }) + + t.Run("leftEmissionAmount", func(t *testing.T) { + uassert.Equal(t, MAX_EMISSION_AMOUNT, GetLeftEmissionAmount()) + }) +} + +func TestMintAndCheckRelativeData(t *testing.T) { + // before mint + oldTotalSupply := TotalSupply() + oldMintedAmount := MintedEmissionAmount() + oldLeftEmissionAmount := GetLeftEmissionAmount() + + // mint + mintAmountFor10Blocks := uint64(142694060) + t.Run("mint for 10 blocks", func(t *testing.T) { + std.TestSetRealm(emissionRealm) + std.TestSkipHeights(10) + mintedAmount := MintGns(a2u(consts.ADMIN)) + uassert.Equal(t, mintAmountFor10Blocks, mintedAmount) + }) + + // after mint + t.Run("increment of totalSupply", func(t *testing.T) { + uassert.Equal(t, oldTotalSupply+mintAmountFor10Blocks, TotalSupply()) + }) + + t.Run("increment of mintedAmount", func(t *testing.T) { + uassert.Equal(t, oldMintedAmount+mintAmountFor10Blocks, MintedEmissionAmount()) + }) + + t.Run("decrement of leftEmissionAmount", func(t *testing.T) { + uassert.Equal(t, oldLeftEmissionAmount-mintAmountFor10Blocks, GetLeftEmissionAmount()) + }) +} + +func TestBurnAndCheckRelativeData(t *testing.T) { + // before burn + oldTotalSupply := TotalSupply() + oldMintedAmount := MintedEmissionAmount() + oldLeftEmissionAmount := GetLeftEmissionAmount() + + // burn + burnAmount := uint64(100000000) + t.Run("burn amount", func(t *testing.T) { + std.TestSetRealm(adminRealm) + Burn(a2u(consts.ADMIN), burnAmount) + }) + + // after burn + t.Run("decrement of totalSupply", func(t *testing.T) { + uassert.Equal(t, oldTotalSupply-burnAmount, TotalSupply()) + }) + + t.Run("decrement of mintedAmount", func(t *testing.T) { + uassert.Equal(t, oldMintedAmount-burnAmount, MintedEmissionAmount()) + }) + + t.Run("totalSupply should be same with (INITIAL_MINT_AMOUNT) + (mintedEmissionAmount)", func(t *testing.T) { + uassert.Equal(t, TotalSupply(), INITIAL_MINT_AMOUNT+MintedEmissionAmount()) + }) + + t.Run("same for leftEmissionAmount", func(t *testing.T) { + uassert.Equal(t, oldLeftEmissionAmount, GetLeftEmissionAmount()) + }) +} diff --git a/_deploy/r/gnoswap/gns/utils.gno b/_deploy/r/gnoswap/gns/utils.gno index 5948421ee..6aa58902e 100644 --- a/_deploy/r/gnoswap/gns/utils.gno +++ b/_deploy/r/gnoswap/gns/utils.gno @@ -2,9 +2,35 @@ package gns import ( "std" + + pusers "gno.land/p/demo/users" + + "gno.land/r/gnoswap/v1/common" ) func getPrev() (string, string) { prev := std.PrevRealm() return prev.Addr().String(), prev.PkgPath() } + +func a2u(addr std.Address) pusers.AddressOrName { + return pusers.AddressOrName(addr) +} + +func assertShouldNotBeHalted() { + common.IsHalted() +} + +func assertCallerIsEmission() { + caller := std.PrevRealm().Addr() + if err := common.EmissionOnly(caller); err != nil { + panic(err) + } +} + +func i64Min(x, y int64) int64 { + if x < y { + return x + } + return y +} diff --git a/emission/emission.gno b/emission/emission.gno index ff2fcbb43..f781acc48 100644 --- a/emission/emission.gno +++ b/emission/emission.gno @@ -49,7 +49,7 @@ func MintAndDistributeGns() { return } - mintedEmissionRewardAmount := mintGns() + mintedEmissionRewardAmount := gns.MintGns(a2u(consts.EMISSION_ADDR)) if hasLeftGNSAmount() { mintedEmissionRewardAmount += GetLeftGNSAmount() SetLeftGNSAmount(0)