From 2e3de24e2b61997614523121aef7ff8fce02079a Mon Sep 17 00:00:00 2001 From: n3wbie Date: Wed, 18 Dec 2024 16:48:49 +0900 Subject: [PATCH] refactor: use avl.Tree in pool contract --- pool/_helper_test.gno | 13 +++-- pool/api.gno | 43 +++++++++----- pool/api_test.gno | 2 +- pool/getter.gno | 6 +- pool/getter_test.gno | 16 +++--- pool/pool.gno | 17 +++--- pool/pool_manager.gno | 82 ++++++++++++++++----------- pool/pool_manager_test.gno | 31 +--------- pool/pool_test.gno | 7 ++- pool/position.gno | 38 ++++++++----- pool/position_update.gno | 4 +- pool/position_update_test.gno | 4 +- pool/protocol_fee_withdrawal_test.gno | 3 +- pool/swap_test.gno | 17 +++--- pool/tick.gno | 30 +++++++--- pool/tick_bitmap.gno | 29 ++++++++-- pool/tick_bitmap_test.gno | 7 ++- pool/tick_test.gno | 41 +++++++------- pool/type.gno | 29 +++++----- pool/utils.gno | 2 +- 20 files changed, 236 insertions(+), 185 deletions(-) diff --git a/pool/_helper_test.gno b/pool/_helper_test.gno index ea719ff2c..34826edaa 100644 --- a/pool/_helper_test.gno +++ b/pool/_helper_test.gno @@ -13,6 +13,7 @@ import ( "gno.land/r/onbloc/qux" "gno.land/r/onbloc/usdc" + "gno.land/p/demo/avl" "gno.land/p/demo/testutils" "gno.land/p/demo/uassert" pusers "gno.land/p/demo/users" @@ -23,7 +24,7 @@ import ( const ( ugnotDenom string = "ugnot" - ugnotPath string = "gno.land/r/gnoswap/v1/pool:ugnot" + ugnotPath string = "ugnot" wugnotPath string = "gno.land/r/demo/wugnot" gnsPath string = "gno.land/r/gnoswap/v1/gns" barPath string = "gno.land/r/onbloc/bar" @@ -517,7 +518,7 @@ func ugnotDeposit(t *testing.T, addr std.Address, amount uint64) { // resetObject resets the object state(clear or make it default values) func resetObject(t *testing.T) { - pools = make(poolMap) + pools = avl.NewTree() slot0FeeProtocol = 0 poolCreationFee = 100_000_000 withdrawalFeeBPS = 100 @@ -570,11 +571,11 @@ func burnUsdc(addr pusers.AddressOrName) { func TestBeforeResetObject(t *testing.T) { // make some data - pools = make(poolMap) - pools["gno.land/r/gnoswap/v1/gns:gno.land/r/onbloc/usdc"] = &Pool{ + pools = avl.NewTree() + pools.Set("gno.land/r/gnoswap/v1/gns:gno.land/r/onbloc/usdc", &Pool{ token0Path: "gno.land/r/gnoswap/v1/gns", token1Path: "gno.land/r/onbloc/usdc", - } + }) slot0FeeProtocol = 1 poolCreationFee = 100_000_000 @@ -591,7 +592,7 @@ func TestBeforeResetObject(t *testing.T) { func TestResetObject(t *testing.T) { resetObject(t) - uassert.Equal(t, len(pools), 0) + uassert.Equal(t, pools.Size(), 0) uassert.Equal(t, slot0FeeProtocol, uint8(0)) uassert.Equal(t, poolCreationFee, uint64(100_000_000)) uassert.Equal(t, withdrawalFeeBPS, uint64(100)) diff --git a/pool/api.gno b/pool/api.gno index afe4d9fa6..32c15a7db 100644 --- a/pool/api.gno +++ b/pool/api.gno @@ -4,6 +4,7 @@ import ( b64 "encoding/base64" "gno.land/p/demo/json" + u256 "gno.land/p/gnoswap/uint256" "std" "strconv" @@ -81,14 +82,16 @@ type RpcPosition struct { func ApiGetPools() string { rpcPools := []RpcPool{} - for poolPath, _ := range pools { + pools.Iterate("", "", func(poolPath string, value interface{}) bool { rpcPool := rpcMakePool(poolPath) rpcPools = append(rpcPools, rpcPool) - } + + return false + }) responses := json.ArrayNode("", []*json.Node{}) for _, pool := range rpcPools { - _poolNode := json.ObjectNode("", map[string]*json.Node{ + poolNode := json.ObjectNode("", map[string]*json.Node{ "poolPath": json.StringNode("poolPath", pool.PoolPath), "token0Path": json.StringNode("token0Path", pool.Token0Path), "token1Path": json.StringNode("token1Path", pool.Token1Path), @@ -110,7 +113,7 @@ func ApiGetPools() string { "tickBitmaps": json.ObjectNode("tickBitmaps", makeRpcTickBitmapsJson(pool.TickBitmaps)), "positions": json.ArrayNode("positions", makeRpcPositionsArray(pool.Positions)), }) - responses.AppendArray(_poolNode) + responses.AppendArray(poolNode) } node := json.ObjectNode("", map[string]*json.Node{ @@ -122,8 +125,7 @@ func ApiGetPools() string { } func ApiGetPool(poolPath string) string { - _, exist := pools[poolPath] - if !exist { + if !pools.Has(poolPath) { return "" } rpcPool := rpcMakePool(poolPath) @@ -198,8 +200,11 @@ func rpcMakePool(poolPath string) RpcPool { rpcPool.Liquidity = pool.liquidity.ToString() rpcPool.Ticks = RpcTicks{} - for tick, tickInfo := range pool.ticks { - rpcPool.Ticks[tick] = RpcTickInfo{ + pool.ticks.Iterate("", "", func(tickStr string, iTickInfo interface{}) bool { + tick, _ := strconv.Atoi(tickStr) + tickInfo := iTickInfo.(TickInfo) + + rpcPool.Ticks[int32(tick)] = RpcTickInfo{ LiquidityGross: tickInfo.liquidityGross.ToString(), LiquidityNet: tickInfo.liquidityNet.ToString(), FeeGrowthOutside0X128: tickInfo.feeGrowthOutside0X128.ToString(), @@ -209,17 +214,22 @@ func rpcMakePool(poolPath string) RpcPool { SecondsOutside: tickInfo.secondsOutside, Initialized: tickInfo.initialized, } - } + + return false + }) rpcPool.TickBitmaps = RpcTickBitmaps{} - for tick, tickBitmap := range pool.tickBitmaps { - rpcPool.TickBitmaps[tick] = tickBitmap.ToString() - } + pool.tickBitmaps.Iterate("", "", func(tickStr string, iTickBitmap interface{}) bool { + tick, _ := strconv.Atoi(tickStr) + pool.setTickBitmap(int16(tick), iTickBitmap.(*u256.Uint)) + + return false + }) - Positions := pool.positions rpcPositions := []RpcPosition{} - for posKey, posInfo := range Positions { + pool.positions.Iterate("", "", func(posKey string, iPositionInfo interface{}) bool { owner, tickLower, tickUpper := posKeyDivide(posKey) + posInfo := iPositionInfo.(PositionInfo) rpcPositions = append(rpcPositions, RpcPosition{ Owner: owner, @@ -229,7 +239,10 @@ func rpcMakePool(poolPath string) RpcPool { Token0Owed: posInfo.tokensOwed0.ToString(), Token1Owed: posInfo.tokensOwed1.ToString(), }) - } + + return false + }) + rpcPool.Positions = rpcPositions return rpcPool diff --git a/pool/api_test.gno b/pool/api_test.gno index 58031e921..fcef1abe2 100644 --- a/pool/api_test.gno +++ b/pool/api_test.gno @@ -21,7 +21,7 @@ func TestInitTwoPools(t *testing.T) { // bar:baz CreatePool(barPath, bazPath, fee500, "130621891405341611593710811006") // tick 10000 - uassert.Equal(t, len(pools), 2) + uassert.Equal(t, pools.Size(), 2) } func TestApiGetPools(t *testing.T) { diff --git a/pool/getter.gno b/pool/getter.gno index 443516141..6c22f1d15 100644 --- a/pool/getter.gno +++ b/pool/getter.gno @@ -2,9 +2,11 @@ package pool func PoolGetPoolList() []string { poolPaths := []string{} - for poolPath, _ := range pools { + pools.Iterate("", "", func(poolPath string, _ interface{}) bool { poolPaths = append(poolPaths, poolPath) - } + + return false + }) return poolPaths } diff --git a/pool/getter_test.gno b/pool/getter_test.gno index 1772cc228..e0727669b 100644 --- a/pool/getter_test.gno +++ b/pool/getter_test.gno @@ -3,6 +3,8 @@ package pool import ( "testing" + "gno.land/p/demo/avl" + i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" ) @@ -35,8 +37,8 @@ func TestInitData(t *testing.T) { liquidity: u256.NewUint(1000000), } - mockTicks := Ticks{} - mockTicks[0] = TickInfo{ + mockTicks := avl.NewTree() + mockTicks.Set("0", TickInfo{ liquidityGross: u256.NewUint(1000000), liquidityNet: i256.NewInt(2000000), feeGrowthOutside0X128: u256.NewUint(3000000), @@ -45,20 +47,20 @@ func TestInitData(t *testing.T) { secondsPerLiquidityOutsideX128: u256.NewUint(6000000), secondsOutside: 7, initialized: true, - } + }) mockPool.ticks = mockTicks - mockPositions := Positions{} - mockPositions["test_position"] = PositionInfo{ + mockPositions := avl.NewTree() + mockPositions.Set("test_position", PositionInfo{ liquidity: u256.NewUint(1000000), feeGrowthInside0LastX128: u256.NewUint(2000000), feeGrowthInside1LastX128: u256.NewUint(3000000), tokensOwed0: u256.NewUint(4000000), tokensOwed1: u256.NewUint(5000000), - } + }) mockPool.positions = mockPositions - pools["token0:token1:3000"] = mockPool + pools.Set("token0:token1:3000", mockPool) } func TestPoolGetters(t *testing.T) { diff --git a/pool/pool.gno b/pool/pool.gno index df2e51ee8..67308c228 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -124,7 +124,7 @@ func Burn( } positionKey := getPositionKey(caller, tickLower, tickUpper) - pool.positions[positionKey] = position + pool.setPosition(positionKey, position) // actual token transfer happens in Collect() return amount0.ToString(), amount1.ToString() @@ -192,7 +192,7 @@ func Collect( checkTransferError(token1.Transfer(recipient, amount1.Uint64())) } - pool.positions[positionKey] = position + pool.setPosition(positionKey, position) return amount0.ToString(), amount1.ToString() } @@ -316,12 +316,15 @@ func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 { // - feePrtocol0 occupies the lower 4 bits // - feeProtocol1 is shifted the lower 4 positions to occupy the upper 4 bits newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 ) + // Update slot0 for each pool - for _, pool := range pools { - if pool != nil { - pool.slot0.feeProtocol = newFee - } - } + pools.Iterate("", "", func(poolPath string, iPool interface{}) bool { + pool := iPool.(*Pool) + pool.slot0.feeProtocol = newFee + + return false + }) + // update slot0 slot0FeeProtocol = newFee return newFee diff --git a/pool/pool_manager.gno b/pool/pool_manager.gno index 575b95594..cd04d0416 100644 --- a/pool/pool_manager.gno +++ b/pool/pool_manager.gno @@ -2,8 +2,10 @@ package pool import ( "std" + "strconv" "strings" + "gno.land/p/demo/avl" "gno.land/p/demo/ufmt" "gno.land/r/gnoswap/v1/consts" @@ -15,39 +17,18 @@ import ( u256 "gno.land/p/gnoswap/uint256" ) -type poolMap map[string]*Pool - -func (pm *poolMap) Get(poolPath string) (*Pool, bool) { - pool, exist := (*pm)[poolPath] - if !exist { - return nil, false - } - - return pool, true -} - -func (pm *poolMap) Set(poolPath string, pool *Pool) { - (*pm)[poolPath] = pool -} - -type tickSpacingMap map[uint32]int32 - -func (t *tickSpacingMap) Get(fee uint32) int32 { - return (*t)[fee] -} - var ( - feeAmountTickSpacing tickSpacingMap = make(tickSpacingMap) // maps fee to tickSpacing || map[feeAmount]tick_spacing - pools poolMap = make(poolMap) // maps poolPath to pool || map[poolPath]*Pool + feeAmountTickSpacing = avl.NewTree() // feeBps(uint32) -> tickSpacing(int32) + pools = avl.NewTree() // poolPath -> *Pool slot0FeeProtocol uint8 = 0 ) func init() { - feeAmountTickSpacing[100] = 1 // 0.01% - feeAmountTickSpacing[500] = 10 // 0.05% - feeAmountTickSpacing[3000] = 60 // 0.3% - feeAmountTickSpacing[10000] = 200 // 1% + setFeeAmountTickSpacing(100, 1) // 0.01% + setFeeAmountTickSpacing(500, 10) // 0.05% + setFeeAmountTickSpacing(3000, 60) // 0.3% + setFeeAmountTickSpacing(10000, 200) // 1% } // createPoolParams holds the essential parameters for creating a new pool. @@ -66,7 +47,7 @@ func newPoolParams( sqrtPriceX96 string, ) *createPoolParams { price := u256.MustFromDecimal(sqrtPriceX96) - tickSpacing := feeAmountTickSpacing.Get(fee) + tickSpacing := GetFeeAmountTickSpacing(fee) return &createPoolParams{ token0Path: token0Path, token1Path: token1Path, @@ -211,8 +192,7 @@ func CreatePool( // DoesPoolPathExist checks if a pool exists for the given poolPath. // The poolPath is a unique identifier for a pool, combining token paths and fee. func DoesPoolPathExist(poolPath string) bool { - _, exist := pools[poolPath] - return exist + return pools.Has(poolPath) } // GetPool retrieves a pool instance based on the provided token paths and fee tier. @@ -260,14 +240,14 @@ func GetPool(token0Path, token1Path string, fee uint32) *Pool { // Example: // pool := GetPoolFromPoolPath("path/to/pool") func GetPoolFromPoolPath(poolPath string) *Pool { - pool, exist := pools[poolPath] + iPool, exist := pools.Get(poolPath) if !exist { panic(addDetailToError( errDataNotFound, ufmt.Sprintf("expected poolPath(%s) to exist", poolPath), )) } - return pool + return iPool.(*Pool) } // GetPoolPath generates a unique pool path string based on the token paths and fee tier. @@ -303,3 +283,41 @@ func GetPoolPath(token0Path, token1Path string, fee uint32) string { } return ufmt.Sprintf("%s:%s:%d", token0Path, token1Path, fee) } + +// GetFeeAmountTickSpacing retrieves the tick spacing associated with a given fee amount. +// The tick spacing determines the minimum distance between ticks in the pool. +// +// Parameters: +// - fee (uint32): The fee tier in basis points (e.g., 3000 for 0.3%) +// +// Returns: +// - int32: The tick spacing value for the given fee tier +// +// Panics: +// - If the fee amount is not registered in feeAmountTickSpacing +func GetFeeAmountTickSpacing(fee uint32) int32 { + feeStr := strconv.FormatUint(uint64(fee), 10) + iTickSpacing, exist := feeAmountTickSpacing.Get(feeStr) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("expected feeAmountTickSpacing(%s) to exist", feeStr), + )) + } + + return iTickSpacing.(int32) +} + +// setFeeAmountTickSpacing associates a tick spacing value with a fee amount. +// This is typically called during initialization to set up supported fee tiers. +// +// Parameters: +// - fee (uint32): The fee tier in basis points (e.g., 3000 for 0.3%) +// - tickSpacing (int32): The minimum tick spacing for this fee tier +// +// Note: Smaller tick spacing allows for more granular price points but increases +// computational overhead. Higher fee tiers typically use larger tick spacing. +func setFeeAmountTickSpacing(fee uint32, tickSpacing int32) { + feeStr := strconv.FormatUint(uint64(fee), 10) + feeAmountTickSpacing.Set(feeStr, tickSpacing) +} diff --git a/pool/pool_manager_test.gno b/pool/pool_manager_test.gno index 36637a281..4ea35cd95 100644 --- a/pool/pool_manager_test.gno +++ b/pool/pool_manager_test.gno @@ -8,29 +8,6 @@ import ( "gno.land/r/gnoswap/v1/consts" ) -func TestPoolMapOperations(t *testing.T) { - pm := make(poolMap) - - poolPath := "token0:token1:500" - params := newPoolParams("token0", "token1", 500, "4295128740") - pool := newPool(params) - - pm.Set(poolPath, pool) - - retrieved, exists := pm.Get(poolPath) - if !exists { - t.Error("Expected pool to exist") - } - if retrieved != pool { - t.Error("Retrieved pool doesn't match original") - } - - _, exists = pm.Get("nonexistent") - if exists { - t.Error("Expected pool to not exist") - } -} - func TestNewPoolParams(t *testing.T) { params := newPoolParams( "token0", @@ -98,7 +75,7 @@ func TestTickSpacingMap(t *testing.T) { } for _, tt := range tests { - spacing := feeAmountTickSpacing.Get(tt.fee) + spacing := GetFeeAmountTickSpacing(tt.fee) if spacing != tt.tickSpacing { t.Errorf("For fee %d, expected tick spacing %d, got %d", tt.fee, tt.tickSpacing, spacing) @@ -176,11 +153,7 @@ func TestCreatePool(t *testing.T) { if !tt.shouldPanic { // verify pool was created correctly poolPath := GetPoolPath(tt.token0Path, tt.token1Path, tt.fee) - pool, exists := pools.Get(poolPath) - if !exists { - t.Errorf("pool was not created") - return - } + pool := mustGetPool(poolPath) // check if GNOT was properly wrapped expectedToken0 := tt.token0Path diff --git a/pool/pool_test.gno b/pool/pool_test.gno index 7a7050321..9cef9f82e 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -3,6 +3,7 @@ package pool import ( "testing" + "gno.land/p/demo/avl" "gno.land/p/demo/testutils" "gno.land/p/demo/uassert" u256 "gno.land/p/gnoswap/uint256" @@ -57,7 +58,7 @@ func TestBurn(t *testing.T) { tokensOwed1: u256.NewUint(0), } mockPool := &Pool{ - positions: make(map[string]PositionInfo), + positions: avl.NewTree(), } GetPool = func(token0Path, token1Path string, fee uint32) *Pool { @@ -98,7 +99,7 @@ func TestBurn(t *testing.T) { // setup position for this test posKey := getPositionKey(mockCaller, tt.tickLower, tt.tickUpper) - mockPool.positions[posKey] = mockPosition + mockPool.positions.Set(posKey, mockPosition) if tt.expectPanic { defer func() { @@ -125,7 +126,7 @@ func TestBurn(t *testing.T) { t.Errorf("expected amount1 %s, got %s", tt.expectedAmount1, amount1) } - newPosition := mockPool.positions[posKey] + newPosition := mockPool.mustGetPosition(posKey) if newPosition.tokensOwed0.IsZero() { t.Error("expected tokensOwed0 to be updated") } diff --git a/pool/position.gno b/pool/position.gno index 8afba9905..572b556dd 100644 --- a/pool/position.gno +++ b/pool/position.gno @@ -120,10 +120,10 @@ func (p *Pool) positionUpdateWithKey( feeGrowthInside0X128 = feeGrowthInside0X128.NilToZero() feeGrowthInside1X128 = feeGrowthInside1X128.NilToZero() - positionToUpdate := p.GetPosition(positionKey) + positionToUpdate, _ := p.GetPosition(positionKey) positionAfterUpdate := positionUpdate(positionToUpdate, liquidityDelta, feeGrowthInside0X128, feeGrowthInside1X128) - p.positions[positionKey] = positionAfterUpdate + p.setPosition(positionKey, positionAfterUpdate) return positionAfterUpdate } @@ -153,23 +153,31 @@ func (p *Pool) PositionTokensOwed1(key string) *u256.Uint { return p.mustGetPosition(key).tokensOwed1 } -func (p *Pool) mustGetPosition(key string) PositionInfo { - position, exist := p.positions[key] +// GetPosition returns the position info for a given key. +func (p *Pool) GetPosition(key string) (PositionInfo, bool) { + iPositionInfo, exist := p.positions.Get(key) if !exist { - panic(addDetailToError( - errDataNotFound, - ufmt.Sprintf("position(%s) does not exist", key), - )) + newPosition := PositionInfo{} + newPosition.valueOrZero() + return newPosition, false } - return position + + return iPositionInfo.(PositionInfo), true } -func (p *Pool) GetPosition(key string) PositionInfo { - position, exist := p.positions[key] +// setPosition sets the position info for a given key. +func (p *Pool) setPosition(posKey string, positionInfo PositionInfo) { + p.positions.Set(posKey, positionInfo) +} + +// mustGetPosition returns the position info for a given key. +func (p *Pool) mustGetPosition(positionKey string) PositionInfo { + positionInfo, exist := p.GetPosition(positionKey) if !exist { - newPosition := PositionInfo{} - newPosition.valueOrZero() - return newPosition + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("positionKey(%s) does not exist", positionKey), + )) } - return position + return positionInfo } diff --git a/pool/position_update.gno b/pool/position_update.gno index 827d91f10..9e3fe0304 100644 --- a/pool/position_update.gno +++ b/pool/position_update.gno @@ -88,10 +88,10 @@ func (p *Pool) updatePosition(positionParams ModifyPositionParams, tick int32) P // clear any tick data that is no longer needed if positionParams.liquidityDelta.IsNeg() { if flippedLower { - delete(p.ticks, positionParams.tickLower) + p.deleteTick(positionParams.tickLower) } if flippedUpper { - delete(p.ticks, positionParams.tickUpper) + p.deleteTick(positionParams.tickUpper) } } diff --git a/pool/position_update_test.gno b/pool/position_update_test.gno index ad2e67077..b54ae6d09 100644 --- a/pool/position_update_test.gno +++ b/pool/position_update_test.gno @@ -67,8 +67,8 @@ func TestUpdatePosition(t *testing.T) { } if !tt.positionParams.liquidityDelta.IsZero() { - lowerTick := p.ticks[tt.positionParams.tickLower] - upperTick := p.ticks[tt.positionParams.tickUpper] + lowerTick := p.mustGetTick(tt.positionParams.tickLower) + upperTick := p.mustGetTick(tt.positionParams.tickUpper) if !lowerTick.initialized { t.Error("lower tick not initialized") diff --git a/pool/protocol_fee_withdrawal_test.gno b/pool/protocol_fee_withdrawal_test.gno index 241459d91..7cbac5d39 100644 --- a/pool/protocol_fee_withdrawal_test.gno +++ b/pool/protocol_fee_withdrawal_test.gno @@ -55,8 +55,7 @@ func TestHandleWithdrawalFee(t *testing.T) { InitialisePoolTest(t) std.TestSetRealm(std.NewUserRealm(users.Resolve(position))) poolPath := GetPoolPath(wugnotPath, gnsPath, fee3000) - _, found := pools[poolPath] - if !found { + if !pools.Has(poolPath) { panic("pool not found") } TokenApprove(t, wugnotPath, alice, protocolFee, uint64(0)) diff --git a/pool/swap_test.gno b/pool/swap_test.gno index 3b73a5139..25cb7c4f8 100644 --- a/pool/swap_test.gno +++ b/pool/swap_test.gno @@ -4,6 +4,7 @@ import ( "std" "testing" + "gno.land/p/demo/avl" "gno.land/p/demo/uassert" "gno.land/r/demo/users" @@ -201,14 +202,13 @@ func TestComputeSwap(t *testing.T) { }, feeGrowthGlobal0X128: u256.Zero(), feeGrowthGlobal1X128: u256.Zero(), - tickBitmaps: make(TickBitmaps), - ticks: make(Ticks), - positions: make(Positions), + tickBitmaps: avl.NewTree(), + ticks: avl.NewTree(), + positions: avl.NewTree(), } wordPos, _ := tickBitmapPosition(0) - // TODO: use avl - mockPool.tickBitmaps[wordPos] = u256.NewUint(1) + mockPool.setTickBitmap(wordPos, u256.NewUint(1)) t.Run("basic swap", func(t *testing.T) { comp := SwapComputation{ @@ -280,7 +280,6 @@ func TestComputeSwap(t *testing.T) { } func TestSwap_Failures(t *testing.T) { - t.Skip() const addr = pusers.AddressOrName(consts.ROUTER_ADDR) tests := []struct { @@ -412,9 +411,9 @@ func TestDrySwap_Failures(t *testing.T) { }, feeGrowthGlobal0X128: u256.Zero(), feeGrowthGlobal1X128: u256.Zero(), - tickBitmaps: make(TickBitmaps), - ticks: make(Ticks), - positions: make(Positions), + tickBitmaps: avl.NewTree(), + ticks: avl.NewTree(), + positions: avl.NewTree(), } originalGetPool := GetPool diff --git a/pool/tick.gno b/pool/tick.gno index 63e070222..7c147d1eb 100644 --- a/pool/tick.gno +++ b/pool/tick.gno @@ -1,6 +1,8 @@ package pool import ( + "strconv" + "gno.land/p/demo/ufmt" i256 "gno.land/p/gnoswap/int256" @@ -282,14 +284,21 @@ func (p *Pool) tickCross( thisTick.feeGrowthOutside0X128 = new(u256.Uint).Sub(feeGrowthGlobal0X128, thisTick.feeGrowthOutside0X128) thisTick.feeGrowthOutside1X128 = new(u256.Uint).Sub(feeGrowthGlobal1X128, thisTick.feeGrowthOutside1X128) - p.ticks[tick] = thisTick + p.setTick(tick, thisTick) return thisTick.liquidityNet.Clone() } // setTick updates the tick data for the specified tick index in the pool. func (p *Pool) setTick(tick int32, newTickInfo TickInfo) { - p.ticks[tick] = newTickInfo + tickStr := strconv.Itoa(int(tick)) + p.ticks.Set(tickStr, newTickInfo) +} + +// deleteTick deletes the tick data for the specified tick index in the pool. +func (p *Pool) deleteTick(tick int32) { + tickStr := strconv.Itoa(int(tick)) + p.ticks.Remove(tickStr) } // getTick retrieves the TickInfo associated with the specified tick index from the pool. @@ -310,9 +319,15 @@ func (p *Pool) setTick(tick int32, newTickInfo TickInfo) { // This function ensures the retrieved tick data is always valid and safe for further operations, // such as calculations or updates, by sanitizing nil fields in the TickInfo structure. func (p *Pool) getTick(tick int32) TickInfo { - tickInfo := p.ticks[tick] - tickInfo.valueOrZero() - return tickInfo + tickStr := strconv.Itoa(int(tick)) + iTickInfo, exist := p.ticks.Get(tickStr) + if !exist { + tickInfo := TickInfo{} + tickInfo.valueOrZero() + return tickInfo + } + + return iTickInfo.(TickInfo) } // GetTickLiquidityGross returns the gross liquidity for the specified tick. @@ -379,7 +394,8 @@ func (p *Pool) GetTickInitialized(tick int32) bool { // tickInfo := pool.mustGetTick(10) // fmt.Println("Tick Info:", tickInfo) func (p *Pool) mustGetTick(tick int32) TickInfo { - tickInfo, exist := p.ticks[tick] + tickStr := strconv.Itoa(int(tick)) + iTickInfo, exist := p.ticks.Get(tickStr) if !exist { panic(addDetailToError( errDataNotFound, @@ -387,5 +403,5 @@ func (p *Pool) mustGetTick(tick int32) TickInfo { )) } - return tickInfo + return iTickInfo.(TickInfo) } diff --git a/pool/tick_bitmap.gno b/pool/tick_bitmap.gno index 53c9cc416..dc7369b7f 100644 --- a/pool/tick_bitmap.gno +++ b/pool/tick_bitmap.gno @@ -1,6 +1,9 @@ package pool import ( + "strconv" + + "gno.land/p/demo/ufmt" plp "gno.land/p/gnoswap/pool" u256 "gno.land/p/gnoswap/uint256" @@ -79,16 +82,32 @@ func (p *Pool) tickBitmapNextInitializedTickWithInOneWord( // getTickBitmap gets the tick bitmap for the given word position // if the tick bitmap is not initialized, initialize it to zero func (p *Pool) getTickBitmap(wordPos int16) *u256.Uint { - if p.tickBitmaps[wordPos] == nil { - p.tickBitmaps[wordPos] = u256.Zero() + wordPosStr := strconv.Itoa(int(wordPos)) + + if !p.tickBitmaps.Has(wordPosStr) { + p.initTickBitmap(wordPos) + } + + iU256, exist := p.tickBitmaps.Get(wordPosStr) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("tickBitmap(%d) does not exist", wordPos), + )) } - return p.tickBitmaps[wordPos] + return iU256.(*u256.Uint) } // setTickBitmap sets the tick bitmap for the given word position -func (p *Pool) setTickBitmap(wordPos int16, bitmap *u256.Uint) { - p.tickBitmaps[wordPos] = bitmap +func (p *Pool) setTickBitmap(wordPos int16, tickBitmap *u256.Uint) { + wordPosStr := strconv.Itoa(int(wordPos)) + p.tickBitmaps.Set(wordPosStr, tickBitmap) +} + +// initTickBitmap initializes the tick bitmap for the given word position +func (p *Pool) initTickBitmap(wordPos int16) { + p.setTickBitmap(wordPos, u256.Zero()) } // getWordAndBitPos gets tick's wordPos and bitPos depending on the swap direction diff --git a/pool/tick_bitmap_test.gno b/pool/tick_bitmap_test.gno index 37c8c047b..05d52a001 100644 --- a/pool/tick_bitmap_test.gno +++ b/pool/tick_bitmap_test.gno @@ -3,6 +3,7 @@ package pool import ( "testing" + "gno.land/p/demo/avl" u256 "gno.land/p/gnoswap/uint256" ) @@ -76,7 +77,7 @@ func TestTickBitmapFlipTick(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pool := &Pool{ - tickBitmaps: make(map[int16]*u256.Uint), + tickBitmaps: avl.NewTree(), } if tt.shouldPanic { @@ -92,7 +93,7 @@ func TestTickBitmapFlipTick(t *testing.T) { if !tt.shouldPanic { wordPos, bitPos := tickBitmapPosition(tt.tick / tt.tickSpacing) expected := new(u256.Uint).Lsh(u256.NewUint(1), uint(bitPos)) - if pool.tickBitmaps[wordPos].Cmp(expected) != 0 { + if pool.getTickBitmap(wordPos).Cmp(expected) != 0 { t.Errorf("bitmap not set correctly") } } @@ -137,7 +138,7 @@ func TestTickBitmapNextInitializedTickWithInOneWord(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pool := &Pool{ - tickBitmaps: make(map[int16]*u256.Uint), + tickBitmaps: avl.NewTree(), } if tt.setupBitmap != nil { tt.setupBitmap(pool) diff --git a/pool/tick_test.gno b/pool/tick_test.gno index 32fef15e7..d7a2aee05 100644 --- a/pool/tick_test.gno +++ b/pool/tick_test.gno @@ -3,6 +3,7 @@ package pool import ( "testing" + "gno.land/p/demo/avl" "gno.land/p/demo/uassert" i256 "gno.land/p/gnoswap/int256" @@ -60,24 +61,24 @@ func TestcalculateMaxLiquidityPerTick(t *testing.T) { func TestCalculateFeeGrowthInside(t *testing.T) { // Create a mock pool pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } // Setup test ticks - pool.ticks[0] = TickInfo{ + pool.ticks.Set("0", TickInfo{ liquidityGross: u256.NewUint(1000), liquidityNet: i256.NewInt(100), feeGrowthOutside0X128: u256.NewUint(5), feeGrowthOutside1X128: u256.NewUint(7), initialized: true, - } - pool.ticks[100] = TickInfo{ + }) + pool.ticks.Set("100", TickInfo{ liquidityGross: u256.NewUint(2000), liquidityNet: i256.NewInt(-100), feeGrowthOutside0X128: u256.NewUint(10), feeGrowthOutside1X128: u256.NewUint(15), initialized: true, - } + }) tests := []struct { name string @@ -261,7 +262,7 @@ func TestCalculateFeeGrowthInside(t *testing.T) { func TestTickUpdate(t *testing.T) { pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } tests := []struct { @@ -403,7 +404,7 @@ func TestTickUpdate(t *testing.T) { wantFlipped: false, shouldPanic: false, verify: func() { - info := pool.ticks[1] + info := pool.mustGetTick(1) uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "1") uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "2") }, @@ -424,7 +425,7 @@ func TestTickUpdate(t *testing.T) { wantFlipped: false, shouldPanic: false, verify: func() { - info := pool.ticks[1] + info := pool.mustGetTick(1) uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "1") uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "2") }, @@ -441,9 +442,9 @@ func TestTickUpdate(t *testing.T) { wantFlipped: false, shouldPanic: false, verify: func() { - info := pool.ticks[2] - uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "") - uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "") + info := pool.getTick(2) + uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "0") + uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "0") }, }, } @@ -494,17 +495,17 @@ func TestTickUpdate(t *testing.T) { func TestTickCross(t *testing.T) { pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } // Setup initial tick state - pool.ticks[100] = TickInfo{ + pool.ticks.Set("100", TickInfo{ liquidityGross: u256.NewUint(1000), liquidityNet: i256.NewInt(500), feeGrowthOutside0X128: u256.NewUint(10), feeGrowthOutside1X128: u256.NewUint(15), initialized: true, - } + }) tests := []struct { name string @@ -539,7 +540,7 @@ func TestTickCross(t *testing.T) { func TestGetTick(t *testing.T) { pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } // Setup a tick @@ -550,7 +551,7 @@ func TestGetTick(t *testing.T) { feeGrowthOutside1X128: u256.NewUint(15), initialized: true, } - pool.ticks[50] = expectedTick + pool.setTick(50, expectedTick) tests := []struct { name string @@ -710,9 +711,7 @@ func setTick( ) { t.Helper() - info := pool.ticks[tick] - info.valueOrZero() - + info := pool.getTick(tick) info.feeGrowthOutside0X128 = feeGrowthOutside0X128 info.feeGrowthOutside1X128 = feeGrowthOutside1X128 info.liquidityGross = liquidityGross @@ -722,10 +721,10 @@ func setTick( info.secondsOutside = secondsOutside info.initialized = initialized - pool.ticks[tick] = info + pool.setTick(tick, info) } func deleteTick(t *testing.T, pool *Pool, tick int32) { t.Helper() - delete(pool.ticks, tick) + pool.deleteTick(tick) } diff --git a/pool/type.gno b/pool/type.gno index 2cca5dbc9..a94fb96ff 100644 --- a/pool/type.gno +++ b/pool/type.gno @@ -3,6 +3,7 @@ package pool import ( "std" + "gno.land/p/demo/avl" "gno.land/p/demo/ufmt" "gno.land/r/gnoswap/v1/common" @@ -313,12 +314,6 @@ func (t *TickInfo) valueOrZero() { t.secondsPerLiquidityOutsideX128 = t.secondsPerLiquidityOutsideX128.NilToZero() } -type ( - Ticks map[int32]TickInfo // tick => TickInfo - TickBitmaps map[int16]*u256.Uint // tick(wordPos) => bitmap(tickWord ^ mask) - Positions map[string]PositionInfo // positionKey => PositionInfo -) - // type Pool describes a single Pool's state // A pool is identificed with a unique key (token0, token1, fee), where token0 < token1 type Pool struct { @@ -343,11 +338,11 @@ type Pool struct { liquidity *u256.Uint // total amount of liquidity in the pool - ticks Ticks // maps tick index to tick + ticks *avl.Tree // tick(int32) -> TickInfo - tickBitmaps TickBitmaps // maps tick index to tick bitmap + tickBitmaps *avl.Tree // tick(wordPos)(int16) -> bitMap(tickWord ^ mask)(*u256.Uint) - positions Positions // maps the key (caller, lower tick, upper tick) to a unique position + positions *avl.Tree // maps the key (caller, lower tick, upper tick) to a unique position } func newPool(poolInfo *createPoolParams) *Pool { @@ -367,9 +362,9 @@ func newPool(poolInfo *createPoolParams) *Pool { feeGrowthGlobal1X128: u256.Zero(), protocolFees: newProtocolFees(), liquidity: u256.Zero(), - ticks: Ticks{}, - tickBitmaps: TickBitmaps{}, - positions: Positions{}, + ticks: avl.NewTree(), + tickBitmaps: avl.NewTree(), + positions: avl.NewTree(), } } @@ -442,10 +437,12 @@ func (p *Pool) Liquidity() *u256.Uint { } func mustGetPool(poolPath string) *Pool { - pool, exist := pools[poolPath] - if !exist { - panic(addDetailToError(errDataNotFound, - ufmt.Sprintf("poolPath(%s) does not exist", poolPath))) + pool := GetPoolFromPoolPath(poolPath) + if pool == nil { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("expected poolPath(%s) to exist", poolPath), + )) } return pool } diff --git a/pool/utils.gno b/pool/utils.gno index c731d3618..9f3017af0 100644 --- a/pool/utils.gno +++ b/pool/utils.gno @@ -82,7 +82,7 @@ func safeConvertToInt128(value *u256.Uint) *i256.Int { // This function validates that the given `value` is properly initialized and checks whether // it exceeds the maximum value of uint128. If the value exceeds the uint128 range, // it applies a masking operation to truncate the value to fit within the uint128 limit. -// +//q // Parameters: // - value: *u256.Uint, the value to be checked and possibly truncated. //