diff --git a/_deploy/r/gnoswap/consts/consts.gno b/_deploy/r/gnoswap/consts/consts.gno index 9f9228e9..d2ebff66 100644 --- a/_deploy/r/gnoswap/consts/consts.gno +++ b/_deploy/r/gnoswap/consts/consts.gno @@ -6,9 +6,8 @@ import ( // GNOSWAP SERVICE const ( - ADMIN std.Address = "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d" // Admin - DEV_OPS std.Address = "g1mjvd83nnjee3z2g7683er55me9f09688pd4mj9" // DevOps - + ADMIN std.Address = "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d" + DEV_OPS std.Address = "g1mjvd83nnjee3z2g7683er55me9f09688pd4mj9" TOKEN_REGISTER std.Address = "g1er355fkjksqpdtwmhf5penwa82p0rhqxkkyhk5" TOKEN_REGISTER_NAMESPACE string = "gno.land/r/g1er355fkjksqpdtwmhf5penwa82p0rhqxkkyhk5" @@ -21,7 +20,8 @@ const ( GNOT string = "gnot" WRAPPED_WUGNOT string = "gno.land/r/demo/wugnot" - UGNOT_MIN_DEPOSIT_TO_WRAP uint64 = 1000 // defined in https://github.com/gnolang/gno/blob/81a88a2976ba9f2f9127ebbe7fb7d1e1f7fa4bd4/examples/gno.land/r/demo/wugnot/wugnot.gno#L19 + // defined in https://github.com/gnolang/gno/blob/81a88a2976ba9f2f9127ebbe7fb7d1e1f7fa4bd4/examples/gno.land/r/demo/wugnot/wugnot.gno#L19 + UGNOT_MIN_DEPOSIT_TO_WRAP uint64 = 1000 ) // CONTRACT PATH & ADDRESS @@ -91,9 +91,11 @@ const ( MAX_UINT128 string = "340282366920938463463374607431768211455" MAX_UINT160 string = "1461501637330902918203684832716283019655932542975" - MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819968" MAX_UINT256 string = "115792089237316195423570985008687907853269984665640564039457584007913129639935" + MAX_INT128 string = "170141183460469231731687303715884105727" + MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819968" + // Tick Related MIN_TICK int32 = -887272 MAX_TICK int32 = 887272 @@ -108,6 +110,8 @@ const ( Q64 string = "18446744073709551616" // 2 ** 64 Q96 string = "79228162514264337593543950336" // 2 ** 96 Q128 string = "340282366920938463463374607431768211456" // 2 ** 128 + + Q128_RESOLUTION uint = 128 ) // TIMESTAMP & DAY diff --git a/pool/errors.gno b/pool/errors.gno index 11c4efa5..c98286a4 100644 --- a/pool/errors.gno +++ b/pool/errors.gno @@ -31,8 +31,13 @@ var ( errTransferFailed = errors.New("[GNOSWAP-POOL-021] token transfer failed") errInvalidTickAndTickSpacing = errors.New("[GNOSWAP-POOL-022] invalid tick and tick spacing requested") errInvalidAddress = errors.New("[GNOSWAP-POOL-023] invalid address") - errInvalidTickRange = errors.New("[GNOSWAP-POOL-024] tickLower is greater than tickUpper") - errUnderflow = errors.New("[GNOSWAP-POOL-025] underflow") // TODO: make as common error code + errInvalidTickRange = errors.New("[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper") + errUnderflow = errors.New("[GNOSWAP-POOL-025] underflow") + errOverFlow = errors.New("[GNOSWAP-POOL-026] overflow") + errBalanceUpdateFailed = errors.New("[GNOSWAP-POOL-027] balance update failed") + errTickLowerInvalid = errors.New("[GNOSWAP-POOL-028] tickLower is invalid") + errTickUpperInvalid = errors.New("[GNOSWAP-POOL-029] tickUpper is invalid") + errTickLowerGtTickUpper = errors.New("[GNOSWAP-POOL-030] tickLower is greater than tickUpper") ) // addDetailToError adds detail to an error message diff --git a/pool/getter.gno b/pool/getter.gno index f962846f..44351614 100644 --- a/pool/getter.gno +++ b/pool/getter.gno @@ -1,6 +1,5 @@ package pool -// pool func PoolGetPoolList() []string { poolPaths := []string{} for poolPath, _ := range pools { @@ -11,91 +10,89 @@ func PoolGetPoolList() []string { } func PoolGetToken0Path(poolPath string) string { - return mustGetPool(poolPath).GetToken0Path() + return mustGetPool(poolPath).Token0Path() } func PoolGetToken1Path(poolPath string) string { - return mustGetPool(poolPath).GetToken1Path() + return mustGetPool(poolPath).Token1Path() } func PoolGetFee(poolPath string) uint32 { - return mustGetPool(poolPath).GetFee() + return mustGetPool(poolPath).Fee() } func PoolGetBalanceToken0(poolPath string) string { - return mustGetPool(poolPath).GetBalanceToken0().ToString() + return mustGetPool(poolPath).BalanceToken0().ToString() } func PoolGetBalanceToken1(poolPath string) string { - return mustGetPool(poolPath).GetBalanceToken1().ToString() + return mustGetPool(poolPath).BalanceToken1().ToString() } func PoolGetTickSpacing(poolPath string) int32 { - return mustGetPool(poolPath).GetTickSpacing() + return mustGetPool(poolPath).TickSpacing() } func PoolGetMaxLiquidityPerTick(poolPath string) string { - return mustGetPool(poolPath).GetMaxLiquidityPerTick().ToString() + return mustGetPool(poolPath).MaxLiquidityPerTick().ToString() } func PoolGetSlot0SqrtPriceX96(poolPath string) string { - return mustGetPool(poolPath).GetSlot0SqrtPriceX96().ToString() + return mustGetPool(poolPath).Slot0SqrtPriceX96().ToString() } func PoolGetSlot0Tick(poolPath string) int32 { - return mustGetPool(poolPath).GetSlot0Tick() + return mustGetPool(poolPath).Slot0Tick() } func PoolGetSlot0FeeProtocol(poolPath string) uint8 { - return mustGetPool(poolPath).GetSlot0FeeProtocol() + return mustGetPool(poolPath).Slot0FeeProtocol() } func PoolGetSlot0Unlocked(poolPath string) bool { - return mustGetPool(poolPath).GetSlot0Unlocked() + return mustGetPool(poolPath).Slot0Unlocked() } func PoolGetFeeGrowthGlobal0X128(poolPath string) string { - return mustGetPool(poolPath).GetFeeGrowthGlobal0X128().ToString() + return mustGetPool(poolPath).FeeGrowthGlobal0X128().ToString() } func PoolGetFeeGrowthGlobal1X128(poolPath string) string { - return mustGetPool(poolPath).GetFeeGrowthGlobal1X128().ToString() + return mustGetPool(poolPath).FeeGrowthGlobal1X128().ToString() } func PoolGetProtocolFeesToken0(poolPath string) string { - return mustGetPool(poolPath).GetProtocolFeesToken0().ToString() + return mustGetPool(poolPath).ProtocolFeesToken0().ToString() } func PoolGetProtocolFeesToken1(poolPath string) string { - return mustGetPool(poolPath).GetProtocolFeesToken1().ToString() + return mustGetPool(poolPath).ProtocolFeesToken1().ToString() } func PoolGetLiquidity(poolPath string) string { - return mustGetPool(poolPath).GetLiquidity().ToString() + return mustGetPool(poolPath).Liquidity().ToString() } -// position func PoolGetPositionLiquidity(poolPath, key string) string { - return mustGetPool(poolPath).GetPositionLiquidity(key).ToString() + return mustGetPool(poolPath).PositionLiquidity(key).ToString() } func PoolGetPositionFeeGrowthInside0LastX128(poolPath, key string) string { - return mustGetPool(poolPath).GetPositionFeeGrowthInside0LastX128(key).ToString() + return mustGetPool(poolPath).PositionFeeGrowthInside0LastX128(key).ToString() } func PoolGetPositionFeeGrowthInside1LastX128(poolPath, key string) string { - return mustGetPool(poolPath).GetPositionFeeGrowthInside1LastX128(key).ToString() + return mustGetPool(poolPath).PositionFeeGrowthInside1LastX128(key).ToString() } func PoolGetPositionTokensOwed0(poolPath, key string) string { - return mustGetPool(poolPath).GetPositionTokensOwed0(key).ToString() + return mustGetPool(poolPath).PositionTokensOwed0(key).ToString() } func PoolGetPositionTokensOwed1(poolPath, key string) string { - return mustGetPool(poolPath).GetPositionTokensOwed1(key).ToString() + return mustGetPool(poolPath).PositionTokensOwed1(key).ToString() } -// tick func PoolGetTickLiquidityGross(poolPath string, tick int32) string { return mustGetPool(poolPath).GetTickLiquidityGross(tick).ToString() } diff --git a/pool/liquidity_math.gno b/pool/liquidity_math.gno index 47aaea2c..82a8b431 100644 --- a/pool/liquidity_math.gno +++ b/pool/liquidity_math.gno @@ -44,7 +44,7 @@ func liquidityMathAddDelta(x *u256.Uint, y *i256.Int) *u256.Uint { if z.Gte(x) { panic(addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("Less than Condition(z must be < x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + ufmt.Sprintf("Condition failed: (z must be < x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), )) } } else { @@ -52,7 +52,7 @@ func liquidityMathAddDelta(x *u256.Uint, y *i256.Int) *u256.Uint { if z.Lt(x) { panic(addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("Less than or Equal Condition(z must be >= x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + ufmt.Sprintf("Condition failed: (z must be >= x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), )) } } diff --git a/pool/liquidity_math_test.gno b/pool/liquidity_math_test.gno index fc2f379b..6c38f3df 100644 --- a/pool/liquidity_math_test.gno +++ b/pool/liquidity_math_test.gno @@ -42,7 +42,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { }, wantPanic: addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("Less than Condition(z must be < x) (x: 0, y: -100, z:115792089237316195423570985008687907853269984665640564039457584007913129639836)")), + ufmt.Sprintf("Condition failed: (z must be < x) (x: 0, y: -100, z:115792089237316195423570985008687907853269984665640564039457584007913129639836)")), }, { name: "overflow panic with add delta", @@ -53,7 +53,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { }, wantPanic: addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("Less than or Equal Condition(z must be >= x) (x: 115792089237316195423570985008687907853269984665640564039457584007913129639935, y: 100, z:99)")), + ufmt.Sprintf("Condition failed: (z must be >= x) (x: 115792089237316195423570985008687907853269984665640564039457584007913129639935, y: 100, z:99)")), }, } diff --git a/pool/pool.gno b/pool/pool.gno index db4df9b8..df2e51ee 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -12,8 +12,32 @@ import ( u256 "gno.land/p/gnoswap/uint256" ) -// Mint creates a new position and mints liquidity tokens. -// Returns minted amount0, amount1 in string +// Mint adds liquidity to a pool by minting a new position. +// +// This function mints a liquidity position within the specified tick range in a pool. +// It verifies caller permissions, validates inputs, and updates the pool's state. Additionally, +// it transfers the required amounts of token0 and token1 from the caller to the pool. +// +// Parameters: +// - token0Path: string, the path or identifier for token 0. +// - token1Path: string, the path or identifier for token 1. +// - fee: uint32, the fee tier of the pool. +// - recipient: std.Address, the address to receive the newly created position. +// - tickLower: int32, the lower tick boundary of the position. +// - tickUpper: int32, the upper tick boundary of the position. +// - liquidityAmount: string, the amount of liquidity to add, provided as a decimal string. +// - positionCaller: std.Address, the address of the entity calling the function (e.g., the position owner). +// +// Returns: +// - string: The amount of token 0 transferred to the pool as a string. +// - string: The amount of token 1 transferred to the pool as a string. +// +// Panic Conditions: +// - The system is halted (`common.IsHalted()`). +// - Caller lacks permission to mint a position when `common.GetLimitCaller()` is enforced. +// - The provided `liquidityAmount` is zero. +// - Any failure during token transfers or position modifications. +// // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#mint func Mint( token0Path string, @@ -22,43 +46,54 @@ func Mint( recipient std.Address, tickLower int32, tickUpper int32, - _liquidityAmount string, + liquidityAmount string, positionCaller std.Address, ) (string, string) { - common.IsHalted() + assertOnlyNotHalted() if common.GetLimitCaller() { - caller := std.PrevRealm().Addr() - if err := common.PositionOnly(caller); err != nil { - panic(addDetailToError( - errNoPermission, - ufmt.Sprintf("only position(%s) can call pool mint(), called from %s", consts.POSITION_ADDR, caller.String()), - )) - } + assertOnlyPositionContract() } - liquidityAmount := u256.MustFromDecimal(_liquidityAmount) - if liquidityAmount.IsZero() { + liquidity := u256.MustFromDecimal(liquidityAmount) + if liquidity.IsZero() { panic(errZeroLiquidity) } pool := GetPool(token0Path, token1Path, fee) - position := newModifyPositionParams(recipient, tickLower, tickUpper, i256.FromUint256(liquidityAmount)) + liquidityDelta := safeConvertToInt128(liquidity) + position := newModifyPositionParams(recipient, tickLower, tickUpper, liquidityDelta) _, amount0, amount1 := pool.modifyPosition(position) if amount0.Gt(u256.Zero()) { - pool.transferFromAndVerify(positionCaller, consts.POOL_ADDR, pool.token0Path, amount0, true) + pool.safeTransferFrom(positionCaller, consts.POOL_ADDR, pool.token0Path, amount0, true) } if amount1.Gt(u256.Zero()) { - pool.transferFromAndVerify(positionCaller, consts.POOL_ADDR, pool.token1Path, amount1, false) + pool.safeTransferFrom(positionCaller, consts.POOL_ADDR, pool.token1Path, amount1, false) } return amount0.ToString(), amount1.ToString() } -// Burn removes liquidity from the caller and account tokens owed for the liquidity to the position -// If liquidity of 0 is burned, it recalculates fees owed to a position -// Returns burned amount0, amount1 in string +// Burn removes liquidity from a position in the pool. +// +// This function allows the caller to burn (remove) a specified amount of liquidity from a position. +// It calculates the amounts of token0 and token1 released when liquidity is removed and updates +// the position's owed token amounts. The actual transfer of tokens back to the caller happens +// during a separate `Collect()` operation. +// +// Parameters: +// - token0Path: string, the path or identifier for token 0. +// - token1Path: string, the path or identifier for token 1. +// - fee: uint32, the fee tier of the pool. +// - tickLower: int32, the lower tick boundary of the position. +// - tickUpper: int32, the upper tick boundary of the position. +// - liquidityAmount: string, the amount of liquidity to remove, provided as a decimal string (uint128). +// +// Returns: +// - string: The amount of token0 owed after removing liquidity, as a string. +// - string: The amount of token1 owed after removing liquidity, as a string. +// // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#burn func Burn( token0Path string, @@ -68,40 +103,54 @@ func Burn( tickUpper int32, liquidityAmount string, // uint128 ) (string, string) { // uint256 x2 - common.IsHalted() - caller := std.PrevRealm().Addr() + assertOnlyNotHalted() if common.GetLimitCaller() { - if err := common.PositionOnly(caller); err != nil { - panic(addDetailToError( - errNoPermission, - ufmt.Sprintf("only position(%s) can call pool burn(), called from %s", consts.POSITION_ADDR, caller.String()), - )) - } + assertOnlyPositionContract() } - - liqAmount := u256.MustFromDecimal(liquidityAmount) - pool := GetPool(token0Path, token1Path, fee) - liqDelta := i256.Zero().Neg(i256.FromUint256(liqAmount)) + caller := getPrevAddr() + liqAmount := u256.MustFromDecimal(liquidityAmount) + liqAmountInt256 := safeConvertToInt128(liqAmount) + liqDelta := i256.Zero().Neg(liqAmountInt256) posParams := newModifyPositionParams(caller, tickLower, tickUpper, liqDelta) position, amount0, amount1 := pool.modifyPosition(posParams) if amount0.Gt(u256.Zero()) || amount1.Gt(u256.Zero()) { + amount0 = toUint128(amount0) + amount1 = toUint128(amount1) position.tokensOwed0 = new(u256.Uint).Add(position.tokensOwed0, amount0) position.tokensOwed1 = new(u256.Uint).Add(position.tokensOwed1, amount1) } - positionKey := positionGetKey(caller, tickLower, tickUpper) + positionKey := getPositionKey(caller, tickLower, tickUpper) pool.positions[positionKey] = position // actual token transfer happens in Collect() return amount0.ToString(), amount1.ToString() } -// Collect collects tokens owed to a position -// Burned amounts, and swap fees will be transferred to the caller -// Returns collected amount0, amount1 in string +// Collect handles the collection of tokens (token0 and token1) from a liquidity position. +// +// This function allows the caller to collect a specified amount of tokens owed to a position +// in a liquidity pool. It calculates the collectible amount based on three constraints: +// the requested amount, the tokens owed to the position, and the pool's available balance. +// The collected tokens are transferred to the specified recipient. +// +// Parameters: +// - token0Path: string, the path or identifier for token 0. +// - token1Path: string, the path or identifier for token 1. +// - fee: uint32, the fee tier of the pool. +// - recipient: std.Address, the address to receive the collected tokens. +// - tickLower: int32, the lower tick boundary of the position. +// - tickUpper: int32, the upper tick boundary of the position. +// - amount0Requested: string, the requested amount of token 0 to collect (decimal string). +// - amount1Requested: string, the requested amount of token 1 to collect (decimal string). +// +// Returns: +// - string: The actual amount of token 0 collected, as a string. +// - string: The actual amount of token 1 collected, as a string. +// // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#collect func Collect( token0Path string, @@ -113,76 +162,61 @@ func Collect( amount0Requested string, amount1Requested string, ) (string, string) { - common.IsHalted() + assertOnlyNotHalted() if common.GetLimitCaller() { - caller := std.PrevRealm().Addr() - if err := common.PositionOnly(caller); err != nil { - panic(addDetailToError( - errNoPermission, - ufmt.Sprintf("only position(%s) can call pool collect(), called from %s", consts.POSITION_ADDR, caller.String()), - )) - } + assertOnlyPositionContract() } pool := GetPool(token0Path, token1Path, fee) - - positionKey := positionGetKey(std.PrevRealm().Addr(), tickLower, tickUpper) - position, exist := pool.positions[positionKey] - if !exist { - panic(addDetailToError( - errDataNotFound, - ufmt.Sprintf("positionKey(%s) does not exist", positionKey), - )) - } + positionKey := getPositionKey(getPrevAddr(), tickLower, tickUpper) + position := pool.mustGetPosition(positionKey) var amount0, amount1 *u256.Uint // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 amount0Req := u256.MustFromDecimal(amount0Requested) - amount0, position.tokensOwed0, pool.balances.token0 = collectToken(amount0Req, position.tokensOwed0, pool.balances.token0) - token0 := common.GetTokenTeller(pool.token0Path) - checkTransferError(token0.Transfer(recipient, amount0.Uint64())) - - // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 + amount0 = collectToken(amount0Req, position.tokensOwed0, pool.BalanceToken0()) amount1Req := u256.MustFromDecimal(amount1Requested) - amount1, position.tokensOwed1, pool.balances.token1 = collectToken(amount1Req, position.tokensOwed1, pool.balances.token1) + amount1 = collectToken(amount1Req, position.tokensOwed1, pool.BalanceToken1()) - // Update state first then transfer - position.tokensOwed1 = new(u256.Uint).Sub(position.tokensOwed1, amount1) - pool.balances.token1 = new(u256.Uint).Sub(pool.balances.token1, amount1) - token1 := common.GetTokenTeller(pool.token1Path) - checkTransferError(token1.Transfer(recipient, amount1.Uint64())) + if amount0.Gt(u256.Zero()) { + position.tokensOwed0 = new(u256.Uint).Sub(position.tokensOwed0, amount0) + pool.balances.token0 = new(u256.Uint).Sub(pool.balances.token0, amount0) + token0 := common.GetTokenTeller(pool.token0Path) + checkTransferError(token0.Transfer(recipient, amount0.Uint64())) + } + if amount1.Gt(u256.Zero()) { + position.tokensOwed1 = new(u256.Uint).Sub(position.tokensOwed1, amount1) + pool.balances.token1 = new(u256.Uint).Sub(pool.balances.token1, amount1) + token1 := common.GetTokenTeller(pool.token1Path) + checkTransferError(token1.Transfer(recipient, amount1.Uint64())) + } pool.positions[positionKey] = position return amount0.ToString(), amount1.ToString() } -// collectToken handles the collection of tokens (either token0 or token1) from a position. -// It calculates the actual amount that can be collected based on three constraints: -// the requested amount, tokens owed, and available pool balance. +// collectToken calculates the actual amount of tokens that can be collected. +// +// This function determines the smallest possible value among the requested amount (`amountReq`), +// the tokens owed (`tokensOwed`), and the pool's available balance (`poolBalance`). It ensures +// the collected amount does not exceed any of these constraints. // // Parameters: -// - amountReq: amount requested to collect -// - tokensOwed: amount of tokens owed to the position -// - poolBalance: current balance of tokens in the pool +// - amountReq: *u256.Uint, the amount of tokens requested for collection. +// - tokensOwed: *u256.Uint, the total amount of tokens owed to the position. +// - poolBalance: *u256.Uint, the current balance of tokens available in the pool. // // Returns: -// - amount: actual amount that will be collected (minimum of the three inputs) -// - newTokensOwed: remaining tokens owed after collection -// - newPoolBalance: remaining pool balance after collection +// - amount: *u256.Uint, the actual amount that can be collected (minimum of the three inputs). func collectToken( amountReq, tokensOwed, poolBalance *u256.Uint, -) (amount, newTokensOwed, newPoolBalance *u256.Uint) { +) (amount *u256.Uint) { // find smallest of three amounts amount = u256Min(amountReq, tokensOwed) amount = u256Min(amount, poolBalance) - - // value for update state - newTokensOwed = new(u256.Uint).Sub(tokensOwed, amount) - newPoolBalance = new(u256.Uint).Sub(poolBalance, amount) - - return amount, newTokensOwed, newPoolBalance + return amount.Clone() } // SetFeeProtocolByAdmin sets the fee protocol for all pools @@ -191,22 +225,8 @@ func SetFeeProtocolByAdmin( feeProtocol0 uint8, feeProtocol1 uint8, ) { - caller := std.PrevRealm().Addr() - if err := common.AdminOnly(caller); err != nil { - panic(err) - } - - newFee := setFeeProtocol(feeProtocol0, feeProtocol1) - - prevAddr, prevRealm := getPrev() - std.Emit( - "SetFeeProtocolByAdmin", - "prevAddr", prevAddr, - "prevRealm", prevRealm, - "feeProtocol0", ufmt.Sprintf("%d", feeProtocol0), - "feeProtocol1", ufmt.Sprintf("%d", feeProtocol1), - "internal_newFee", ufmt.Sprintf("%d", newFee), - ) + assertOnlyAdmin() + setFeeProtocolInternal(feeProtocol0, feeProtocol1, "SetFeeProtocolByAdmin") } // SetFeeProtocol sets the fee protocol for all pools @@ -214,40 +234,62 @@ func SetFeeProtocolByAdmin( // Also it will be applied to new created pools // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#setfeeprotocol func SetFeeProtocol(feeProtocol0, feeProtocol1 uint8) { - caller := std.PrevRealm().Addr() - if err := common.GovernanceOnly(caller); err != nil { - panic(err) - } + assertOnlyGovernance() + setFeeProtocolInternal(feeProtocol0, feeProtocol1, "SetFeeProtocol") +} +// setFeeProtocolInternal updates the protocol fee for all pools and emits an event. +// +// This function is an internal utility used to set the protocol fee for token0 and token1 in a compact +// format. The fee values are stored as a single `uint8` byte where: +// - Lower 4 bits represent the fee for token0 (feeProtocol0). +// - Upper 4 bits represent the fee for token1 (feeProtocol1). +// +// It also emits an event to log the changes, including the previous and new fee protocol values. +// +// Parameters: +// - feeProtocol0: uint8, protocol fee for token0 (must be 0 or between 4 and 10 inclusive). +// - feeProtocol1: uint8, protocol fee for token1 (must be 0 or between 4 and 10 inclusive). +// - eventName: string, the name of the event to emit (e.g., "SetFeeProtocolByAdmin"). +// +// Notes: +// - This function is called by higher-level functions like `SetFeeProtocolByAdmin` or `SetFeeProtocol`. +// - It does not validate caller permissions; validation must be performed by the calling function. +func setFeeProtocolInternal(feeProtocol0, feeProtocol1 uint8, eventName string) { + oldFee := slot0FeeProtocol newFee := setFeeProtocol(feeProtocol0, feeProtocol1) - prevAddr, prevRealm := getPrev() + feeProtocol0Old := oldFee % 16 + feeProtocol1Old := oldFee >> 4 + + prevAddr, prevPkgPath := getPrevAsString() std.Emit( - "SetFeeProtocol", + eventName, "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, + "feeProtocol0Old", ufmt.Sprintf("%d", feeProtocol0Old), + "feeProtocol1Old", ufmt.Sprintf("%d", feeProtocol1Old), "feeProtocol0", ufmt.Sprintf("%d", feeProtocol0), "feeProtocol1", ufmt.Sprintf("%d", feeProtocol1), "internal_newFee", ufmt.Sprintf("%d", newFee), ) } -// setFeeProtocol updates the protocol fee configuration for all existing pools and sets -// the default for new pools. This is an internal function called by both `admin` and `governance` -// protocol fee management functions. +// setFeeProtocol updates the protocol fee configuration for all managed pools. // -// The protocol fee is stored as a single `uint8` value where: -// - Lower 4 bits store feeProtocol0 (for token0) -// - Upper 4 bits store feeProtocol1 (for token1) +// This function combines the protocol fee values for token0 and token1 into a single `uint8` value, +// where: +// - Lower 4 bits store feeProtocol0 (for token0). +// - Upper 4 bits store feeProtocol1 (for token1). // -// This compact representation allows storing both fee values in a single byte. +// The updated fee protocol is applied uniformly to all pools managed by the system. // -// Parameters (must be 0 or between 4 and 10 inclusive): -// - feeProtocol0: protocol fee for token0 -// - feeProtocol1: protocol fee for token1 +// Parameters: +// - feeProtocol0: protocol fee for token0 (must be 0 or between 4 and 10 inclusive). +// - feeProtocol1: protocol fee for token1 (must be 0 or between 4 and 10 inclusive). // // Returns: -// - newFee (uint8): the combined fee protocol value +// - newFee (uint8): the combined fee protocol value. // // Example: // If feeProtocol0 = 4 and feeProtocol1 = 5: @@ -257,32 +299,47 @@ func SetFeeProtocol(feeProtocol0, feeProtocol1 uint8) { // // Binary: 0101 0100 // // ^^^^ ^^^^ // // fee1=5 fee0=4 +// +// Notes: +// - This function ensures that all pools under management are updated to use the same fee protocol. +// - Caller restrictions (e.g., admin or governance) are not enforced in this function. +// - Ensure the system is not halted before updating fees. func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 { - common.IsHalted() - + assertOnlyNotHalted() if err := validateFeeProtocol(feeProtocol0, feeProtocol1); err != nil { panic(addDetailToError( err, ufmt.Sprintf("expected (feeProtocol0(%d) == 0 || (feeProtocol0(%d) >= 4 && feeProtocol0(%d) <= 10)) && (feeProtocol1(%d) == 0 || (feeProtocol1(%d) >= 4 && feeProtocol1(%d) <= 10))", feeProtocol0, feeProtocol0, feeProtocol0, feeProtocol1, feeProtocol1, feeProtocol1), )) } - // combine both protocol fee into a single byte: // - feePrtocol0 occupies the lower 4 bits // - feeProtocol1 is shifted the lower 4 positions to occupy the upper 4 bits newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 ) - - // iterate all pool + // Update slot0 for each pool for _, pool := range pools { - pool.slot0.feeProtocol = newFee + if pool != nil { + pool.slot0.feeProtocol = newFee + } } - // update slot0 slot0FeeProtocol = newFee - return newFee } +// validateFeeProtocol validates the fee protocol values for token0 and token1. +// +// This function checks whether the provided fee protocol values (`feeProtocol0` and `feeProtocol1`) +// are valid using the `isValidFeeProtocolValue` function. If either value is invalid, it returns +// an error indicating that the protocol fee percentage is invalid. +// +// Parameters: +// - feeProtocol0: uint8, the fee protocol value for token0. +// - feeProtocol1: uint8, the fee protocol value for token1. +// +// Returns: +// - error: Returns `errInvalidProtocolFeePct` if either `feeProtocol0` or `feeProtocol1` is invalid. +// Returns `nil` if both values are valid. func validateFeeProtocol(feeProtocol0, feeProtocol1 uint8) error { if !isValidFeeProtocolValue(feeProtocol0) || !isValidFeeProtocolValue(feeProtocol1) { return errInvalidProtocolFeePct @@ -303,17 +360,10 @@ func CollectProtocolByAdmin( token1Path string, fee uint32, recipient std.Address, - amount0Requested string, // uint128 - amount1Requested string, // uint128 -) (string, string) { // uint128 x2 - common.MustRegistered(token0Path) - common.MustRegistered(token1Path) - - caller := std.PrevRealm().Addr() - if err := common.AdminOnly(caller); err != nil { - panic(err) - } - + amount0Requested string, + amount1Requested string, +) (string, string) { + assertOnlyAdmin() amount0, amount1 := collectProtocol( token0Path, token1Path, @@ -323,11 +373,11 @@ func CollectProtocolByAdmin( amount1Requested, ) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "CollectProtocolByAdmin", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "token0Path", token0Path, "token1Path", token1Path, "fee", ufmt.Sprintf("%d", fee), @@ -351,14 +401,7 @@ func CollectProtocol( amount0Requested string, // uint128 amount1Requested string, // uint128 ) (string, string) { // uint128 x2 - common.MustRegistered(token0Path) - common.MustRegistered(token1Path) - - caller := std.PrevRealm().Addr() - if err := common.GovernanceOnly(caller); err != nil { - panic(err) - } - + assertOnlyGovernance() amount0, amount1 := collectProtocol( token0Path, token1Path, @@ -368,11 +411,11 @@ func CollectProtocol( amount1Requested, ) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "CollectProtocol", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "token0Path", token0Path, "token1Path", token1Path, "fee", ufmt.Sprintf("%d", fee), @@ -384,6 +427,23 @@ func CollectProtocol( return amount0, amount1 } +// collectProtocol collects protocol fees for token0 and token1 from the specified pool. +// +// This function allows the collection of accumulated protocol fees for token0 and token1. It ensures +// the requested amounts do not exceed the available protocol fees in the pool and transfers the +// collected amounts to the specified recipient. +// +// Parameters: +// - token0Path: string, the path or identifier for token0. +// - token1Path: string, the path or identifier for token1. +// - fee: uint32, the fee tier of the pool. +// - recipient: std.Address, the address to receive the collected protocol fees. +// - amount0Requested: string, the requested amount of token0 to collect (decimal string). +// - amount1Requested: string, the requested amount of token1 to collect (decimal string). +// +// Returns: +// - string: The actual amount of token0 collected, as a string. +// - string: The actual amount of token1 collected, as a string. func collectProtocol( token0Path string, token1Path string, @@ -392,25 +452,37 @@ func collectProtocol( amount0Requested string, amount1Requested string, ) (string, string) { - common.IsHalted() + assertOnlyRegistered(token0Path) + assertOnlyRegistered(token1Path) + assertOnlyNotHalted() pool := GetPool(token0Path, token1Path, fee) amount0Req := u256.MustFromDecimal(amount0Requested) amount1Req := u256.MustFromDecimal(amount1Requested) - amount0 := u256Min(amount0Req, pool.protocolFees.token0) - amount1 := u256Min(amount1Req, pool.protocolFees.token1) + amount0 := u256Min(amount0Req, pool.ProtocolFeesToken0()) + amount1 := u256Min(amount1Req, pool.ProtocolFeesToken1()) - amount0, amount1 = pool.saveProtocolFees(amount0, amount1) + amount0, amount1 = pool.saveProtocolFees(amount0.Clone(), amount1.Clone()) uAmount0 := amount0.Uint64() uAmount1 := amount1.Uint64() token0Teller := common.GetTokenTeller(pool.token0Path) checkTransferError(token0Teller.Transfer(recipient, uAmount0)) + newBalanceToken0, err := updatePoolBalance(pool.BalanceToken0(), pool.BalanceToken1(), amount0, true) + if err != nil { + panic(err) + } + pool.balances.token0 = newBalanceToken0 token1Teller := common.GetTokenTeller(pool.token1Path) checkTransferError(token1Teller.Transfer(recipient, uAmount1)) + newBalanceToken1, err := updatePoolBalance(pool.BalanceToken0(), pool.BalanceToken1(), amount1, false) + if err != nil { + panic(err) + } + pool.balances.token1 = newBalanceToken1 return amount0.ToString(), amount1.ToString() } @@ -424,19 +496,19 @@ func collectProtocol( // Returns the adjusted amounts that will actually be collected for both tokens. func (p *Pool) saveProtocolFees(amount0, amount1 *u256.Uint) (*u256.Uint, *u256.Uint) { cond01 := amount0.Gt(u256.Zero()) - cond02 := amount0.Eq(p.protocolFees.token0) + cond02 := amount0.Eq(p.ProtocolFeesToken0()) if cond01 && cond02 { amount0 = new(u256.Uint).Sub(amount0, u256.One()) } cond11 := amount1.Gt(u256.Zero()) - cond12 := amount1.Eq(p.protocolFees.token1) + cond12 := amount1.Eq(p.ProtocolFeesToken1()) if cond11 && cond12 { amount1 = new(u256.Uint).Sub(amount1, u256.One()) } - p.protocolFees.token0 = new(u256.Uint).Sub(p.protocolFees.token0, amount0) - p.protocolFees.token1 = new(u256.Uint).Sub(p.protocolFees.token1, amount1) + p.protocolFees.token0 = new(u256.Uint).Sub(p.ProtocolFeesToken0(), amount0) + p.protocolFees.token1 = new(u256.Uint).Sub(p.ProtocolFeesToken1(), amount1) // return rest fee return amount0, amount1 diff --git a/pool/pool_manager.gno b/pool/pool_manager.gno index c8f5c25b..575b9559 100644 --- a/pool/pool_manager.gno +++ b/pool/pool_manager.gno @@ -6,7 +6,6 @@ import ( "gno.land/p/demo/ufmt" - "gno.land/r/gnoswap/v1/common" "gno.land/r/gnoswap/v1/consts" en "gno.land/r/gnoswap/v1/emission" @@ -130,13 +129,10 @@ func CreatePool( token0Path string, token1Path string, fee uint32, - _sqrtPriceX96 string, + sqrtPriceX96 string, ) { - common.IsHalted() - en.MintAndDistributeGns() - - poolInfo := newPoolParams(token0Path, token1Path, fee, _sqrtPriceX96) - + assertOnlyNotHalted() + poolInfo := newPoolParams(token0Path, token1Path, fee, sqrtPriceX96) if poolInfo.isSameTokenPath() { panic(addDetailToError( errDuplicateTokenInPool, @@ -146,14 +142,14 @@ func CreatePool( ), )) } + en.MintAndDistributeGns() // wrap first token0Path, token1Path = poolInfo.wrap() - poolPath := GetPoolPath(token0Path, token1Path, fee) // reinitialize poolInfo with wrapped tokens - poolInfo = newPoolParams(token0Path, token1Path, fee, _sqrtPriceX96) + poolInfo = newPoolParams(token0Path, token1Path, fee, sqrtPriceX96) // then check if token0Path == token1Path if poolInfo.isSameTokenPath() { @@ -173,8 +169,7 @@ func CreatePool( )) } - // TODO: make this as a parameter - prevAddr, prevRealm := getPrev() + prevAddr, prevRealm := getPrevAsString() // check whether the pool already exist pool, exist := pools.Get(poolPath) @@ -208,7 +203,7 @@ func CreatePool( "token0Path", token0Path, "token1Path", token1Path, "fee", ufmt.Sprintf("%d", fee), - "sqrtPriceX96", _sqrtPriceX96, + "sqrtPriceX96", sqrtPriceX96, "internal_poolPath", poolPath, ) } @@ -220,46 +215,91 @@ func DoesPoolPathExist(poolPath string) bool { return exist } -// GetPool retrieves the pool for the given token paths and fee. -// It constructs the poolPath from the given parameters and returns the corresponding pool. -// Returns pool struct +// GetPool retrieves a pool instance based on the provided token paths and fee tier. +// +// This function determines the pool path by combining the paths of token0 and token1 along with the fee tier, +// and then retrieves the corresponding pool instance using that path. +// +// Parameters: +// - token0Path (string): The unique path for token0. +// - token1Path (string): The unique path for token1. +// - fee (uint32): The fee tier for the pool, expressed in basis points (e.g., 3000 for 0.3%). +// +// Returns: +// - *Pool: A pointer to the Pool instance corresponding to the provided tokens and fee tier. +// +// Notes: +// - The order of token paths (token0Path and token1Path) matters and should match the pool's configuration. +// - Ensure that the tokens and fee tier provided are valid and registered in the system. +// +// Example: +// pool := GetPool("gno.land/r/demo/wugnot", "gno.land/r/gnoswap/v1/gns", 3000) func GetPool(token0Path, token1Path string, fee uint32) *Pool { poolPath := GetPoolPath(token0Path, token1Path, fee) - pool, exist := pools[poolPath] - if !exist { - panic(addDetailToError( - errDataNotFound, - ufmt.Sprintf("pool_manager.gno__GetPool() || expected poolPath(%s) to exist", poolPath), - )) - } - - return pool + return GetPoolFromPoolPath(poolPath) } -// GetPoolFromPoolPath retrieves the pool for the given poolPath. +// GetPoolFromPoolPath retrieves a pool instance based on the provided pool path. +// +// This function checks if a pool exists for the given poolPath in the `pools` mapping. +// If the pool exists, it returns the pool instance. Otherwise, it panics with a descriptive error. +// +// Parameters: +// - poolPath (string): The unique identifier or path for the pool. +// +// Returns: +// - *Pool: A pointer to the Pool instance corresponding to the given poolPath. +// +// Panics: +// - If the `poolPath` does not exist in the `pools` mapping, it panics with an error message +// indicating that the expected poolPath was not found. +// +// Notes: +// - Ensure that the `poolPath` provided is valid and corresponds to an existing pool in the `pools` mapping. +// +// Example: +// pool := GetPoolFromPoolPath("path/to/pool") func GetPoolFromPoolPath(poolPath string) *Pool { pool, exist := pools[poolPath] if !exist { panic(addDetailToError( errDataNotFound, - ufmt.Sprintf("pool_manager.gno__GetPoolFromPoolPath() || expected poolPath(%s) to exist", poolPath), + ufmt.Sprintf("expected poolPath(%s) to exist", poolPath), )) } - return pool } -// GetPoolPath generates a poolPath from the given token paths and fee. -// The poolPath is constructed by joining the token paths and fee with colons. +// GetPoolPath generates a unique pool path string based on the token paths and fee tier. +// +// This function ensures that the token paths are registered and sorted in alphabetical order +// before combining them with the fee tier to create a unique identifier for the pool. +// +// Parameters: +// - token0Path (string): The unique identifier or path for token0. +// - token1Path (string): The unique identifier or path for token1. +// - fee (uint32): The fee tier for the pool, expressed in basis points (e.g., 3000 for 0.3%). +// +// Returns: +// - string: A unique pool path string in the format "token0Path:token1Path:fee". +// +// Notes: +// - The function validates that both `token0Path` and `token1Path` are registered in the system +// using `common.MustRegistered`. +// - The token paths are sorted alphabetically to ensure consistent pool path generation, regardless +// of the input order. +// - This sorting guarantees that the pool path remains deterministic for the same pair of tokens and fee. +// +// Example: +// poolPath := GetPoolPath("path/to/token0", "path/to/token1", 3000) +// // Output: "path/to/token0:path/to/token1:3000" func GetPoolPath(token0Path, token1Path string, fee uint32) string { - common.MustRegistered(token0Path) - common.MustRegistered(token1Path) + assertOnlyRegistered(token0Path) + assertOnlyRegistered(token1Path) - // TODO: this check is not unnecessary, if we are sure that // all the token paths in the pool are sorted in alphabetical order. if strings.Compare(token1Path, token0Path) < 0 { token0Path, token1Path = token1Path, token0Path } - return ufmt.Sprintf("%s:%s:%d", token0Path, token1Path, fee) } diff --git a/pool/pool_manager_test.gno b/pool/pool_manager_test.gno index 328315b5..36637a28 100644 --- a/pool/pool_manager_test.gno +++ b/pool/pool_manager_test.gno @@ -202,3 +202,59 @@ func TestCreatePool(t *testing.T) { resetObject(t) } + +func TestGetPool(t *testing.T) { + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) + shouldPanic bool + expected string + }{ + { + name: "Panic - unregisterd poolPath ", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + }, + action: func(t *testing.T) { + GetPool(barPath, fooPath, fee500) + }, + shouldPanic: true, + expected: "[GNOSWAP-POOL-008] requested data not found || expected poolPath(gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500) to exist", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tt.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + case error: + if r.(error).Error() != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r.(error).Error(), tt.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + } + }() + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + } + }) + } +} diff --git a/pool/pool_test.gno b/pool/pool_test.gno index b58d95ac..7a705032 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -1,13 +1,10 @@ package pool import ( - "std" "testing" "gno.land/p/demo/testutils" "gno.land/p/demo/uassert" - - i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" "gno.land/r/gnoswap/v1/consts" @@ -100,7 +97,7 @@ func TestBurn(t *testing.T) { } // setup position for this test - posKey := positionGetKey(mockCaller, tt.tickLower, tt.tickUpper) + posKey := getPositionKey(mockCaller, tt.tickLower, tt.tickUpper) mockPool.positions[posKey] = mockPosition if tt.expectPanic { @@ -140,3 +137,28 @@ func TestBurn(t *testing.T) { } } +func TestSetFeeProtocolInternal(t *testing.T) { + tests := []struct { + name string + feeProtocol0 uint8 + feeProtocol1 uint8 + eventName string + }{ + { + name: "set fee protocol by admin", + feeProtocol0: 4, + feeProtocol1: 5, + eventName: "SetFeeProtocolByAdmin", + }, + } + + for _, tt := range tests { + t.Run("set fee protocol by admin", func(t *testing.T) { + InitialisePoolTest(t) + pool := GetPool(wugnotPath, gnsPath, fee3000) + SetFeeProtocolByAdmin(tt.feeProtocol0, tt.feeProtocol1) + uassert.Equal(t, tt.feeProtocol0, pool.Slot0FeeProtocol()%16) + uassert.Equal(t, tt.feeProtocol1, pool.Slot0FeeProtocol()>>4) + }) + } +} diff --git a/pool/pool_transfer.gno b/pool/pool_transfer.gno index 201ae2cf..da3951f0 100644 --- a/pool/pool_transfer.gno +++ b/pool/pool_transfer.gno @@ -11,7 +11,7 @@ import ( u256 "gno.land/p/gnoswap/uint256" ) -// transferAndVerify performs a token transfer out of the pool while ensuring +// safeTransfer performs a token transfer out of the pool while ensuring // the pool has sufficient balance and updating internal accounting. // This function is typically used during swaps and liquidity removals. // @@ -33,7 +33,7 @@ import ( // 4. Update pool's internal balance // // Panics if any validation fails or if the transfer fails -func (p *Pool) transferAndVerify( +func (p *Pool) safeTransfer( to std.Address, tokenPath string, amount *i256.Int, @@ -47,16 +47,13 @@ func (p *Pool) transferAndVerify( absAmount := amount.Abs() - token0 := p.balances.token0 - token1 := p.balances.token1 + token0 := p.BalanceToken0() + token1 := p.BalanceToken1() if err := validatePoolBalance(token0, token1, absAmount, isToken0); err != nil { panic(err) } - amountUint64, err := safeConvertToUint64(absAmount) - if err != nil { - panic(err) - } + amountUint64 := safeConvertToUint64(absAmount) token := common.GetTokenTeller(tokenPath) checkTransferError(token.Transfer(to, amountUint64)) @@ -73,45 +70,79 @@ func (p *Pool) transferAndVerify( } } -// transferFromAndVerify performs a token transfer into the pool using transferFrom -// while updating the pool's internal accounting. This function is typically used -// during swaps and liquidity additions. +// safeTransferFrom securely transfers tokens into the pool while ensuring balance consistency. // -// The function assumes the sender has approved the pool to spend their tokens. +// This function performs the following steps: +// 1. Validates and converts the transfer amount to `uint64` using `safeConvertToUint64`. +// 2. Executes the token transfer using `TransferFrom` via the token teller contract. +// 3. Verifies that the destination balance reflects the correct amount after transfer. +// 4. Updates the pool's internal balances (`token0` or `token1`) and validates the updated state. // // Parameters: -// - from: source address for the transfer -// - to: destination address (typically the pool) -// - tokenPath: path identifier of the token to transfer -// - amount: amount to transfer (must be positive) -// - isToken0: true if transferring token0, false for token1 +// - from (std.Address): Source address for the token transfer. +// - to (std.Address): Destination address, typically the pool address. +// - tokenPath (string): Path identifier for the token being transferred. +// - amount (*u256.Uint): The amount of tokens to transfer (must be a positive value). +// - isToken0 (bool): A flag indicating whether the token being transferred is token0 (`true`) or token1 (`false`). // -// The function will: -// 1. Convert amount to uint64 (must fit) -// 2. Execute the transferFrom -// 3. Update pool's internal balance +// Panics: +// - If the `amount` exceeds the uint64 range during conversion. +// - If the token transfer (`TransferFrom`) fails. +// - If the destination balance after the transfer does not match the expected amount. +// - If the pool's internal balances (`token0` or `token1`) overflow or become inconsistent. // -// Panics if the amount conversion fails or if the transfer fails -func (p *Pool) transferFromAndVerify( +// Notes: +// - The function assumes that the sender (`from`) has approved the pool to spend the specified tokens. +// - The balance consistency check ensures that no tokens are lost or double-counted during the transfer. +// - Pool balance updates are performed atomically to ensure internal consistency. +// +// Example: +// p.safeTransferFrom( +// +// sender, poolAddress, "path/to/token0", u256.MustFromDecimal("1000"), true +// +// ) +func (p *Pool) safeTransferFrom( from, to std.Address, tokenPath string, amount *u256.Uint, isToken0 bool, ) { - absAmount := amount - amountUint64, err := safeConvertToUint64(absAmount) - if err != nil { - panic(err) - } + amountUint64 := safeConvertToUint64(amount) - token := common.GetTokenTeller(tokenPath) - checkTransferError(token.TransferFrom(from, to, amountUint64)) + token := common.GetToken(tokenPath) + beforeBalance := token.BalanceOf(to) + + teller := common.GetTokenTeller(tokenPath) + checkTransferError(teller.TransferFrom(from, to, amountUint64)) + + afterBalance := token.BalanceOf(to) + if (beforeBalance + amountUint64) != afterBalance { + panic(ufmt.Sprintf( + "%v. beforeBalance(%d) + amount(%d) != afterBalance(%d)", + errTransferFailed, beforeBalance, amountUint64, afterBalance, + )) + } // update pool balances if isToken0 { - p.balances.token0 = new(u256.Uint).Add(p.balances.token0, absAmount) + beforeToken0 := p.balances.token0.Clone() + p.balances.token0 = new(u256.Uint).Add(p.balances.token0, amount) + if p.balances.token0.Lt(beforeToken0) { + panic(ufmt.Sprintf( + "%v. token0(%s) < beforeToken0(%s)", + errBalanceUpdateFailed, p.balances.token0.ToString(), beforeToken0.ToString(), + )) + } } else { - p.balances.token1 = new(u256.Uint).Add(p.balances.token1, absAmount) + beforeToken1 := p.balances.token1.Clone() + p.balances.token1 = new(u256.Uint).Add(p.balances.token1, amount) + if p.balances.token1.Lt(beforeToken1) { + panic(ufmt.Sprintf( + "%v. token1(%s) < beforeToken1(%s)", + errBalanceUpdateFailed, p.balances.token1.ToString(), beforeToken1.ToString(), + )) + } } } @@ -150,7 +181,7 @@ func updatePoolBalance( if isBalanceOverflowOrNegative(overflow, newBalance) { return nil, ufmt.Errorf( "%v. cannot decrease, token0(%s) - amount(%s)", - errTransferFailed, token0.ToString(), amount.ToString(), + errBalanceUpdateFailed, token0.ToString(), amount.ToString(), ) } return newBalance, nil @@ -160,12 +191,13 @@ func updatePoolBalance( if isBalanceOverflowOrNegative(overflow, newBalance) { return nil, ufmt.Errorf( "%v. cannot decrease, token1(%s) - amount(%s)", - errTransferFailed, token1.ToString(), amount.ToString(), + errBalanceUpdateFailed, token1.ToString(), amount.ToString(), ) } return newBalance, nil } +// isBalanceOverflowOrNegative checks if the balance calculation resulted in an overflow or negative value. func isBalanceOverflowOrNegative(overflow bool, newBalance *u256.Uint) bool { return overflow || newBalance.Lt(u256.Zero()) } diff --git a/pool/pool_transfer_test.gno b/pool/pool_transfer_test.gno index 562825b9..9a610d51 100644 --- a/pool/pool_transfer_test.gno +++ b/pool/pool_transfer_test.gno @@ -134,10 +134,10 @@ func TestTransferFromAndVerify(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - TokenFaucet(t, fooPath, pusers.AddressOrName(tt.from)) - TokenApprove(t, fooPath, pusers.AddressOrName(tt.from), pool, u256.MustFromDecimal(tt.amount.ToString()).Uint64()) + TokenFaucet(t, tt.tokenPath, pusers.AddressOrName(tt.from)) + TokenApprove(t, tt.tokenPath, pusers.AddressOrName(tt.from), pool, u256.MustFromDecimal(tt.amount.ToString()).Uint64()) - tt.pool.transferFromAndVerify(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) + tt.pool.safeTransferFrom(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) if !tt.pool.balances.token0.Eq(tt.expectedBal0) { t.Errorf("token0 balance mismatch: expected %s, got %s", @@ -165,7 +165,7 @@ func TestTransferFromAndVerify(t *testing.T) { TokenFaucet(t, fooPath, pusers.AddressOrName(testutils.TestAddress("from_addr"))) TokenApprove(t, fooPath, pusers.AddressOrName(testutils.TestAddress("from_addr")), pusers.AddressOrName(consts.POOL_ADDR), u256.MustFromDecimal(negativeAmount.Abs().ToString()).Uint64()) - pool.transferFromAndVerify( + pool.safeTransferFrom( testutils.TestAddress("from_addr"), testutils.TestAddress("to_addr"), fooPath, @@ -197,7 +197,7 @@ func TestTransferFromAndVerify(t *testing.T) { } }() - pool.transferFromAndVerify( + pool.safeTransferFrom( testutils.TestAddress("from_addr"), testutils.TestAddress("to_addr"), fooPath, diff --git a/pool/position.gno b/pool/position.gno index 64d37e9d..8afba990 100644 --- a/pool/position.gno +++ b/pool/position.gno @@ -15,49 +15,52 @@ var ( Q128 = u256.MustFromDecimal(consts.Q128) ) -// positionGetKey generates a unique key for a position based on the owner's address and the tick range. -func positionGetKey( +// getPositionKey generates a unique, encoded key for a liquidity position. +// +// This function creates a unique key for identifying a liquidity position in a pool. The key is based +// on the position's owner address, lower tick, and upper tick values. The generated key is then encoded +// as a base64 string to ensure compatibility and uniqueness. +// +// Parameters: +// - owner: std.Address, the address of the position's owner. +// - tickLower: int32, the lower tick boundary for the position. +// - tickUpper: int32, the upper tick boundary for the position. +// +// Returns: +// - string: A base64-encoded string representing the unique position key. +// +// Workflow: +// 1. Validates that the `owner` address is valid using `assertOnlyValidAddress`. +// 2. Ensures `tickLower` is less than `tickUpper` using `assertTickLowerLessThanUpper`. +// 3. Constructs the position key as a formatted string: +// "____" +// 4. Encodes the generated position key into a base64 string for safety and uniqueness. +// 5. Returns the encoded position key. +// +// Example: +// +// owner := std.Address("0x123456789") +// positionKey := getPositionKey(owner, 100, 200) +// fmt.Println("Position Key:", positionKey) +// // Output: base64-encoded string representing "0x123456789__100__200" +// +// Notes: +// - The base64 encoding ensures that the position key can be safely used as an identifier +// across different systems or data stores. +// - The function will panic if: +// - The `owner` address is invalid. +// - `tickLower` is greater than or equal to `tickUpper`. +func getPositionKey( owner std.Address, tickLower int32, tickUpper int32, ) string { - if !owner.IsValid() { - panic(addDetailToError( - errInvalidAddress, - ufmt.Sprintf("position.gno__positionGetKey() || invalid owner address %s", owner.String()), - )) - } - - if tickLower > tickUpper { - panic(addDetailToError( - errInvalidTickRange, - ufmt.Sprintf("position.gno__positionGetKey() || tickLower(%d) is greater than tickUpper(%d)", tickLower, tickUpper), - )) - } + assertOnlyValidAddress(owner) + assertTickLowerLessThanUpper(tickLower, tickUpper) positionKey := ufmt.Sprintf("%s__%d__%d", owner.String(), tickLower, tickUpper) - - encoded := base64.StdEncoding.EncodeToString([]byte(positionKey)) - return encoded -} - -// positionUpdateWithKey updates a position in the pool and returns the updated position. -func (pool *Pool) positionUpdateWithKey( - positionKey string, - liquidityDelta *i256.Int, - feeGrowthInside0X128 *u256.Uint, - feeGrowthInside1X128 *u256.Uint, -) PositionInfo { - // if pointer is nil, set to zero for calculation - liquidityDelta = liquidityDelta.NilToZero() - feeGrowthInside0X128 = feeGrowthInside0X128.NilToZero() - feeGrowthInside1X128 = feeGrowthInside1X128.NilToZero() - - positionToUpdate := pool.positions[positionKey] - positionAfterUpdate := positionUpdate(positionToUpdate, liquidityDelta, feeGrowthInside0X128, feeGrowthInside1X128) - pool.positions[positionKey] = positionAfterUpdate - - return positionAfterUpdate + encodedPositionKey := base64.StdEncoding.EncodeToString([]byte(positionKey)) + return encodedPositionKey } // positionUpdate calculates and returns an updated PositionInfo. @@ -67,7 +70,7 @@ func positionUpdate( feeGrowthInside0X128 *u256.Uint, feeGrowthInside1X128 *u256.Uint, ) PositionInfo { - position.init() + position.valueOrZero() var liquidityNext *u256.Uint if liquidityDelta.IsZero() { @@ -105,25 +108,48 @@ func positionUpdate( return position } -// receiver getters +// positionUpdateWithKey updates a position in the pool and returns the updated position. +func (p *Pool) positionUpdateWithKey( + positionKey string, + liquidityDelta *i256.Int, + feeGrowthInside0X128 *u256.Uint, + feeGrowthInside1X128 *u256.Uint, +) PositionInfo { + // if pointer is nil, set to zero for calculation + liquidityDelta = liquidityDelta.NilToZero() + feeGrowthInside0X128 = feeGrowthInside0X128.NilToZero() + feeGrowthInside1X128 = feeGrowthInside1X128.NilToZero() + + positionToUpdate := p.GetPosition(positionKey) + positionAfterUpdate := positionUpdate(positionToUpdate, liquidityDelta, feeGrowthInside0X128, feeGrowthInside1X128) + + p.positions[positionKey] = positionAfterUpdate + + return positionAfterUpdate +} -func (p *Pool) GetPositionLiquidity(key string) *u256.Uint { +// PositionLiquidity returns the liquidity of a position. +func (p *Pool) PositionLiquidity(key string) *u256.Uint { return p.mustGetPosition(key).liquidity } -func (p *Pool) GetPositionFeeGrowthInside0LastX128(key string) *u256.Uint { +// PositionFeeGrowthInside0LastX128 returns the fee growth of token0 inside a position. +func (p *Pool) PositionFeeGrowthInside0LastX128(key string) *u256.Uint { return p.mustGetPosition(key).feeGrowthInside0LastX128 } -func (p *Pool) GetPositionFeeGrowthInside1LastX128(key string) *u256.Uint { +// PositionFeeGrowthInside1LastX128 returns the fee growth of token1 inside a position. +func (p *Pool) PositionFeeGrowthInside1LastX128(key string) *u256.Uint { return p.mustGetPosition(key).feeGrowthInside1LastX128 } -func (p *Pool) GetPositionTokensOwed0(key string) *u256.Uint { +// PositionTokensOwed0 returns the amount of token0 owed by a position. +func (p *Pool) PositionTokensOwed0(key string) *u256.Uint { return p.mustGetPosition(key).tokensOwed0 } -func (p *Pool) GetPositionTokensOwed1(key string) *u256.Uint { +// PositionTokensOwed1 returns the amount of token1 owed by a position. +func (p *Pool) PositionTokensOwed1(key string) *u256.Uint { return p.mustGetPosition(key).tokensOwed1 } @@ -135,6 +161,15 @@ func (p *Pool) mustGetPosition(key string) PositionInfo { ufmt.Sprintf("position(%s) does not exist", key), )) } + return position +} +func (p *Pool) GetPosition(key string) PositionInfo { + position, exist := p.positions[key] + if !exist { + newPosition := PositionInfo{} + newPosition.valueOrZero() + return newPosition + } return position } diff --git a/pool/position_modify.gno b/pool/position_modify.gno index c6c57400..fe3bf876 100644 --- a/pool/position_modify.gno +++ b/pool/position_modify.gno @@ -1,11 +1,12 @@ package pool import ( - "gno.land/r/gnoswap/v1/common" - + "gno.land/p/demo/ufmt" i256 "gno.land/p/gnoswap/int256" plp "gno.land/p/gnoswap/pool" u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" ) // modifyPosition updates a position in the pool and calculates the amount of tokens @@ -25,46 +26,49 @@ import ( // - *u256.Uint: amount of token0 needed/returned // - *u256.Uint: amount of token1 needed/returned func (p *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *u256.Uint, *u256.Uint) { + checkTicks(params.tickLower, params.tickUpper) + + // get current state and price bounds + tick := p.Slot0Tick() // update position state - position := p.updatePosition(params) + position := p.updatePosition(params, tick) liqDelta := params.liquidityDelta - if liqDelta.IsZero() { return position, u256.Zero(), u256.Zero() } amount0, amount1 := i256.Zero(), i256.Zero() - // get current state and price bounds - tick := p.slot0.tick // covert ticks to sqrt price to use in amount calculations // price = 1.0001^tick, but we use sqrtPriceX96 sqrtRatioLower := common.TickMathGetSqrtRatioAtTick(params.tickLower) sqrtRatioUpper := common.TickMathGetSqrtRatioAtTick(params.tickUpper) - sqrtPriceX96 := p.slot0.sqrtPriceX96 + sqrtPriceX96 := p.Slot0SqrtPriceX96() // calculate token amounts based on current price position relative to range switch { case tick < params.tickLower: // case 1 // full range between lower and upper tick is used for token0 + // current tick is below the passed range; liquidity can only become in range by crossing from left to + // right, when we'll need _more_ token0 (it's becoming more valuable) so user must provide it amount0 = calculateToken0Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) case tick < params.tickUpper: // case 2 liquidityBefore := p.liquidity - // token0 used from current price to upper tick amount0 = calculateToken0Amount(sqrtPriceX96, sqrtRatioUpper, liqDelta) // token1 used from lower tick to current price amount1 = calculateToken1Amount(sqrtRatioLower, sqrtPriceX96, liqDelta) - // update pool's active liquidity since price is in range p.liquidity = liquidityMathAddDelta(liquidityBefore, liqDelta) default: // case 3 // full range between lower and upper tick is used for token1 + // current tick is above the passed range; liquidity can only become in range by crossing from right to + // left, when we'll need _more_ token1 (it's becoming more valuable) so user must provide it amount1 = calculateToken1Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) } @@ -80,3 +84,36 @@ func calculateToken1Amount(sqrtPriceLower, sqrtPriceUpper *u256.Uint, liquidityD res := plp.SqrtPriceMathGetAmount1DeltaStr(sqrtPriceLower, sqrtPriceUpper, liquidityDelta) return i256.MustFromDecimal(res) } + +func checkTicks(tickLower, tickUpper int32) { + assertTickLowerLessThanUpper(tickLower, tickUpper) + assertValidTickLower(tickLower) + assertValidTickUpper(tickUpper) +} + +func assertTickLowerLessThanUpper(tickLower, tickUpper int32) { + if tickLower >= tickUpper { + panic(addDetailToError( + errInvalidTickRange, + ufmt.Sprintf("tickLower(%d), tickUpper(%d)", tickLower, tickUpper), + )) + } +} + +func assertValidTickLower(tickLower int32) { + if tickLower < consts.MIN_TICK { + panic(addDetailToError( + errTickLowerInvalid, + ufmt.Sprintf("tickLower(%d) < MIN_TICK(%d)", tickLower, consts.MIN_TICK), + )) + } +} + +func assertValidTickUpper(tickUpper int32) { + if tickUpper > consts.MAX_TICK { + panic(addDetailToError( + errTickUpperInvalid, + ufmt.Sprintf("tickUpper(%d) > MAX_TICK(%d)", tickUpper, consts.MAX_TICK), + )) + } +} diff --git a/pool/position_modify_test.gno b/pool/position_modify_test.gno index a6c2da6d..4a8ed140 100644 --- a/pool/position_modify_test.gno +++ b/pool/position_modify_test.gno @@ -134,3 +134,159 @@ func TestModifyPositionEdgeCases(t *testing.T) { uassert.Equal(t, amount1.ToString(), "2958014") }) } + +func TestAssertTickLowerLessThanUpper(t *testing.T) { + tests := []struct { + name string + tickLower int32 + tickUpper int32 + shouldPanic bool + expected string + }{ + { + name: "tickLower is less than tickUpper", + tickLower: -100, + tickUpper: 100, + shouldPanic: false, + expected: "", + }, + { + name: "tickLower equals tickUpper", + tickLower: 50, + tickUpper: 50, + shouldPanic: true, + expected: "[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(50), tickUpper(50)", + }, + { + name: "tickLower greater than tickUpper", + tickLower: 200, + tickUpper: 100, + shouldPanic: true, + expected: "[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(200), tickUpper(100)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } else { + if tt.shouldPanic { + uassert.Equal(t, tt.expected, r.(string)) + } else { + t.Errorf("expected no panic, but got: %v", r) + } + } + }() + assertTickLowerLessThanUpper(tt.tickLower, tt.tickUpper) + }) + } +} + +func TestAssertValidTickLower(t *testing.T) { + tests := []struct { + name string + tickLower int32 + tickUpper int32 + shouldPanic bool + expected string + }{ + { + name: "tickLower equals MIN_TICK", + tickLower: consts.MIN_TICK, + tickUpper: 100, + shouldPanic: false, + expected: "", + }, + { + name: "tickLower greater than MIN_TICK", + tickLower: consts.MIN_TICK + 1, + tickUpper: 50, + shouldPanic: false, + expected: "", + }, + { + name: "tickLower less than MIN_TICK (panic)", + tickLower: consts.MIN_TICK - 1, + tickUpper: 100, + shouldPanic: true, + expected: "[GNOSWAP-POOL-028] tickLower is invalid || tickLower(-887273) < MIN_TICK(-887272)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } else { + if tt.shouldPanic { + uassert.Equal(t, tt.expected, r.(string)) + } else { + t.Errorf("expected no panic, but got: %v", r) + } + } + }() + assertValidTickLower(tt.tickLower) + }) + } +} + +func TestAssertValidTickUpper(t *testing.T) { + tests := []struct { + name string + tickLower int32 + tickUpper int32 + shouldPanic bool + expected string + }{ + { + name: "tickUpper equals MAX_TICK", + tickLower: consts.MIN_TICK, + tickUpper: consts.MAX_TICK, + shouldPanic: false, + expected: "", + }, + { + name: "tickUpper less than MAX_TICK", + tickLower: consts.MIN_TICK + 1, + tickUpper: consts.MAX_TICK - 1, + shouldPanic: false, + expected: "", + }, + { + name: "tickUpper greater than MAX_TICK (panic)", + tickLower: consts.MIN_TICK - 1, + tickUpper: consts.MAX_TICK + 1, + shouldPanic: true, + expected: "[GNOSWAP-POOL-029] tickUpper is invalid || tickUpper(887273) > MAX_TICK(887272)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } else { + if tt.shouldPanic { + uassert.Equal(t, tt.expected, r.(string)) + } else { + t.Errorf("expected no panic, but got: %v", r) + } + } + }() + assertValidTickUpper(tt.tickUpper) + }) + } +} diff --git a/pool/position_test.gno b/pool/position_test.gno index 46928b6a..b96821af 100644 --- a/pool/position_test.gno +++ b/pool/position_test.gno @@ -25,18 +25,23 @@ func TestPositionGetKey(t *testing.T) { panicMsg string expectedKey string }{ - {invalidAddr, 100, 200, true, `[GNOSWAP-POOL-023] invalid address || position.gno__positionGetKey() || invalid owner address invalidAddr`, ""}, // invalid address - {validAddr, 200, 100, true, `[GNOSWAP-POOL-024] tickLower is greater than tickUpper || position.gno__positionGetKey() || tickLower(200) is greater than tickUpper(100)`, ""}, // tickLower > tickUpper - {validAddr, -100, -200, true, `[GNOSWAP-POOL-024] tickLower is greater than tickUpper || position.gno__positionGetKey() || tickLower(-100) is greater than tickUpper(-200)`, ""}, // tickLower > tickUpper - {validAddr, 100, 100, false, "", "ZzF3ZXNrYzZ0eWc5anhndWpsdGEwNDdoNmx0YTA0N2g2bGRqbHVkdV9fMTAwX18xMDA="}, // tickLower == tickUpper - {validAddr, 100, 200, false, "", "ZzF3ZXNrYzZ0eWc5anhndWpsdGEwNDdoNmx0YTA0N2g2bGRqbHVkdV9fMTAwX18yMDA="}, // tickLower < tickUpper + {invalidAddr, 100, 200, true, `[GNOSWAP-POOL-023] invalid address || (invalidAddr)`, ""}, // invalid address + {validAddr, 200, 100, true, `[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(200), tickUpper(100)`, ""}, // tickLower > tickUpper + {validAddr, -100, -200, true, `[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(-100), tickUpper(-200)`, ""}, // tickLower > tickUpper + {validAddr, 100, 100, true, "[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(100), tickUpper(100)", ""}, // tickLower == tickUpper + {validAddr, 100, 200, false, "", "ZzF3ZXNrYzZ0eWc5anhndWpsdGEwNDdoNmx0YTA0N2g2bGRqbHVkdV9fMTAwX18yMDA="}, // tickLower < tickUpper } for _, tc := range tests { + defer func() { + if r := recover(); r != nil { + uassert.Equal(t, tc.panicMsg, r.(string)) + } + }() if tc.shouldPanic { - uassert.PanicsWithMessage(t, tc.panicMsg, func() { positionGetKey(tc.owner, tc.tickLower, tc.tickUpper) }) + uassert.PanicsWithMessage(t, tc.panicMsg, func() { getPositionKey(tc.owner, tc.tickLower, tc.tickUpper) }) } else { - key := positionGetKey(tc.owner, tc.tickLower, tc.tickUpper) + key := getPositionKey(tc.owner, tc.tickLower, tc.tickUpper) uassert.Equal(t, tc.expectedKey, key) } } @@ -55,7 +60,7 @@ func TestPositionUpdateWithKey(t *testing.T) { ) dummyPool = newPool(poolParams) - positionKey = positionGetKey( + positionKey = getPositionKey( testutils.TestAddress("dummyAddr"), 100, 200, diff --git a/pool/position_update.gno b/pool/position_update.gno index 5f21e49f..827d91f1 100644 --- a/pool/position_update.gno +++ b/pool/position_update.gno @@ -1,68 +1,97 @@ package pool -import ( - u256 "gno.land/p/gnoswap/uint256" -) - -func (pool *Pool) updatePosition(positionParams ModifyPositionParams) PositionInfo { - feeGrowthGlobal0X128 := pool.feeGrowthGlobal0X128.Clone() - feeGrowthGlobal1X128 := pool.feeGrowthGlobal1X128.Clone() +// updatePosition modifies the position's liquidity and updates the corresponding tick states. +// +// This function updates the position data based on the specified liquidity delta and tick range. +// It also manages the fee growth, tick state flipping, and cleanup of unused tick data. +// +// Parameters: +// - positionParams: ModifyPositionParams, the parameters for the position modification, which include: +// - owner: The address of the position owner. +// - tickLower: The lower tick boundary of the position. +// - tickUpper: The upper tick boundary of the position. +// - liquidityDelta: The change in liquidity (positive or negative). +// - tick: int32, the current tick position. +// +// Returns: +// - PositionInfo: The updated position information. +// +// Workflow: +// 1. Clone the global fee growth values (token 0 and token 1). +// 2. If the liquidity delta is non-zero: +// - Update the lower and upper ticks using `tickUpdate`, flipping their states if necessary. +// - If a tick's state was flipped, update the tick bitmap to reflect the new state. +// 3. Calculate the fee growth inside the tick range using `getFeeGrowthInside`. +// 4. Generate a unique position key and update the position data using `positionUpdateWithKey`. +// 5. If liquidity is being removed (negative delta), clean up unused tick data by deleting the tick entries. +// 6. Return the updated position. +// +// Notes: +// - The function flips the tick states and cleans up unused tick data when liquidity is removed. +// - It ensures fee growth and position data remain accurate after the update. +// +// Example Usage: +// +// updatedPosition := pool.updatePosition(positionParams, currentTick) +// fmt.Println("Updated Position Info:", updatedPosition) +func (p *Pool) updatePosition(positionParams ModifyPositionParams, tick int32) PositionInfo { + feeGrowthGlobal0X128 := p.FeeGrowthGlobal0X128().Clone() + feeGrowthGlobal1X128 := p.FeeGrowthGlobal1X128().Clone() var flippedLower, flippedUpper bool if !(positionParams.liquidityDelta.IsZero()) { - flippedLower = pool.tickUpdate( + flippedLower = p.tickUpdate( positionParams.tickLower, - pool.slot0.tick, + tick, positionParams.liquidityDelta, feeGrowthGlobal0X128, feeGrowthGlobal1X128, false, - pool.maxLiquidityPerTick, + p.maxLiquidityPerTick, ) - flippedUpper = pool.tickUpdate( + flippedUpper = p.tickUpdate( positionParams.tickUpper, - pool.slot0.tick, + tick, positionParams.liquidityDelta, feeGrowthGlobal0X128, feeGrowthGlobal1X128, true, - pool.maxLiquidityPerTick, + p.maxLiquidityPerTick, ) if flippedLower { - pool.tickBitmapFlipTick(positionParams.tickLower, pool.tickSpacing) + p.tickBitmapFlipTick(positionParams.tickLower, p.tickSpacing) } if flippedUpper { - pool.tickBitmapFlipTick(positionParams.tickUpper, pool.tickSpacing) + p.tickBitmapFlipTick(positionParams.tickUpper, p.tickSpacing) } } - feeGrowthInside0X128, feeGrowthInside1X128 := pool.calculateFeeGrowthInside( + feeGrowthInside0X128, feeGrowthInside1X128 := p.getFeeGrowthInside( positionParams.tickLower, positionParams.tickUpper, - pool.slot0.tick, + tick, feeGrowthGlobal0X128, feeGrowthGlobal1X128, ) - positionKey := positionGetKey(positionParams.owner, positionParams.tickLower, positionParams.tickUpper) - - position := pool.positionUpdateWithKey( + positionKey := getPositionKey(positionParams.owner, positionParams.tickLower, positionParams.tickUpper) + position := p.positionUpdateWithKey( positionKey, positionParams.liquidityDelta, - u256.MustFromDecimal(feeGrowthInside0X128.ToString()), - u256.MustFromDecimal(feeGrowthInside1X128.ToString()), + feeGrowthInside0X128.Clone(), + feeGrowthInside1X128.Clone(), ) + // clear any tick data that is no longer needed if positionParams.liquidityDelta.IsNeg() { if flippedLower { - delete(pool.ticks, positionParams.tickLower) + delete(p.ticks, positionParams.tickLower) } - if flippedUpper { - delete(pool.ticks, positionParams.tickUpper) + delete(p.ticks, positionParams.tickUpper) } } diff --git a/pool/position_update_test.gno b/pool/position_update_test.gno index 24ca4cb5..ad2e6707 100644 --- a/pool/position_update_test.gno +++ b/pool/position_update_test.gno @@ -3,19 +3,15 @@ package pool import ( "testing" - "std" - - "gno.land/p/demo/uassert" - - "gno.land/r/gnoswap/v1/consts" i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/consts" ) func TestUpdatePosition(t *testing.T) { poolParams := &createPoolParams{ - token0Path: "token0", - token1Path: "token1", + token0Path: "token0", + token1Path: "token1", fee: 500, tickSpacing: 10, sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 @@ -23,8 +19,8 @@ func TestUpdatePosition(t *testing.T) { p := newPool(poolParams) tests := []struct { - name string - positionParams ModifyPositionParams + name string + positionParams ModifyPositionParams expectLiquidity *u256.Uint }{ { @@ -61,11 +57,12 @@ func TestUpdatePosition(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - position := p.updatePosition(tt.positionParams) - + tick := p.Slot0Tick() + position := p.updatePosition(tt.positionParams, tick) + if !position.liquidity.Eq(tt.expectLiquidity) { - t.Errorf("liquidity mismatch: expected %s, got %s", - tt.expectLiquidity.ToString(), + t.Errorf("liquidity mismatch: expected %s, got %s", + tt.expectLiquidity.ToString(), position.liquidity.ToString()) } diff --git a/pool/protocol_fee_pool_creation.gno b/pool/protocol_fee_pool_creation.gno index a5628429..c6e4e06a 100644 --- a/pool/protocol_fee_pool_creation.gno +++ b/pool/protocol_fee_pool_creation.gno @@ -32,7 +32,7 @@ func SetPoolCreationFee(fee uint64) { } setPoolCreationFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetPoolCreationFee", "prevAddr", prevAddr, @@ -54,7 +54,7 @@ func SetPoolCreationFeeByAdmin(fee uint64) { } setPoolCreationFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetPoolCreationFeeByAdmin", "prevAddr", prevAddr, diff --git a/pool/protocol_fee_withdrawal.gno b/pool/protocol_fee_withdrawal.gno index cacc602c..d7d208b9 100644 --- a/pool/protocol_fee_withdrawal.gno +++ b/pool/protocol_fee_withdrawal.gno @@ -72,7 +72,7 @@ func HandleWithdrawalFee( token1Teller := common.GetTokenTeller(token1Path) checkTransferError(token1Teller.TransferFrom(positionCaller, consts.PROTOCOL_FEE_ADDR, feeAmount1.Uint64())) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "WithdrawalFee", "prevAddr", prevAddr, @@ -106,7 +106,7 @@ func SetWithdrawalFee(fee uint64) { setWithdrawalFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetWithdrawalFee", "prevAddr", prevAddr, @@ -126,7 +126,7 @@ func SetWithdrawalFeeByAdmin(fee uint64) { setWithdrawalFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetWithdrawalFeeByAdmin", "prevAddr", prevAddr, diff --git a/pool/swap.gno b/pool/swap.gno index 47c5e2d6..3b8b9c60 100644 --- a/pool/swap.gno +++ b/pool/swap.gno @@ -107,7 +107,7 @@ func Swap( // actual swap pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "Swap", @@ -512,13 +512,13 @@ func tickTransition(step StepComputations, zeroForOne bool, state SwapState, poo func (p *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { if zeroForOne { // payer > POOL - p.transferFromAndVerify(payer, consts.POOL_ADDR, p.token0Path, amount0.Abs(), true) + p.safeTransferFrom(payer, consts.POOL_ADDR, p.token0Path, amount0.Abs(), true) // POOL > recipient - p.transferAndVerify(recipient, p.token1Path, amount1, false) + p.safeTransfer(recipient, p.token1Path, amount1, false) } else { // payer > POOL - p.transferFromAndVerify(payer, consts.POOL_ADDR, p.token1Path, amount1.Abs(), false) + p.safeTransferFrom(payer, consts.POOL_ADDR, p.token1Path, amount1.Abs(), false) // POOL > recipient - p.transferAndVerify(recipient, p.token0Path, amount0, true) + p.safeTransfer(recipient, p.token0Path, amount0, true) } } diff --git a/pool/tests/__TEST_pool_burn_test.gnoA b/pool/tests/__TEST_pool_burn_test.gnoA index bcaa501b..64d59735 100644 --- a/pool/tests/__TEST_pool_burn_test.gnoA +++ b/pool/tests/__TEST_pool_burn_test.gnoA @@ -81,7 +81,7 @@ func TestDoesNotClear(t *testing.T) { uassert.Equal(t, liq.ToString(), "0") // tokensOwed - thisPositionKey := positionGetKey(consts.POSITION_ADDR, -887160, 887160) + thisPositionKey := getPositionKey(consts.POSITION_ADDR, -887160, 887160) thisPosition := thisPool.positions[thisPositionKey] tokensOwed0 := thisPosition.tokensOwed0 diff --git a/pool/tests/__TEST_pool_spec_#6_test.gnoA b/pool/tests/__TEST_pool_spec_#6_test.gnoA index 03aed8d6..1e36d3c2 100644 --- a/pool/tests/__TEST_pool_spec_#6_test.gnoA +++ b/pool/tests/__TEST_pool_spec_#6_test.gnoA @@ -133,7 +133,7 @@ func TestWorkAccross(t *testing.T) { ) // tokensOwed - thisPositionKey := positionGetKey(consts.POSITION_ADDR, -887270, 887270) + thisPositionKey := getPositionKey(consts.POSITION_ADDR, -887270, 887270) thisPosition := thisPool.positions[thisPositionKey] tokensOwed0 := thisPosition.tokensOwed0 diff --git a/pool/tick.gno b/pool/tick.gno index d0db635c..63e07022 100644 --- a/pool/tick.gno +++ b/pool/tick.gno @@ -19,6 +19,42 @@ func calculateMaxLiquidityPerTick(tickSpacing int32) *u256.Uint { return new(u256.Uint).Div(u256.MustFromDecimal(consts.MAX_UINT128), u256.NewUint(numTicks)) } +// getFeeGrowthBelowX128 calculates the fee growth below a specified tick. +// +// This function computes the fee growth for token 0 and token 1 below a given tick (`tickLower`) +// relative to the current tick (`tickCurrent`). The fee growth values are adjusted based on whether +// the `tickCurrent` is above or below the `tickLower`. +// +// Parameters: +// - tickLower: int32, the lower tick boundary for fee calculation. +// - tickCurrent: int32, the current tick index. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth for token 0 in X128 precision. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth for token 1 in X128 precision. +// - lowerTick: TickInfo, the fee growth and liquidity details for the lower tick. +// +// Returns: +// - *u256.Uint: Fee growth below `tickLower` for token 0. +// - *u256.Uint: Fee growth below `tickLower` for token 1. +// +// Workflow: +// 1. If `tickCurrent` is greater than or equal to `tickLower`: +// - Return the `feeGrowthOutside0X128` and `feeGrowthOutside1X128` values of the `lowerTick`. +// 2. If `tickCurrent` is below `tickLower`: +// - Compute the fee growth below the lower tick by subtracting `feeGrowthOutside` values +// from the global fee growth values (`feeGrowthGlobal0X128` and `feeGrowthGlobal1X128`). +// 3. Return the calculated fee growth values for both tokens. +// +// Behavior: +// - If `tickCurrent >= tickLower`, the fee growth outside the lower tick is returned as-is. +// - If `tickCurrent < tickLower`, the fee growth is calculated as: +// feeGrowthBelow = feeGrowthGlobal - feeGrowthOutside +// +// Example: +// +// feeGrowth0, feeGrowth1 := getFeeGrowthBelowX128( +// 100, 150, globalFeeGrowth0, globalFeeGrowth1, lowerTickInfo, +// ) +// fmt.Println("Fee Growth Below:", feeGrowth0, feeGrowth1) func getFeeGrowthBelowX128( tickLower, tickCurrent int32, feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, @@ -28,12 +64,48 @@ func getFeeGrowthBelowX128( return lowerTick.feeGrowthOutside0X128, lowerTick.feeGrowthOutside1X128 } - below0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, lowerTick.feeGrowthOutside0X128) - below1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, lowerTick.feeGrowthOutside1X128) + feeGrowthBelow0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, lowerTick.feeGrowthOutside0X128) + feeGrowthBelow1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, lowerTick.feeGrowthOutside1X128) - return below0X128, below1X128 + return feeGrowthBelow0X128, feeGrowthBelow1X128 } +// getFeeGrowthAboveX128 calculates the fee growth above a specified tick. +// +// This function computes the fee growth for token 0 and token 1 above a given tick (`tickUpper`) +// relative to the current tick (`tickCurrent`). The fee growth values are adjusted based on whether +// the `tickCurrent` is above or below the `tickUpper`. +// +// Parameters: +// - tickUpper: int32, the upper tick boundary for fee calculation. +// - tickCurrent: int32, the current tick index. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth for token 0 in X128 precision. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth for token 1 in X128 precision. +// - upperTick: TickInfo, the fee growth and liquidity details for the upper tick. +// +// Returns: +// - *u256.Uint: Fee growth above `tickUpper` for token 0. +// - *u256.Uint: Fee growth above `tickUpper` for token 1. +// +// Workflow: +// 1. If `tickCurrent` is less than `tickUpper`: +// - Return the `feeGrowthOutside0X128` and `feeGrowthOutside1X128` values of the `upperTick`. +// 2. If `tickCurrent` is greater than or equal to `tickUpper`: +// - Compute the fee growth above the upper tick by subtracting `feeGrowthOutside` values +// from the global fee growth values (`feeGrowthGlobal0X128` and `feeGrowthGlobal1X128`). +// 3. Return the calculated fee growth values for both tokens. +// +// Behavior: +// - If `tickCurrent < tickUpper`, the fee growth outside the upper tick is returned as-is. +// - If `tickCurrent >= tickUpper`, the fee growth is calculated as: +// feeGrowthAbove = feeGrowthGlobal - feeGrowthOutside +// +// Example: +// +// feeGrowth0, feeGrowth1 := getFeeGrowthAboveX128( +// 200, 150, globalFeeGrowth0, globalFeeGrowth1, upperTickInfo, +// ) +// fmt.Println("Fee Growth Above:", feeGrowth0, feeGrowth1) func getFeeGrowthAboveX128( tickUpper, tickCurrent int32, feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, @@ -43,15 +115,51 @@ func getFeeGrowthAboveX128( return upperTick.feeGrowthOutside0X128, upperTick.feeGrowthOutside1X128 } - above0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, upperTick.feeGrowthOutside0X128) - above1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, upperTick.feeGrowthOutside1X128) + feeGrowthAbove0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, upperTick.feeGrowthOutside0X128) + feeGrowthAbove1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, upperTick.feeGrowthOutside1X128) - return above0X128, above1X128 + return feeGrowthAbove0X128, feeGrowthAbove1X128 } -// calculateFeeGrowthInside calculates the fee growth inside a tick range, -// and returns the fee growth inside for both tokens. -func (p *Pool) calculateFeeGrowthInside( +// getFeeGrowthInside calculates the fee growth within a specified tick range. +// +// This function computes the accumulated fee growth for token 0 and token 1 inside a given tick range +// (`tickLower` to `tickUpper`) relative to the current tick position (`tickCurrent`). It isolates the fee +// growth within the range by subtracting the fee growth below the lower tick and above the upper tick +// from the global fee growth. +// +// Parameters: +// - tickLower: int32, the lower tick boundary of the range. +// - tickUpper: int32, the upper tick boundary of the range. +// - tickCurrent: int32, the current tick index. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth for token 0 in X128 precision. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth for token 1 in X128 precision. +// +// Returns: +// - *u256.Uint: Fee growth inside the tick range for token 0. +// - *u256.Uint: Fee growth inside the tick range for token 1. +// +// Workflow: +// 1. Retrieve the tick information (`lower` and `upper`) for the lower and upper tick boundaries +// using `p.getTick`. +// 2. Calculate the fee growth below the lower tick using `getFeeGrowthBelowX128`. +// 3. Calculate the fee growth above the upper tick using `getFeeGrowthAboveX128`. +// 4. Subtract the fee growth below and above the range from the global fee growth values: +// feeGrowthInside = feeGrowthGlobal - feeGrowthBelow - feeGrowthAbove +// 5. Return the computed fee growth values for token 0 and token 1 within the range. +// +// Behavior: +// - The fee growth is isolated within the range `[tickLower, tickUpper]`. +// - The function ensures the calculations accurately consider the tick boundaries and the current tick position. +// +// Example: +// +// feeGrowth0, feeGrowth1 := pool.getFeeGrowthInside( +// 100, 200, 150, globalFeeGrowth0, globalFeeGrowth1, +// ) +// fmt.Println("Fee Growth Inside (Token 0):", feeGrowth0) +// fmt.Println("Fee Growth Inside (Token 1):", feeGrowth1) +func (p *Pool) getFeeGrowthInside( tickLower int32, tickUpper int32, tickCurrent int32, @@ -70,11 +178,53 @@ func (p *Pool) calculateFeeGrowthInside( return feeGrowthInside0X128, feeGrowthInside1X128 } -// tickUpdate updates a tick's state and returns whether the tick was flipped. +// tickUpdate updates the state of a specific tick. +// +// This function applies a given liquidity change (liquidityDelta) to the specified tick, updates +// the fee growth values if necessary, and adjusts the net liquidity based on whether the tick +// is an upper or lower boundary. It also verifies that the total liquidity does not exceed the +// maximum allowed value and ensures the net liquidity stays within the valid int128 range. +// +// Parameters: +// - tick: int32, the index of the tick to update. +// - tickCurrent: int32, the current active tick index. +// - liquidityDelta: *i256.Int, the amount of liquidity to add or remove. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth value for token 0. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth value for token 1. +// - upper: bool, indicates if this is the upper boundary (true for upper, false for lower). +// - maxLiquidity: *u256.Uint, the maximum allowed liquidity. +// +// Returns: +// - flipped: bool, indicates if the tick's initialization state has changed. +// (e.g., liquidity transitioning from zero to non-zero, or vice versa) +// +// Workflow: +// 1. Nil input values are replaced with zero. +// 2. The function retrieves the tick information for the specified tick index. +// 3. Applies the liquidityDelta to compute the new total liquidity (liquidityGross). +// - If the total liquidity exceeds the maximum allowed value, the function panics. +// 4. Checks whether the tick's initialized state has changed and sets the `flipped` flag. +// 5. If the tick was previously uninitialized and its index is less than or equal to the current tick, +// the fee growth values are initialized to the current global values. +// 6. Updates the tick's net liquidity: +// - For an upper boundary, it subtracts liquidityDelta. +// - For a lower boundary, it adds liquidityDelta. +// - Ensures the net liquidity remains within the int128 range using `checkOverFlowInt128`. +// 7. Updates the tick's state with the new values. +// 8. Returns whether the tick's initialized state has flipped. +// +// Panic Conditions: +// - The total liquidity (liquidityGross) exceeds the maximum allowed liquidity (maxLiquidity). +// - The net liquidity (liquidityNet) exceeds the int128 range. +// +// Example: +// +// flipped := pool.tickUpdate(10, 5, liquidityDelta, feeGrowth0, feeGrowth1, true, maxLiquidity) +// fmt.Println("Tick flipped:", flipped) func (p *Pool) tickUpdate( tick int32, tickCurrent int32, - liquidityDelta *i256.Int, // int128 + liquidityDelta *i256.Int, feeGrowthGlobal0X128 *u256.Uint, feeGrowthGlobal1X128 *u256.Uint, upper bool, @@ -84,15 +234,15 @@ func (p *Pool) tickUpdate( feeGrowthGlobal0X128 = feeGrowthGlobal0X128.NilToZero() feeGrowthGlobal1X128 = feeGrowthGlobal1X128.NilToZero() - thisTick := p.getTick(tick) + tickInfo := p.getTick(tick) - liquidityGrossBefore := thisTick.liquidityGross + liquidityGrossBefore := tickInfo.liquidityGross.Clone() liquidityGrossAfter := liquidityMathAddDelta(liquidityGrossBefore, liquidityDelta) if !(liquidityGrossAfter.Lte(maxLiquidity)) { panic(addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("tick.gno__tickUpdate() || liquidityGrossAfter(%s) overflows maxLiquidity(%s)", liquidityGrossAfter.ToString(), maxLiquidity.ToString()), + ufmt.Sprintf("liquidityGrossAfter(%s) overflows maxLiquidity(%s)", liquidityGrossAfter.ToString(), maxLiquidity.ToString()), )) } @@ -100,22 +250,23 @@ func (p *Pool) tickUpdate( if liquidityGrossBefore.IsZero() { if tick <= tickCurrent { - thisTick.feeGrowthOutside0X128 = feeGrowthGlobal0X128 - thisTick.feeGrowthOutside1X128 = feeGrowthGlobal1X128 + tickInfo.feeGrowthOutside0X128 = feeGrowthGlobal0X128.Clone() + tickInfo.feeGrowthOutside1X128 = feeGrowthGlobal1X128.Clone() } - - thisTick.initialized = true + tickInfo.initialized = true } - thisTick.liquidityGross = liquidityGrossAfter + tickInfo.liquidityGross = liquidityGrossAfter.Clone() if upper { - thisTick.liquidityNet = i256.Zero().Sub(thisTick.liquidityNet, liquidityDelta) + tickInfo.liquidityNet = i256.Zero().Sub(tickInfo.liquidityNet, liquidityDelta) + checkOverFlowInt128(tickInfo.liquidityNet) } else { - thisTick.liquidityNet = i256.Zero().Add(thisTick.liquidityNet, liquidityDelta) + tickInfo.liquidityNet = i256.Zero().Add(tickInfo.liquidityNet, liquidityDelta) + checkOverFlowInt128(tickInfo.liquidityNet) } - p.ticks[tick] = thisTick + p.setTick(tick, tickInfo) return flipped } @@ -136,47 +287,97 @@ func (p *Pool) tickCross( return thisTick.liquidityNet.Clone() } -// getTick returns a tick's state. +// setTick updates the tick data for the specified tick index in the pool. +func (p *Pool) setTick(tick int32, newTickInfo TickInfo) { + p.ticks[tick] = newTickInfo +} + +// getTick retrieves the TickInfo associated with the specified tick index from the pool. +// If the TickInfo contains any nil fields, they are replaced with zero values using valueOrZero. +// +// Parameters: +// - tick: The tick index (int32) for which the TickInfo is to be retrieved. +// +// Behavior: +// - Retrieves the TickInfo for the given tick from the pool's tick map. +// - Ensures that all fields of TickInfo are non-nil by calling valueOrZero, which replaces nil values with zero. +// - Returns the updated TickInfo. +// +// Returns: +// - TickInfo: The tick data with all fields guaranteed to have valid values (nil fields are set to zero). +// +// Use Case: +// This function ensures the retrieved tick data is always valid and safe for further operations, +// such as calculations or updates, by sanitizing nil fields in the TickInfo structure. func (p *Pool) getTick(tick int32) TickInfo { tickInfo := p.ticks[tick] - tickInfo.init() - + tickInfo.valueOrZero() return tickInfo } -// receiver getters +// GetTickLiquidityGross returns the gross liquidity for the specified tick. func (p *Pool) GetTickLiquidityGross(tick int32) *u256.Uint { return p.mustGetTick(tick).liquidityGross } +// GetTickLiquidityNet returns the net liquidity for the specified tick. func (p *Pool) GetTickLiquidityNet(tick int32) *i256.Int { return p.mustGetTick(tick).liquidityNet } +// GetTickFeeGrowthOutside0X128 returns the fee growth outside the tick for token 0. func (p *Pool) GetTickFeeGrowthOutside0X128(tick int32) *u256.Uint { return p.mustGetTick(tick).feeGrowthOutside0X128 } +// GetTickFeeGrowthOutside1X128 returns the fee growth outside the tick for token 1. func (p *Pool) GetTickFeeGrowthOutside1X128(tick int32) *u256.Uint { return p.mustGetTick(tick).feeGrowthOutside1X128 } +// GetTickCumulativeOutside returns the cumulative liquidity outside the tick. func (p *Pool) GetTickCumulativeOutside(tick int32) int64 { return p.mustGetTick(tick).tickCumulativeOutside } +// GetTickSecondsPerLiquidityOutsideX128 returns the seconds per liquidity outside the tick. func (p *Pool) GetTickSecondsPerLiquidityOutsideX128(tick int32) *u256.Uint { return p.mustGetTick(tick).secondsPerLiquidityOutsideX128 } +// GetTickSecondsOutside returns the seconds outside the tick. func (p *Pool) GetTickSecondsOutside(tick int32) uint32 { return p.mustGetTick(tick).secondsOutside } +// GetTickInitialized returns whether the tick is initialized. func (p *Pool) GetTickInitialized(tick int32) bool { return p.mustGetTick(tick).initialized } +// mustGetTick retrieves the TickInfo for a specific tick, panicking if the tick does not exist. +// +// This function ensures that the requested tick data exists in the pool's tick mapping. +// If the tick does not exist, it panics with an appropriate error message. +// +// Parameters: +// - tick: int32, the index of the tick to retrieve. +// +// Returns: +// - TickInfo: The information associated with the specified tick. +// +// Behavior: +// - Checks if the tick exists in the pool's tick mapping (`p.ticks`). +// - If the tick exists, it returns the corresponding `TickInfo`. +// - If the tick does not exist, the function panics with a descriptive error. +// +// Panic Conditions: +// - The specified tick does not exist in the pool's mapping. +// +// Example: +// +// tickInfo := pool.mustGetTick(10) +// fmt.Println("Tick Info:", tickInfo) func (p *Pool) mustGetTick(tick int32) TickInfo { tickInfo, exist := p.ticks[tick] if !exist { diff --git a/pool/tick_bitmap.gno b/pool/tick_bitmap.gno index 5c4dbf0c..53c9cc41 100644 --- a/pool/tick_bitmap.gno +++ b/pool/tick_bitmap.gno @@ -1,8 +1,6 @@ package pool import ( - "gno.land/p/demo/ufmt" - plp "gno.land/p/gnoswap/pool" u256 "gno.land/p/gnoswap/uint256" @@ -12,23 +10,45 @@ import ( func tickBitmapPosition(tick int32) (int16, uint8) { wordPos := int16(tick >> 8) // tick / 256 bitPos := uint8(tick % 256) - return wordPos, bitPos } -// tickBitmapFlipTick flips tthe bit corresponding to the given tick -// in the pool's tick bitmap. +// tickBitmapFlipTick flips the state of a tick in the tick bitmap. +// +// This function toggles the "initialized" state of a tick in the tick bitmap. +// It ensures that the tick aligns with the specified tick spacing and then +// flips the corresponding bit in the bitmap representation. +// +// Parameters: +// - tick: int32, the tick index to toggle. +// - tickSpacing: int32, the spacing between valid ticks. +// The tick must align with this spacing. +// +// Workflow: +// 1. Validates that the `tick` aligns with `tickSpacing` using `checkTickSpacing`. +// 2. Computes the position of the bit in the tick bitmap: +// - `wordPos`: Determines which word in the bitmap contains the bit. +// - `bitPos`: Identifies the position of the bit within the word. +// 3. Creates a bitmask using `Lsh` (Left Shift) to target the bit at `bitPos`. +// 4. Toggles (flips) the bit using XOR with the current value of the tick bitmap. +// 5. Updates the tick bitmap with the modified word. +// +// Behavior: +// - If the bit is `0` (uninitialized), it will be flipped to `1` (initialized). +// - If the bit is `1` (initialized), it will be flipped to `0` (uninitialized). +// +// Example: +// +// pool.tickBitmapFlipTick(120, 60) +// // This flips the bit for tick 120 with a tick spacing of 60. +// +// Notes: +// - The `tick` must be divisible by `tickSpacing`. If not, the function will panic. func (p *Pool) tickBitmapFlipTick( tick int32, tickSpacing int32, ) { - if tick%tickSpacing != 0 { - panic(addDetailToError( - errInvalidTickAndTickSpacing, - ufmt.Sprintf("tick_bitmap.gno__tickBitmapFlipTick() || tick(%d) MOD tickSpacing(%d) != 0(%d)", tick, tickSpacing, tick%tickSpacing), - )) - } - + checkTickSpacing(tick, tickSpacing) wordPos, bitPos := tickBitmapPosition(tick / tickSpacing) mask := new(u256.Uint).Lsh(u256.One(), uint(bitPos)) diff --git a/pool/tick_test.gno b/pool/tick_test.gno index 34fb95ab..32fef15e 100644 --- a/pool/tick_test.gno +++ b/pool/tick_test.gno @@ -130,8 +130,9 @@ func TestCalculateFeeGrowthInside(t *testing.T) { want0: "13", want1: "12", preconditions: func() { - pool.setTick( + setTick( t, + pool, 2, u256.NewUint(2), u256.NewUint(3), @@ -154,9 +155,10 @@ func TestCalculateFeeGrowthInside(t *testing.T) { want0: "13", want1: "12", preconditions: func() { - pool.deleteTick(t, 2) // delete tick from previous test - pool.setTick( + deleteTick(t, pool, 2) // delete tick from previous test + setTick( t, + pool, -2, u256.NewUint(2), u256.NewUint(3), @@ -180,8 +182,9 @@ func TestCalculateFeeGrowthInside(t *testing.T) { want1: "11", preconditions: func() { // we already have tick -2 - pool.setTick( + setTick( t, + pool, 2, u256.NewUint(4), u256.NewUint(1), @@ -204,10 +207,11 @@ func TestCalculateFeeGrowthInside(t *testing.T) { want0: "16", want1: "13", preconditions: func() { - pool.deleteTick(t, 2) - pool.deleteTick(t, -2) - pool.setTick( + deleteTick(t, pool, 2) + deleteTick(t, pool, -2) + setTick( t, + pool, -2, u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639932"), // max uint256 - 3 u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639933"), // max uint256 - 2 @@ -218,8 +222,9 @@ func TestCalculateFeeGrowthInside(t *testing.T) { 0, true, ) - pool.setTick( + setTick( t, + pool, 2, u256.NewUint(3), u256.NewUint(5), @@ -239,7 +244,7 @@ func TestCalculateFeeGrowthInside(t *testing.T) { if tt.preconditions != nil { tt.preconditions() } - got0, got1 := pool.calculateFeeGrowthInside( + got0, got1 := pool.getFeeGrowthInside( tt.tickLower, tt.tickUpper, tt.tickCurrent, @@ -247,7 +252,7 @@ func TestCalculateFeeGrowthInside(t *testing.T) { tt.feeGrowthGlobal1X128, ) if got0.ToString() != tt.want0 || got1.ToString() != tt.want1 { - t.Errorf("calculateFeeGrowthInside() = (%v, %v), want (%v, %v)", + t.Errorf("getFeeGrowthInside() = (%v, %v), want (%v, %v)", got0.ToString(), got1.ToString(), tt.want0, tt.want1) } }) @@ -287,7 +292,7 @@ func TestTickUpdate(t *testing.T) { { name: "does not flip from nonzero to greater nonzero", preconditions: func() { - pool.deleteTick(t, 0) + deleteTick(t, pool, 0) pool.tickUpdate(0, 0, i256.One(), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) }, tick: 0, @@ -302,7 +307,7 @@ func TestTickUpdate(t *testing.T) { { name: "flips from nonzero to zero", preconditions: func() { - pool.deleteTick(t, 0) + deleteTick(t, pool, 0) pool.tickUpdate(0, 0, i256.One(), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) }, tick: 0, @@ -317,7 +322,7 @@ func TestTickUpdate(t *testing.T) { { name: "does not flip from nonzero to lesser nonzero", preconditions: func() { - pool.deleteTick(t, 0) + deleteTick(t, pool, 0) pool.tickUpdate(0, 0, i256.NewInt(2), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) }, tick: 0, @@ -332,7 +337,7 @@ func TestTickUpdate(t *testing.T) { { name: "reverts if total liquidity gross is greater than max", preconditions: func() { - pool.deleteTick(t, 0) + deleteTick(t, pool, 0) pool.tickUpdate(0, 0, i256.NewInt(2), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) pool.tickUpdate(0, 0, i256.One(), u256.Zero(), u256.Zero(), true, u256.NewUint(3)) @@ -384,8 +389,8 @@ func TestTickUpdate(t *testing.T) { { name: "assumes all growth happens below ticks lte current tick", preconditions: func() { - pool.deleteTick(t, 0) - pool.deleteTick(t, 1) + deleteTick(t, pool, 0) + deleteTick(t, pool, 1) pool.tickUpdate(1, 1, i256.One(), u256.One(), u256.NewUint(2), false, u256.MustFromDecimal("340282366920938463463374607431768211455")) }, tick: 0, @@ -574,8 +579,125 @@ func TestGetTick(t *testing.T) { } } -func (pool *Pool) setTick( +func TestGetFeeGrowthBelowX128(t *testing.T) { + // Setup test data + globalFeeGrowth0 := u256.NewUint(1000) // Global fee growth for token 0 + globalFeeGrowth1 := u256.NewUint(2000) // Global fee growth for token 1 + + lowerTick := TickInfo{ + feeGrowthOutside0X128: u256.NewUint(300), // fee growth outside for token 0 + feeGrowthOutside1X128: u256.NewUint(500), // fee growth outside for token 1 + } + + tests := []struct { + name string + tickLower int32 + tickCurrent int32 + expectedFeeGrowth0 *u256.Uint + expectedFeeGrowth1 *u256.Uint + }{ + { + name: "tickCurrent >= tickLower - Return feeGrowthOutside directly", + tickLower: 100, + tickCurrent: 100, + expectedFeeGrowth0: lowerTick.feeGrowthOutside0X128, + expectedFeeGrowth1: lowerTick.feeGrowthOutside1X128, + }, + { + name: "tickCurrent > tickLower - Return feeGrowthOutside directly", + tickLower: 50, + tickCurrent: 100, + expectedFeeGrowth0: lowerTick.feeGrowthOutside0X128, + expectedFeeGrowth1: lowerTick.feeGrowthOutside1X128, + }, + { + name: "tickCurrent < tickLower - Subtract feeGrowthOutside from global", + tickLower: 100, + tickCurrent: 50, + expectedFeeGrowth0: u256.NewUint(700), // 1000 - 300 + expectedFeeGrowth1: u256.NewUint(1500), // 2000 - 500 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call the function + feeGrowth0, feeGrowth1 := getFeeGrowthBelowX128( + tt.tickLower, tt.tickCurrent, + globalFeeGrowth0, globalFeeGrowth1, + lowerTick, + ) + + // Assertions + uassert.True(t, feeGrowth0.Eq(tt.expectedFeeGrowth0), + "Expected feeGrowth0: %s, got: %s", tt.expectedFeeGrowth0.ToString(), feeGrowth0.ToString()) + uassert.True(t, feeGrowth1.Eq(tt.expectedFeeGrowth1), + "Expected feeGrowth1: %s, got: %s", tt.expectedFeeGrowth1.ToString(), feeGrowth1.ToString()) + }) + } +} + +func TestGetFeeGrowthAboveX128(t *testing.T) { + // Setup test data + globalFeeGrowth0 := u256.NewUint(1000) // Global fee growth for token 0 + globalFeeGrowth1 := u256.NewUint(2000) // Global fee growth for token 1 + + upperTick := TickInfo{ + feeGrowthOutside0X128: u256.NewUint(300), // Fee growth outside for token 0 + feeGrowthOutside1X128: u256.NewUint(500), // Fee growth outside for token 1 + } + + tests := []struct { + name string + tickUpper int32 + tickCurrent int32 + expectedFeeGrowth0 *u256.Uint + expectedFeeGrowth1 *u256.Uint + }{ + { + name: "tickCurrent < tickUpper - Return feeGrowthOutside directly", + tickUpper: 100, + tickCurrent: 50, + expectedFeeGrowth0: upperTick.feeGrowthOutside0X128, // 300 + expectedFeeGrowth1: upperTick.feeGrowthOutside1X128, // 500 + }, + { + name: "tickCurrent >= tickUpper - Subtract feeGrowthOutside from global", + tickUpper: 100, + tickCurrent: 150, + expectedFeeGrowth0: u256.NewUint(700), // 1000 - 300 + expectedFeeGrowth1: u256.NewUint(1500), // 2000 - 500 + }, + { + name: "tickCurrent == tickUpper - Subtract feeGrowthOutside from global", + tickUpper: 100, + tickCurrent: 100, + expectedFeeGrowth0: u256.NewUint(700), // 1000 - 300 + expectedFeeGrowth1: u256.NewUint(1500), // 2000 - 500 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call the function + feeGrowth0, feeGrowth1 := getFeeGrowthAboveX128( + tt.tickUpper, tt.tickCurrent, + globalFeeGrowth0, globalFeeGrowth1, + upperTick, + ) + + // Assertions + uassert.True(t, feeGrowth0.Eq(tt.expectedFeeGrowth0), + "Expected feeGrowth0: %s, got: %s", tt.expectedFeeGrowth0.ToString(), feeGrowth0.ToString()) + uassert.True(t, feeGrowth1.Eq(tt.expectedFeeGrowth1), + "Expected feeGrowth1: %s, got: %s", tt.expectedFeeGrowth1.ToString(), feeGrowth1.ToString()) + }) + } +} + +func setTick( t *testing.T, + pool *Pool, tick int32, feeGrowthOutside0X128 *u256.Uint, feeGrowthOutside1X128 *u256.Uint, @@ -589,7 +711,7 @@ func (pool *Pool) setTick( t.Helper() info := pool.ticks[tick] - info.init() + info.valueOrZero() info.feeGrowthOutside0X128 = feeGrowthOutside0X128 info.feeGrowthOutside1X128 = feeGrowthOutside1X128 @@ -603,8 +725,7 @@ func (pool *Pool) setTick( pool.ticks[tick] = info } -func (pool *Pool) deleteTick(t *testing.T, tick int32) { +func deleteTick(t *testing.T, pool *Pool, tick int32) { t.Helper() - delete(pool.ticks, tick) } diff --git a/pool/type.gno b/pool/type.gno index 0135c34e..2cca5dbc 100644 --- a/pool/type.gno +++ b/pool/type.gno @@ -227,7 +227,34 @@ type PositionInfo struct { tokensOwed1 *u256.Uint } -func (p *PositionInfo) init() { +// valueOrZero initializes nil fields in PositionInfo to zero. +// +// This function ensures that all numeric fields in the PositionInfo struct are not nil. +// If a field is nil, it is replaced with a zero value, maintaining consistency and preventing +// potential null pointer issues during calculations. +// +// Fields affected: +// - liquidity: The liquidity amount associated with the position. +// - feeGrowthInside0LastX128: Fee growth for token 0 inside the tick range, last recorded value. +// - feeGrowthInside1LastX128: Fee growth for token 1 inside the tick range, last recorded value. +// - tokensOwed0: The amount of token 0 owed to the position owner. +// - tokensOwed1: The amount of token 1 owed to the position owner. +// +// Behavior: +// - If a field is nil, it is set to its equivalent zero value. +// - If a field already has a value, it remains unchanged. +// +// Example: +// +// position := &PositionInfo{} +// position.valueOrZero() +// fmt.Println(position.liquidity) // Output: 0 +// +// Notes: +// - This function is useful for ensuring numeric fields are properly initialized +// before performing operations or calculations. +// - Prevents runtime errors caused by nil values. +func (p *PositionInfo) valueOrZero() { p.liquidity = p.liquidity.NilToZero() p.feeGrowthInside0LastX128 = p.feeGrowthInside0LastX128.NilToZero() p.feeGrowthInside1LastX128 = p.feeGrowthInside1LastX128.NilToZero() @@ -259,7 +286,26 @@ type TickInfo struct { initialized bool // whether the tick is initialized } -func (t *TickInfo) init() { +// valueOrZero ensures that all fields of TickInfo are valid by setting nil fields to zero, +// while retaining existing values if they are not nil. +// This function updates the TickInfo struct to replace any nil values in its fields +// with their respective zero values, ensuring data consistency. +// +// Behavior: +// - If a field is nil, it is replaced with its zero value. +// - If a field already has a valid value, the value remains unchanged. +// +// Fields: +// - liquidityGross: Gross liquidity for the tick, set to zero if nil, otherwise retains its value. +// - liquidityNet: Net liquidity for the tick, set to zero if nil, otherwise retains its value. +// - feeGrowthOutside0X128: Accumulated fee growth for token0 outside the tick, set to zero if nil, otherwise retains its value. +// - feeGrowthOutside1X128: Accumulated fee growth for token1 outside the tick, set to zero if nil, otherwise retains its value. +// - secondsPerLiquidityOutsideX128: Time per liquidity outside the tick, set to zero if nil, otherwise retains its value. +// +// Use Case: +// This function ensures all numeric fields in TickInfo are non-nil and have valid values, +// preventing potential runtime errors caused by nil values during operations like arithmetic or comparisons. +func (t *TickInfo) valueOrZero() { t.liquidityGross = t.liquidityGross.NilToZero() t.liquidityNet = t.liquidityNet.NilToZero() t.feeGrowthOutside0X128 = t.feeGrowthOutside0X128.NilToZero() @@ -327,71 +373,71 @@ func newPool(poolInfo *createPoolParams) *Pool { } } -func (p *Pool) GetToken0Path() string { +func (p *Pool) Token0Path() string { return p.token0Path } -func (p *Pool) GetToken1Path() string { +func (p *Pool) Token1Path() string { return p.token1Path } -func (p *Pool) GetFee() uint32 { +func (p *Pool) Fee() uint32 { return p.fee } -func (p *Pool) GetBalanceToken0() *u256.Uint { +func (p *Pool) BalanceToken0() *u256.Uint { return p.balances.token0 } -func (p *Pool) GetBalanceToken1() *u256.Uint { +func (p *Pool) BalanceToken1() *u256.Uint { return p.balances.token1 } -func (p *Pool) GetTickSpacing() int32 { +func (p *Pool) TickSpacing() int32 { return p.tickSpacing } -func (p *Pool) GetMaxLiquidityPerTick() *u256.Uint { +func (p *Pool) MaxLiquidityPerTick() *u256.Uint { return p.maxLiquidityPerTick } -func (p *Pool) GetSlot0() Slot0 { +func (p *Pool) Slot0() Slot0 { return p.slot0 } -func (p *Pool) GetSlot0SqrtPriceX96() *u256.Uint { +func (p *Pool) Slot0SqrtPriceX96() *u256.Uint { return p.slot0.sqrtPriceX96 } -func (p *Pool) GetSlot0Tick() int32 { +func (p *Pool) Slot0Tick() int32 { return p.slot0.tick } -func (p *Pool) GetSlot0FeeProtocol() uint8 { +func (p *Pool) Slot0FeeProtocol() uint8 { return p.slot0.feeProtocol } -func (p *Pool) GetSlot0Unlocked() bool { +func (p *Pool) Slot0Unlocked() bool { return p.slot0.unlocked } -func (p *Pool) GetFeeGrowthGlobal0X128() *u256.Uint { +func (p *Pool) FeeGrowthGlobal0X128() *u256.Uint { return p.feeGrowthGlobal0X128 } -func (p *Pool) GetFeeGrowthGlobal1X128() *u256.Uint { +func (p *Pool) FeeGrowthGlobal1X128() *u256.Uint { return p.feeGrowthGlobal1X128 } -func (p *Pool) GetProtocolFeesToken0() *u256.Uint { +func (p *Pool) ProtocolFeesToken0() *u256.Uint { return p.protocolFees.token0 } -func (p *Pool) GetProtocolFeesToken1() *u256.Uint { +func (p *Pool) ProtocolFeesToken1() *u256.Uint { return p.protocolFees.token1 } -func (p *Pool) GetLiquidity() *u256.Uint { +func (p *Pool) Liquidity() *u256.Uint { return p.liquidity } diff --git a/pool/utils.gno b/pool/utils.gno index e4511621..c731d361 100644 --- a/pool/utils.gno +++ b/pool/utils.gno @@ -4,24 +4,129 @@ import ( "std" "gno.land/p/demo/ufmt" - pusers "gno.land/p/demo/users" - + i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" ) -func safeConvertToUint64(value *u256.Uint) (uint64, error) { +// safeConvertToUint64 safely converts a *u256.Uint value to a uint64, ensuring no overflow. +// +// This function attempts to convert the given *u256.Uint value to a uint64. If the value exceeds +// the maximum allowable range for uint64 (`2^64 - 1`), it triggers a panic with a descriptive error message. +// +// Parameters: +// - value (*u256.Uint): The unsigned 256-bit integer to be converted. +// +// Returns: +// - uint64: The converted value if it falls within the uint64 range. +// +// Panics: +// - If the `value` exceeds the range of uint64, the function will panic with an error indicating +// the overflow and the original value. +// +// Notes: +// - This function uses the `Uint64WithOverflow` method to detect overflow during the conversion. +// - It is essential to validate large values before calling this function to avoid unexpected panics. +// +// Example: +// safeValue := safeConvertToUint64(u256.MustFromDecimal("18446744073709551615")) // Valid conversion +// safeConvertToUint64(u256.MustFromDecimal("18446744073709551616")) // Panics due to overflow +func safeConvertToUint64(value *u256.Uint) uint64 { res, overflow := value.Uint64WithOverflow() if overflow { - return 0, ufmt.Errorf( + panic(ufmt.Sprintf( "%v: amount(%s) overflows uint64 range", - errOutOfRange, value.ToString(), - ) + errOutOfRange, value.ToString())) } + return res +} + +// safeConvertToInt128 safely converts a *u256.Uint value to an *i256.Int, ensuring it does not exceed the int128 range. +// +// This function converts an unsigned 256-bit integer (*u256.Uint) into a signed 256-bit integer (*i256.Int). +// It checks whether the resulting value falls within the valid range of int128 (`-2^127` to `2^127 - 1`). +// If the value exceeds the maximum allowable int128 range, it triggers a panic with a descriptive error message. +// +// Parameters: +// - value (*u256.Uint): The unsigned 256-bit integer to be converted. +// +// Returns: +// - *i256.Int: The converted value if it falls within the int128 range. +// +// Panics: +// - If the converted value exceeds the maximum int128 value (`2^127 - 1`), the function will panic with an +// error message indicating the overflow and the original value. +// +// Notes: +// - The function uses `i256.FromUint256` to perform the conversion. +// - The constant `MAX_INT128` is used to define the upper bound of the int128 range (`170141183460469231731687303715884105727`). +// +// Example: +// validInt128 := safeConvertToInt128(u256.MustFromDecimal("170141183460469231731687303715884105727")) // Valid conversion +// safeConvertToInt128(u256.MustFromDecimal("170141183460469231731687303715884105728")) // Panics due to overflow +func safeConvertToInt128(value *u256.Uint) *i256.Int { + liquidityDelta := i256.FromUint256(value) + if liquidityDelta.Gt(i256.MustFromDecimal(consts.MAX_INT128)) { + panic(ufmt.Sprintf( + "%v: amount(%s) overflows int128 range", + errOverFlow, value.ToString())) + } + return liquidityDelta +} - return res, nil +// toUint128 ensures a *u256.Uint value fits within the uint128 range. +// +// This function validates that the given `value` is properly initialized and checks whether +// it exceeds the maximum value of uint128. If the value exceeds the uint128 range, +// it applies a masking operation to truncate the value to fit within the uint128 limit. +// +// Parameters: +// - value: *u256.Uint, the value to be checked and possibly truncated. +// +// Returns: +// - *u256.Uint: A value guaranteed to fit within the uint128 range. +// +// Notes: +// - The mask ensures that only the lower 128 bits of the value are retained. +// - If the input value is already within the uint128 range, it remains unchanged. +// - MAX_UINT128 is a constant representing `2^128 - 1`. +func toUint128(value *u256.Uint) *u256.Uint { + assertOnlyInitializedUint256(value) + if value.Gt(u256.MustFromDecimal(consts.MAX_UINT128)) { + mask := new(u256.Uint).Lsh(u256.One(), consts.Q128_RESOLUTION) + mask = mask.Sub(mask, u256.One()) + value = value.And(value, mask) + } + return value } +// a2u converts a std.Address to a pusers.AddressOrName, ensuring the input address is valid. +// +// This function takes a `std.Address` and verifies its validity. If the address is invalid, +// the function triggers a panic with an appropriate error message. For valid addresses, +// it performs the conversion to `pusers.AddressOrName`. +// +// Parameters: +// - addr (std.Address): The input address to be converted. +// +// Returns: +// - pusers.AddressOrName: The converted address, wrapped as a `pusers.AddressOrName` type. +// +// Panics: +// - If the provided `addr` is invalid, the function will panic with an error indicating +// the invalid address. +// +// Notes: +// - The function relies on the `addr.IsValid()` method to determine the validity of the input address. +// - It uses `addDetailToError` to provide additional context for the error message when an invalid +// address is encountered. +// +// Example: +// converted := a2u(std.Address("validAddress")) // Successful conversion +// a2u(std.Address("")) // Panics due to invalid address func a2u(addr std.Address) pusers.AddressOrName { if !addr.IsValid() { panic(addDetailToError( @@ -32,23 +137,60 @@ func a2u(addr std.Address) pusers.AddressOrName { return pusers.AddressOrName(addr) } +// u256Min returns the smaller of two *u256.Uint values. +// +// This function compares two unsigned 256-bit integers and returns the smaller of the two. +// If `num1` is less than `num2`, it returns `num1`; otherwise, it returns `num2`. +// +// Parameters: +// - num1 (*u256.Uint): The first unsigned 256-bit integer. +// - num2 (*u256.Uint): The second unsigned 256-bit integer. +// +// Returns: +// - *u256.Uint: The smaller of `num1` and `num2`. +// +// Notes: +// - This function uses the `Lt` (less than) method of `*u256.Uint` to perform the comparison. +// - The function assumes both input values are non-nil. If nil inputs are possible in the usage context, +// additional validation may be needed. +// +// Example: +// smaller := u256Min(u256.MustFromDecimal("10"), u256.MustFromDecimal("20")) // Returns 10 +// smaller := u256Min(u256.MustFromDecimal("30"), u256.MustFromDecimal("20")) // Returns 20 func u256Min(num1, num2 *u256.Uint) *u256.Uint { if num1.Lt(num2) { return num1 } - return num2 } -func isUserCall() bool { - return std.PrevRealm().IsUser() +// derivePkgAddr derives the Realm address from it's pkgPath parameter +func derivePkgAddr(pkgPath string) std.Address { + return std.DerivePkgAddr(pkgPath) } -func getPrev() (string, string) { - prev := std.PrevRealm() +// getPrevRealm returns object of the previous realm. +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrevAsString returns the address and package path of the previous realm. +func getPrevAsString() (string, string) { + prev := getPrevRealm() return prev.Addr().String(), prev.PkgPath() } +// isUserCall returns true if the caller is a user. +func isUserCall() bool { + return std.PrevRealm().IsUser() +} + +// checkTransferError checks transfer error. func checkTransferError(err error) { if err != nil { panic(addDetailToError( @@ -57,3 +199,79 @@ func checkTransferError(err error) { )) } } + +// checkOverFlowInt128 checks if the value overflows the int128 range. +func checkOverFlowInt128(value *i256.Int) { + if value.Gt(i256.MustFromDecimal(consts.MAX_INT128)) { + panic(ufmt.Sprintf( + "%v: amount(%s) overflows int128 range", + errOverFlow, value.ToString())) + } +} + +// checkTickSpacing checks if the tick is divisible by the tickSpacing. +func checkTickSpacing(tick, tickSpacing int32) { + if tick%tickSpacing != 0 { + panic(addDetailToError( + errInvalidTickAndTickSpacing, + ufmt.Sprintf("tick(%d) MOD tickSpacing(%d) != 0(%d)", tick, tickSpacing, tick%tickSpacing), + )) + } +} + +// assertOnlyValidAddress panics if the address is invalid. +func assertOnlyValidAddress(addr std.Address) { + if !addr.IsValid() { + panic(addDetailToError( + errInvalidAddress, + ufmt.Sprintf("(%s)", addr), + )) + } +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + +// assertOnlyPositionContract panics if the caller is not the position contract. +func assertOnlyPositionContract() { + caller := getPrevAddr() + if err := common.PositionOnly(caller); err != nil { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only position(%s) can call, called from %s", consts.POSITION_ADDR, caller.String()), + )) + } +} + +// assertOnlyInitializedUint256 panics if the value is nil. +func assertOnlyInitializedUint256(value *u256.Uint) { + if value == nil { + panic(addDetailToError( + errInvalidInput, + "value is nil", + )) + } +} + +// assertOnlyAdmin panics if the caller is not the admin. +func assertOnlyAdmin() { + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err) + } +} + +// assertOnlyGovernance panics if the caller is not the governance. +func assertOnlyGovernance() { + caller := getPrevAddr() + if err := common.GovernanceOnly(caller); err != nil { + panic(err) + } +} + +// assertOnlyRegistered panics if the token is not registered. +func assertOnlyRegistered(tokenPath string) { + common.MustRegistered(tokenPath) +} diff --git a/pool/utils_test.gno b/pool/utils_test.gno index 623cc4ae..e5457894 100644 --- a/pool/utils_test.gno +++ b/pool/utils_test.gno @@ -6,10 +6,12 @@ import ( "gno.land/p/demo/testutils" "gno.land/p/demo/uassert" - + pusers "gno.land/p/demo/users" + i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" "gno.land/r/demo/users" + "gno.land/r/gnoswap/v1/consts" ) func TestA2U(t *testing.T) { @@ -125,7 +127,7 @@ func TestIsUserCall(t *testing.T) { } } -func TestGetPrev(t *testing.T) { +func TestGetPrevAsString(t *testing.T) { tests := []struct { name string action func() (string, string) @@ -137,7 +139,7 @@ func TestGetPrev(t *testing.T) { action: func() (string, string) { userRealm := std.NewUserRealm(std.Address("user")) std.TestSetRealm(userRealm) - return getPrev() + return getPrevAsString() }, expectedAddr: "user", expectedPkgPath: "", @@ -147,7 +149,7 @@ func TestGetPrev(t *testing.T) { action: func() (string, string) { codeRealm := std.NewCodeRealm("gno.land/r/demo/realm") std.TestSetRealm(codeRealm) - return getPrev() + return getPrevAsString() }, expectedAddr: std.DerivePkgAddr("gno.land/r/demo/realm").String(), expectedPkgPath: "gno.land/r/demo/realm", @@ -162,3 +164,290 @@ func TestGetPrev(t *testing.T) { }) } } + +func TestSafeConvertToUint64(t *testing.T) { + tests := []struct { + name string + value *u256.Uint + wantRes uint64 + wantPanic bool + }{ + {"normal conversion", u256.NewUint(123), 123, false}, + {"overflow", u256.MustFromDecimal(consts.MAX_UINT128), 0, true}, + {"max uint64", u256.NewUint(1<<64 - 1), 1<<64 - 1, false}, + {"zero", u256.NewUint(0), 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.wantPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } + if tt.wantPanic { + t.Errorf("expected panic, but none occurred") + } + }() + + res := safeConvertToUint64(tt.value) + if res != tt.wantRes { + t.Errorf("safeConvertToUint64() = %v, want %v", res, tt.wantRes) + } + }) + } +} + +func TestSafeConvertToInt128(t *testing.T) { + tests := []struct { + name string + value string + wantRes string + wantPanic bool + }{ + {"normal conversion", "170141183460469231731687303715884105727", "170141183460469231731687303715884105727", false}, + {"overflow", "170141183460469231731687303715884105728", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.wantPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } + if tt.wantPanic { + t.Errorf("expected panic, but none occurred") + } + }() + + res := safeConvertToInt128(u256.MustFromDecimal(tt.value)) + if res.ToString() != tt.wantRes { + t.Errorf("safeConvertToUint64() = %v, want %v", res, tt.wantRes) + } + }) + } +} + +func TestA2u(t *testing.T) { + var ( + addr = std.Address("g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8c") + ) + + tests := []struct { + name string + input std.Address + expected pusers.AddressOrName + }{ + { + name: "Success - a2u", + input: addr, + expected: pusers.AddressOrName(addr), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := a2u(tc.input) + uassert.Equal(t, users.Resolve(got).String(), users.Resolve(tc.expected).String()) + }) + } +} + +func TestDerivePkgAddr(t *testing.T) { + var ( + pkgPath = "gno.land/r/gnoswap/v1/position" + ) + tests := []struct { + name string + input string + expected string + }{ + { + name: "Success - derivePkgAddr", + input: pkgPath, + expected: "g1q646ctzhvn60v492x8ucvyqnrj2w30cwh6efk5", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := derivePkgAddr(tc.input) + uassert.Equal(t, got.String(), tc.expected) + }) + } +} + +func TestGetPrevRealm(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prevRealm is User", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevRealm() + uassert.Equal(t, got.Addr().String(), tc.expected[0]) + uassert.Equal(t, got.PkgPath(), tc.expected[1]) + }) + } +} + +func TestGetPrevAddr(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected std.Address + }{ + { + name: "Success - prev Address is User", + originCaller: consts.ADMIN, + expected: "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevAddr() + uassert.Equal(t, got.String(), tc.expected.String()) + }) + } +} + +func TestCheckOverFlowInt128(t *testing.T) { + tests := []struct { + name string + input *i256.Int + shouldPanic bool + expected string + }{ + { + name: "Valid value within int128 range", + input: i256.MustFromDecimal("1"), + shouldPanic: false, + }, + { + name: "Edge case - MAX_INT128", + input: i256.MustFromDecimal(consts.MAX_INT128), + shouldPanic: false, + }, + { + name: "Overflow case - exceeds MAX_INT128", + input: i256.MustFromDecimal(consts.MAX_INT256), // 최대값 + 1 + shouldPanic: true, + expected: "[GNOSWAP-POOL-026] overflow: amount(170141183460469231731687303715884105728) overflows int128 range", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.shouldPanic { + t.Errorf("Unexpected panic: %v", r) + } + } else if tt.shouldPanic { + uassert.Equal(t, tt.expected, r) + t.Errorf("Expected panic but none occurred") + } + }() + checkOverFlowInt128(tt.input) + }) + } +} + +func TestCheckTickSpacing(t *testing.T) { + tests := []struct { + name string + tick int32 + tickSpacing int32 + shouldPanic bool + expected string + }{ + { + name: "Valid tick - divisible by tickSpacing", + tick: 120, + tickSpacing: 60, + shouldPanic: false, + }, + { + name: "Valid tick - zero tick", + tick: 0, + tickSpacing: 10, + shouldPanic: false, + }, + { + name: "Invalid tick - not divisible", + tick: 15, + tickSpacing: 10, + shouldPanic: true, + expected: "[GNOSWAP-POOL-022] invalid tick and tick spacing requested || tick(15) MOD tickSpacing(10) != 0(5)", + }, + { + name: "Invalid tick - negative tick", + tick: -35, + tickSpacing: 20, + shouldPanic: true, + expected: "[GNOSWAP-POOL-022] invalid tick and tick spacing requested || tick(-35) MOD tickSpacing(20) != 0(-15)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.shouldPanic { + t.Errorf("Unexpected panic: %v", r) + } + } else if tt.shouldPanic { + uassert.Equal(t, tt.expected, r) + } + }() + checkTickSpacing(tt.tick, tt.tickSpacing) + }) + } +} + +func TestAssertOnlyValidAddress(t *testing.T) { + tests := []struct { + name string + addr std.Address + expected bool + errorMsg string + }{ + { + name: "Success - valid address", + addr: consts.ADMIN, + expected: true, + }, + { + name: "Failure - invalid address", + addr: "g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8", // invalid length + expected: false, + errorMsg: "[GNOSWAP-POOL-023] invalid address || (g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyValidAddress(tc.addr) + }) + } else { + uassert.PanicsWithMessage(t, tc.errorMsg, func() { + assertOnlyValidAddress(tc.addr) + }) + } + }) + } +} diff --git a/position/_RPC_api.gno b/position/_RPC_api.gno index 10df9acf..5b1cd3d0 100644 --- a/position/_RPC_api.gno +++ b/position/_RPC_api.gno @@ -7,14 +7,11 @@ import ( "gno.land/p/demo/json" "gno.land/p/demo/ufmt" + i256 "gno.land/p/gnoswap/int256" "gno.land/r/gnoswap/v1/common" "gno.land/r/gnoswap/v1/consts" - - pl "gno.land/r/gnoswap/v1/pool" - - i256 "gno.land/p/gnoswap/int256" - "gno.land/r/gnoswap/v1/gnft" + pl "gno.land/r/gnoswap/v1/pool" ) type RpcPosition struct { @@ -69,7 +66,7 @@ func ApiGetPositions() string { } // STAT NODE - _stat := json.ObjectNode("", map[string]*json.Node{ + stat := json.ObjectNode("", map[string]*json.Node{ "height": json.NumberNode("height", float64(std.GetHeight())), "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), }) @@ -99,7 +96,7 @@ func ApiGetPositions() string { } node := json.ObjectNode("", map[string]*json.Node{ - "stat": _stat, + "stat": stat, "response": responses, }) @@ -390,7 +387,7 @@ func rpcMakePosition(lpTokenId uint64) RpcPosition { burned := isBurned(lpTokenId) pool := pl.GetPoolFromPoolPath(position.poolKey) - currentX96 := pool.GetSlot0SqrtPriceX96() + currentX96 := pool.Slot0SqrtPriceX96() lowerX96 := common.TickMathGetSqrtRatioAtTick(position.tickLower) upperX96 := common.TickMathGetSqrtRatioAtTick(position.tickUpper) @@ -439,12 +436,12 @@ func unclaimedFee(tokenId uint64) (*i256.Int, *i256.Int) { poolKey := positions[tokenId].poolKey pool := pl.GetPoolFromPoolPath(poolKey) - currentTick := pool.GetSlot0Tick() + currentTick := pool.Slot0Tick() - _feeGrowthGlobal0X128 := pool.GetFeeGrowthGlobal0X128() // u256 + _feeGrowthGlobal0X128 := pool.FeeGrowthGlobal0X128() // u256 feeGrowthGlobal0X128 := i256.FromUint256(_feeGrowthGlobal0X128) // i256 - _feeGrowthGlobal1X128 := pool.GetFeeGrowthGlobal1X128() // u256 + _feeGrowthGlobal1X128 := pool.FeeGrowthGlobal1X128() // u256 feeGrowthGlobal1X128 := i256.FromUint256(_feeGrowthGlobal1X128) // i256 _tickUpperFeeGrowthOutside0X128 := pool.GetTickFeeGrowthOutside0X128(tickUpper) // u256 diff --git a/position/liquidity_management.gno b/position/liquidity_management.gno index be926f91..d700bfb3 100644 --- a/position/liquidity_management.gno +++ b/position/liquidity_management.gno @@ -16,7 +16,7 @@ import ( func addLiquidity(params AddLiquidityParams) (*u256.Uint, *u256.Uint, *u256.Uint) { pool := pl.GetPoolFromPoolPath(params.poolKey) - sqrtPriceX96 := pool.GetSlot0SqrtPriceX96() + sqrtPriceX96 := pool.Slot0SqrtPriceX96() sqrtRatioAX96 := common.TickMathGetSqrtRatioAtTick(params.tickLower) sqrtRatioBX96 := common.TickMathGetSqrtRatioAtTick(params.tickUpper) diff --git a/position/position.gno b/position/position.gno index f33ffd0b..42503da3 100644 --- a/position/position.gno +++ b/position/position.gno @@ -189,7 +189,8 @@ func mint(params MintParams) (uint64, *u256.Uint, *u256.Uint, *u256.Uint) { nextId++ positionKey := positionKeyCompute(GetOrigPkgAddr(), params.tickLower, params.tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128 := pool.PositionFeeGrowthInside0LastX128(positionKey) + _feeGrowthInside1LastX128 := pool.PositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -320,7 +321,8 @@ func increaseLiquidity(params IncreaseLiquidityParams) (uint64, *u256.Uint, *u25 pool := pl.GetPoolFromPoolPath(position.poolKey) positionKey := positionKeyCompute(GetOrigPkgAddr(), position.tickLower, position.tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128 := pool.PositionFeeGrowthInside0LastX128(positionKey) + _feeGrowthInside1LastX128 := pool.PositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -446,7 +448,8 @@ func decreaseLiquidity(params DecreaseLiquidityParams) (uint64, *u256.Uint, *u25 verifyBurnedAmounts(burnedAmount0, burnedAmount1, params.amount0Min, params.amount1Min) positionKey := positionKeyCompute(GetOrigPkgAddr(), position.tickLower, position.tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128 := pool.PositionFeeGrowthInside0LastX128(positionKey) + _feeGrowthInside1LastX128 := pool.PositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -592,7 +595,8 @@ func Reposition( pool := pl.GetPoolFromPoolPath(position.poolKey) positionKey := positionKeyCompute(GetOrigPkgAddr(), tickLower, tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128 := pool.PositionFeeGrowthInside0LastX128(positionKey) + _feeGrowthInside1LastX128 := pool.PositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -676,7 +680,8 @@ func CollectFee(tokenId uint64, unwrapResult bool) (uint64, string, string, stri positionKey := positionKeyCompute(GetOrigPkgAddr(), position.tickLower, position.tickUpper) pool := pl.GetPoolFromPoolPath(position.poolKey) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128 := pool.PositionFeeGrowthInside0LastX128(positionKey) + _feeGrowthInside1LastX128 := pool.PositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString())