From bf8e09dcf665be1ab1c0ed0e99390cf29f175067 Mon Sep 17 00:00:00 2001 From: n3wbie Date: Tue, 10 Dec 2024 17:26:21 +0900 Subject: [PATCH] feat: use avl.Tree instead of map --- pool/_helper_test.gno | 21 ++++---- pool/api.gno | 43 +++++++++------ pool/api_test.gno | 2 +- pool/getter.gno | 6 ++- pool/getter_test.gno | 16 +++--- pool/pool.gno | 31 +++++++---- pool/pool_manager.gno | 76 ++++++++++----------------- pool/pool_manager_test.gno | 45 ++++------------ pool/pool_test.gno | 18 ++++--- pool/position.gno | 13 +++-- pool/position_update.gno | 6 ++- pool/position_update_test.gno | 27 +++++----- pool/protocol_fee_withdrawal_test.gno | 3 +- pool/tick.gno | 18 +++++-- pool/tick_bitmap.gno | 30 ++++++----- pool/tick_bitmap_test.gno | 11 ++-- pool/tick_test.gno | 55 +++++++++++-------- pool/type.gno | 26 +++++---- 18 files changed, 234 insertions(+), 213 deletions(-) diff --git a/pool/_helper_test.gno b/pool/_helper_test.gno index 992d99cc..99ac4ca7 100644 --- a/pool/_helper_test.gno +++ b/pool/_helper_test.gno @@ -4,6 +4,13 @@ import ( "std" "testing" + "gno.land/p/demo/avl" + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + pusers "gno.land/p/demo/users" + + "gno.land/r/demo/users" + "gno.land/r/demo/wugnot" "gno.land/r/gnoswap/v1/gns" "gno.land/r/onbloc/bar" @@ -13,10 +20,6 @@ import ( "gno.land/r/onbloc/qux" "gno.land/r/onbloc/usdc" - "gno.land/p/demo/testutils" - "gno.land/p/demo/uassert" - pusers "gno.land/p/demo/users" - "gno.land/r/demo/users" "gno.land/r/gnoswap/v1/consts" pn "gno.land/r/gnoswap/v1/position" ) @@ -360,7 +363,7 @@ func ugnotDeposit(t *testing.T, addr std.Address, amount uint64) { // resetObject resets the object state(clear or make it default values) func resetObject(t *testing.T) { - pools = make(poolMap) + pools = avl.NewTree() slot0FeeProtocol = 0 poolCreationFee = 100_000_000 withdrawalFeeBPS = 100 @@ -413,11 +416,11 @@ func burnUsdc(addr pusers.AddressOrName) { func TestBeforeResetObject(t *testing.T) { // make some data - pools = make(poolMap) - pools["gno.land/r/gnoswap/v1/gns:gno.land/r/onbloc/usdc"] = &Pool{ + pools = avl.NewTree() + pools.Set("gno.land/r/gnoswap/v1/gns:gno.land/r/onbloc/usdc", &Pool{ token0Path: "gno.land/r/gnoswap/v1/gns", token1Path: "gno.land/r/onbloc/usdc", - } + }) slot0FeeProtocol = 1 poolCreationFee = 100_000_000 @@ -434,7 +437,7 @@ func TestBeforeResetObject(t *testing.T) { func TestResetObject(t *testing.T) { resetObject(t) - uassert.Equal(t, len(pools), 0) + uassert.Equal(t, pools.Size(), 0) uassert.Equal(t, slot0FeeProtocol, uint8(0)) uassert.Equal(t, poolCreationFee, uint64(100_000_000)) uassert.Equal(t, withdrawalFeeBPS, uint64(100)) diff --git a/pool/api.gno b/pool/api.gno index afe4d9fa..35a36dfa 100644 --- a/pool/api.gno +++ b/pool/api.gno @@ -2,15 +2,15 @@ package pool import ( b64 "encoding/base64" - - "gno.land/p/demo/json" - "std" "strconv" "strings" "time" + "gno.land/p/demo/json" "gno.land/p/demo/ufmt" + + u256 "gno.land/p/gnoswap/uint256" ) type RpcPool struct { @@ -81,10 +81,11 @@ type RpcPosition struct { func ApiGetPools() string { rpcPools := []RpcPool{} - for poolPath, _ := range pools { + pools.Iterate("", "", func(poolPath string, value interface{}) bool { rpcPool := rpcMakePool(poolPath) rpcPools = append(rpcPools, rpcPool) - } + return false + }) responses := json.ArrayNode("", []*json.Node{}) for _, pool := range rpcPools { @@ -122,10 +123,10 @@ func ApiGetPools() string { } func ApiGetPool(poolPath string) string { - _, exist := pools[poolPath] - if !exist { + if getPoolFromPoolPath(poolPath) == nil { return "" } + rpcPool := rpcMakePool(poolPath) responseNode := json.ObjectNode("", map[string]*json.Node{ @@ -198,8 +199,10 @@ func rpcMakePool(poolPath string) RpcPool { rpcPool.Liquidity = pool.liquidity.ToString() rpcPool.Ticks = RpcTicks{} - for tick, tickInfo := range pool.ticks { - rpcPool.Ticks[tick] = RpcTickInfo{ + pool.ticks.Iterate("", "", func(strTick string, value interface{}) bool { + tick, _ := strconv.Atoi(strTick) + tickInfo := value.(TickInfo) + rpcPool.Ticks[int32(tick)] = RpcTickInfo{ LiquidityGross: tickInfo.liquidityGross.ToString(), LiquidityNet: tickInfo.liquidityNet.ToString(), FeeGrowthOutside0X128: tickInfo.feeGrowthOutside0X128.ToString(), @@ -209,18 +212,24 @@ func rpcMakePool(poolPath string) RpcPool { SecondsOutside: tickInfo.secondsOutside, Initialized: tickInfo.initialized, } - } + + return false + }) rpcPool.TickBitmaps = RpcTickBitmaps{} - for tick, tickBitmap := range pool.tickBitmaps { - rpcPool.TickBitmaps[tick] = tickBitmap.ToString() - } + pool.tickBitmaps.Iterate("", "", func(strTick string, value interface{}) bool { + tick, _ := strconv.Atoi(strTick) + tickBitmap := value.(*u256.Uint) + rpcPool.TickBitmaps[int16(tick)] = tickBitmap.ToString() + + return false + }) - Positions := pool.positions rpcPositions := []RpcPosition{} - for posKey, posInfo := range Positions { + pool.positions.Iterate("", "", func(posKey string, value interface{}) bool { owner, tickLower, tickUpper := posKeyDivide(posKey) + posInfo := value.(PositionInfo) rpcPositions = append(rpcPositions, RpcPosition{ Owner: owner, TickLower: tickLower, @@ -229,7 +238,9 @@ func rpcMakePool(poolPath string) RpcPool { Token0Owed: posInfo.tokensOwed0.ToString(), Token1Owed: posInfo.tokensOwed1.ToString(), }) - } + + return false + }) rpcPool.Positions = rpcPositions return rpcPool diff --git a/pool/api_test.gno b/pool/api_test.gno index 58031e92..fcef1abe 100644 --- a/pool/api_test.gno +++ b/pool/api_test.gno @@ -21,7 +21,7 @@ func TestInitTwoPools(t *testing.T) { // bar:baz CreatePool(barPath, bazPath, fee500, "130621891405341611593710811006") // tick 10000 - uassert.Equal(t, len(pools), 2) + uassert.Equal(t, pools.Size(), 2) } func TestApiGetPools(t *testing.T) { diff --git a/pool/getter.gno b/pool/getter.gno index f962846f..1b4d36b0 100644 --- a/pool/getter.gno +++ b/pool/getter.gno @@ -3,9 +3,11 @@ package pool // pool func PoolGetPoolList() []string { poolPaths := []string{} - for poolPath, _ := range pools { + + pools.Iterate("", "", func(poolPath string, value interface{}) bool { poolPaths = append(poolPaths, poolPath) - } + return false + }) return poolPaths } diff --git a/pool/getter_test.gno b/pool/getter_test.gno index 1772cc22..e0727669 100644 --- a/pool/getter_test.gno +++ b/pool/getter_test.gno @@ -3,6 +3,8 @@ package pool import ( "testing" + "gno.land/p/demo/avl" + i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" ) @@ -35,8 +37,8 @@ func TestInitData(t *testing.T) { liquidity: u256.NewUint(1000000), } - mockTicks := Ticks{} - mockTicks[0] = TickInfo{ + mockTicks := avl.NewTree() + mockTicks.Set("0", TickInfo{ liquidityGross: u256.NewUint(1000000), liquidityNet: i256.NewInt(2000000), feeGrowthOutside0X128: u256.NewUint(3000000), @@ -45,20 +47,20 @@ func TestInitData(t *testing.T) { secondsPerLiquidityOutsideX128: u256.NewUint(6000000), secondsOutside: 7, initialized: true, - } + }) mockPool.ticks = mockTicks - mockPositions := Positions{} - mockPositions["test_position"] = PositionInfo{ + mockPositions := avl.NewTree() + mockPositions.Set("test_position", PositionInfo{ liquidity: u256.NewUint(1000000), feeGrowthInside0LastX128: u256.NewUint(2000000), feeGrowthInside1LastX128: u256.NewUint(3000000), tokensOwed0: u256.NewUint(4000000), tokensOwed1: u256.NewUint(5000000), - } + }) mockPool.positions = mockPositions - pools["token0:token1:3000"] = mockPool + pools.Set("token0:token1:3000", mockPool) } func TestPoolGetters(t *testing.T) { diff --git a/pool/pool.gno b/pool/pool.gno index cfb655f0..f36bbf58 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -98,7 +98,7 @@ func Burn( } positionKey := positionGetKey(caller, tickLower, tickUpper) - pool.positions[positionKey] = position + pool.positions.Set(positionKey, position) // actual token transfer happens in Collect() return amount0.ToString(), amount1.ToString() @@ -132,13 +132,14 @@ func Collect( pool := GetPool(token0Path, token1Path, fee) positionKey := positionGetKey(std.PrevRealm().Addr(), tickLower, tickUpper) - position, exist := pool.positions[positionKey] + iposition, exist := pool.positions.Get(positionKey) if !exist { panic(addDetailToError( errDataNotFound, ufmt.Sprintf("pool.gno__Collect() || positionKey(%s) does not exist", positionKey), )) } + position := iposition.(PositionInfo) var amount0, amount1 *u256.Uint @@ -158,7 +159,7 @@ func Collect( token1 := common.GetTokenTeller(pool.token1Path) checkTransferError(token1.Transfer(recipient, amount1.Uint64())) - pool.positions[positionKey] = position + pool.positions.Set(positionKey, position) return amount0.ToString(), amount1.ToString() } @@ -697,9 +698,11 @@ func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 { newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 ) // iterate all pool - for _, pool := range pools { + pools.Iterate("", "", func(poolPath string, value interface{}) bool { + pool := value.(*Pool) pool.slot0.feeProtocol = newFee - } + return false + }) // update slot0 slot0FeeProtocol = newFee @@ -1054,10 +1057,20 @@ func (p *Pool) GetLiquidity() *u256.Uint { } func mustGetPool(poolPath string) *Pool { - pool, exist := pools[poolPath] - if !exist { - panic(addDetailToError(errDataNotFound, - ufmt.Sprintf("poolPath(%s) does not exist", poolPath))) + pool := getPoolFromPoolPath(poolPath) + if pool == nil { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("pool.gno__mustGetPool() || expected poolPath(%s) to exist", poolPath), + )) } return pool } + +func getPoolFromPoolPath(poolPath string) *Pool { + pool, exist := pools.Get(poolPath) + if !exist { + return nil + } + return pool.(*Pool) +} diff --git a/pool/pool_manager.gno b/pool/pool_manager.gno index c8f5c25b..68cc0026 100644 --- a/pool/pool_manager.gno +++ b/pool/pool_manager.gno @@ -2,53 +2,33 @@ package pool import ( "std" + "strconv" "strings" + "gno.land/p/demo/avl" "gno.land/p/demo/ufmt" - "gno.land/r/gnoswap/v1/common" - "gno.land/r/gnoswap/v1/consts" - - en "gno.land/r/gnoswap/v1/emission" + u256 "gno.land/p/gnoswap/uint256" "gno.land/r/gnoswap/v1/gns" - u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" + en "gno.land/r/gnoswap/v1/emission" ) -type poolMap map[string]*Pool - -func (pm *poolMap) Get(poolPath string) (*Pool, bool) { - pool, exist := (*pm)[poolPath] - if !exist { - return nil, false - } - - return pool, true -} - -func (pm *poolMap) Set(poolPath string, pool *Pool) { - (*pm)[poolPath] = pool -} - -type tickSpacingMap map[uint32]int32 - -func (t *tickSpacingMap) Get(fee uint32) int32 { - return (*t)[fee] -} - var ( - feeAmountTickSpacing tickSpacingMap = make(tickSpacingMap) // maps fee to tickSpacing || map[feeAmount]tick_spacing - pools poolMap = make(poolMap) // maps poolPath to pool || map[poolPath]*Pool + feeAmountTickSpacing = avl.NewTree() // feeBps(uint32) -> tickSpacing(int32) + pools = avl.NewTree() // poolPath -> *Pool slot0FeeProtocol uint8 = 0 ) func init() { - feeAmountTickSpacing[100] = 1 // 0.01% - feeAmountTickSpacing[500] = 10 // 0.05% - feeAmountTickSpacing[3000] = 60 // 0.3% - feeAmountTickSpacing[10000] = 200 // 1% + feeAmountTickSpacing.Set("100", int32(1)) // 0.01% + feeAmountTickSpacing.Set("500", int32(10)) // 0.05% + feeAmountTickSpacing.Set("3000", int32(60)) // 0.3% + feeAmountTickSpacing.Set("10000", int32(200)) // 1% } // createPoolParams holds the essential parameters for creating a new pool. @@ -67,13 +47,20 @@ func newPoolParams( sqrtPriceX96 string, ) *createPoolParams { price := u256.MustFromDecimal(sqrtPriceX96) - tickSpacing := feeAmountTickSpacing.Get(fee) + tickSpacing, exist := feeAmountTickSpacing.Get(strconv.FormatUint(uint64(fee), 10)) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("pool_manager.gno__newPoolParams() || expected tickSpacing to exist for fee(%d)", fee), + )) + } + return &createPoolParams{ token0Path: token0Path, token1Path: token1Path, fee: fee, sqrtPriceX96: price, - tickSpacing: tickSpacing, + tickSpacing: tickSpacing.(int32), } } @@ -216,8 +203,11 @@ func CreatePool( // DoesPoolPathExist checks if a pool exists for the given poolPath. // The poolPath is a unique identifier for a pool, combining token paths and fee. func DoesPoolPathExist(poolPath string) bool { - _, exist := pools[poolPath] - return exist + if _, exist := pools.Get(poolPath); exist { + return true + } + + return false } // GetPool retrieves the pool for the given token paths and fee. @@ -225,20 +215,12 @@ func DoesPoolPathExist(poolPath string) bool { // Returns pool struct func GetPool(token0Path, token1Path string, fee uint32) *Pool { poolPath := GetPoolPath(token0Path, token1Path, fee) - pool, exist := pools[poolPath] - if !exist { - panic(addDetailToError( - errDataNotFound, - ufmt.Sprintf("pool_manager.gno__GetPool() || expected poolPath(%s) to exist", poolPath), - )) - } - - return pool + return GetPoolFromPoolPath(poolPath) } // GetPoolFromPoolPath retrieves the pool for the given poolPath. func GetPoolFromPoolPath(poolPath string) *Pool { - pool, exist := pools[poolPath] + pool, exist := pools.Get(poolPath) if !exist { panic(addDetailToError( errDataNotFound, @@ -246,7 +228,7 @@ func GetPoolFromPoolPath(poolPath string) *Pool { )) } - return pool + return pool.(*Pool) } // GetPoolPath generates a poolPath from the given token paths and fee. diff --git a/pool/pool_manager_test.gno b/pool/pool_manager_test.gno index 328315b5..d057765a 100644 --- a/pool/pool_manager_test.gno +++ b/pool/pool_manager_test.gno @@ -8,29 +8,6 @@ import ( "gno.land/r/gnoswap/v1/consts" ) -func TestPoolMapOperations(t *testing.T) { - pm := make(poolMap) - - poolPath := "token0:token1:500" - params := newPoolParams("token0", "token1", 500, "4295128740") - pool := newPool(params) - - pm.Set(poolPath, pool) - - retrieved, exists := pm.Get(poolPath) - if !exists { - t.Error("Expected pool to exist") - } - if retrieved != pool { - t.Error("Retrieved pool doesn't match original") - } - - _, exists = pm.Get("nonexistent") - if exists { - t.Error("Expected pool to not exist") - } -} - func TestNewPoolParams(t *testing.T) { params := newPoolParams( "token0", @@ -88,19 +65,19 @@ func TestGetPoolPath(t *testing.T) { func TestTickSpacingMap(t *testing.T) { tests := []struct { - fee uint32 + fee string tickSpacing int32 }{ - {100, 1}, // 0.01% - {500, 10}, // 0.05% - {3000, 60}, // 0.3% - {10000, 200}, // 1% + {"100", 1}, // 0.01% + {"500", 10}, // 0.05% + {"3000", 60}, // 0.3% + {"10000", 200}, // 1% } for _, tt := range tests { - spacing := feeAmountTickSpacing.Get(tt.fee) - if spacing != tt.tickSpacing { - t.Errorf("For fee %d, expected tick spacing %d, got %d", + spacing, _ := feeAmountTickSpacing.Get(tt.fee) + if spacing.(int32) != tt.tickSpacing { + t.Errorf("For fee %s, expected tick spacing %d, got %d", tt.fee, tt.tickSpacing, spacing) } } @@ -176,11 +153,7 @@ func TestCreatePool(t *testing.T) { if !tt.shouldPanic { // verify pool was created correctly poolPath := GetPoolPath(tt.token0Path, tt.token1Path, tt.fee) - pool, exists := pools.Get(poolPath) - if !exists { - t.Errorf("pool was not created") - return - } + pool := GetPoolFromPoolPath(poolPath) // check if GNOT was properly wrapped expectedToken0 := tt.token0Path diff --git a/pool/pool_test.gno b/pool/pool_test.gno index c6ef82d4..c4ea6c93 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -2,8 +2,10 @@ package pool import ( "std" + "strconv" "testing" + "gno.land/p/demo/avl" "gno.land/p/demo/testutils" "gno.land/p/demo/uassert" @@ -60,7 +62,7 @@ func TestBurn(t *testing.T) { tokensOwed1: u256.NewUint(0), } mockPool := &Pool{ - positions: make(map[string]PositionInfo), + positions: avl.NewTree(), } GetPool = func(token0Path, token1Path string, fee uint32) *Pool { @@ -101,7 +103,7 @@ func TestBurn(t *testing.T) { // setup position for this test posKey := positionGetKey(mockCaller, tt.tickLower, tt.tickUpper) - mockPool.positions[posKey] = mockPosition + mockPool.positions.Set(posKey, mockPosition) if tt.expectPanic { defer func() { @@ -128,7 +130,8 @@ func TestBurn(t *testing.T) { t.Errorf("expected amount1 %s, got %s", tt.expectedAmount1, amount1) } - newPosition := mockPool.positions[posKey] + iNewPosition, _ := mockPool.positions.Get(posKey) + newPosition := iNewPosition.(PositionInfo) if newPosition.tokensOwed0.IsZero() { t.Error("expected tokensOwed0 to be updated") } @@ -475,14 +478,13 @@ func TestComputeSwap(t *testing.T) { }, feeGrowthGlobal0X128: u256.Zero(), feeGrowthGlobal1X128: u256.Zero(), - tickBitmaps: make(TickBitmaps), - ticks: make(Ticks), - positions: make(Positions), + tickBitmaps: avl.NewTree(), + ticks: avl.NewTree(), + positions: avl.NewTree(), } wordPos, _ := tickBitmapPosition(0) - // TODO: use avl - mockPool.tickBitmaps[wordPos] = u256.NewUint(1) + mockPool.tickBitmaps.Set(strconv.FormatInt(int64(wordPos), 10), u256.NewUint(1)) t.Run("basic swap", func(t *testing.T) { comp := SwapComputation{ diff --git a/pool/position.gno b/pool/position.gno index 64d37e9d..98237bfd 100644 --- a/pool/position.gno +++ b/pool/position.gno @@ -53,9 +53,14 @@ func (pool *Pool) positionUpdateWithKey( feeGrowthInside0X128 = feeGrowthInside0X128.NilToZero() feeGrowthInside1X128 = feeGrowthInside1X128.NilToZero() - positionToUpdate := pool.positions[positionKey] + positionToUpdate := PositionInfo{} + ipositionToUpdate, exist := pool.positions.Get(positionKey) + if exist { + positionToUpdate = ipositionToUpdate.(PositionInfo) + } + positionAfterUpdate := positionUpdate(positionToUpdate, liquidityDelta, feeGrowthInside0X128, feeGrowthInside1X128) - pool.positions[positionKey] = positionAfterUpdate + pool.positions.Set(positionKey, positionAfterUpdate) return positionAfterUpdate } @@ -128,7 +133,7 @@ func (p *Pool) GetPositionTokensOwed1(key string) *u256.Uint { } func (p *Pool) mustGetPosition(key string) PositionInfo { - position, exist := p.positions[key] + iposition, exist := p.positions.Get(key) if !exist { panic(addDetailToError( errDataNotFound, @@ -136,5 +141,5 @@ func (p *Pool) mustGetPosition(key string) PositionInfo { )) } - return position + return iposition.(PositionInfo) } diff --git a/pool/position_update.gno b/pool/position_update.gno index 5f21e49f..5ba16994 100644 --- a/pool/position_update.gno +++ b/pool/position_update.gno @@ -1,6 +1,8 @@ package pool import ( + "strconv" + u256 "gno.land/p/gnoswap/uint256" ) @@ -58,11 +60,11 @@ func (pool *Pool) updatePosition(positionParams ModifyPositionParams) PositionIn if positionParams.liquidityDelta.IsNeg() { if flippedLower { - delete(pool.ticks, positionParams.tickLower) + pool.ticks.Remove(strconv.FormatInt(int64(positionParams.tickLower), 10)) } if flippedUpper { - delete(pool.ticks, positionParams.tickUpper) + pool.ticks.Remove(strconv.FormatInt(int64(positionParams.tickUpper), 10)) } } diff --git a/pool/position_update_test.gno b/pool/position_update_test.gno index 24ca4cb5..2b8826ba 100644 --- a/pool/position_update_test.gno +++ b/pool/position_update_test.gno @@ -1,21 +1,18 @@ package pool import ( + "strconv" "testing" - "std" - - "gno.land/p/demo/uassert" - - "gno.land/r/gnoswap/v1/consts" i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/consts" ) func TestUpdatePosition(t *testing.T) { poolParams := &createPoolParams{ - token0Path: "token0", - token1Path: "token1", + token0Path: "token0", + token1Path: "token1", fee: 500, tickSpacing: 10, sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 @@ -23,8 +20,8 @@ func TestUpdatePosition(t *testing.T) { p := newPool(poolParams) tests := []struct { - name string - positionParams ModifyPositionParams + name string + positionParams ModifyPositionParams expectLiquidity *u256.Uint }{ { @@ -62,16 +59,18 @@ func TestUpdatePosition(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { position := p.updatePosition(tt.positionParams) - + if !position.liquidity.Eq(tt.expectLiquidity) { - t.Errorf("liquidity mismatch: expected %s, got %s", - tt.expectLiquidity.ToString(), + t.Errorf("liquidity mismatch: expected %s, got %s", + tt.expectLiquidity.ToString(), position.liquidity.ToString()) } if !tt.positionParams.liquidityDelta.IsZero() { - lowerTick := p.ticks[tt.positionParams.tickLower] - upperTick := p.ticks[tt.positionParams.tickUpper] + iLowerTick, _ := p.ticks.Get(strconv.FormatInt(int64(tt.positionParams.tickLower), 10)) + iUpperTick, _ := p.ticks.Get(strconv.FormatInt(int64(tt.positionParams.tickUpper), 10)) + lowerTick := iLowerTick.(TickInfo) + upperTick := iUpperTick.(TickInfo) if !lowerTick.initialized { t.Error("lower tick not initialized") diff --git a/pool/protocol_fee_withdrawal_test.gno b/pool/protocol_fee_withdrawal_test.gno index 241459d9..7c28e863 100644 --- a/pool/protocol_fee_withdrawal_test.gno +++ b/pool/protocol_fee_withdrawal_test.gno @@ -55,8 +55,7 @@ func TestHandleWithdrawalFee(t *testing.T) { InitialisePoolTest(t) std.TestSetRealm(std.NewUserRealm(users.Resolve(position))) poolPath := GetPoolPath(wugnotPath, gnsPath, fee3000) - _, found := pools[poolPath] - if !found { + if !DoesPoolPathExist(poolPath) { panic("pool not found") } TokenApprove(t, wugnotPath, alice, protocolFee, uint64(0)) diff --git a/pool/tick.gno b/pool/tick.gno index d325826f..5ed2d17a 100644 --- a/pool/tick.gno +++ b/pool/tick.gno @@ -1,6 +1,8 @@ package pool import ( + "strconv" + "gno.land/p/demo/ufmt" i256 "gno.land/p/gnoswap/int256" @@ -85,7 +87,7 @@ func (pool *Pool) tickUpdate( thisTick.liquidityNet = i256.Zero().Add(thisTick.liquidityNet, liquidityDelta) } - pool.ticks[tick] = thisTick + pool.ticks.Set(strconv.FormatInt(int64(tick), 10), thisTick) return flipped } @@ -101,13 +103,19 @@ func (pool *Pool) tickCross( thisTick.feeGrowthOutside0X128 = new(u256.Uint).Sub(feeGrowthGlobal0X128, thisTick.feeGrowthOutside0X128) thisTick.feeGrowthOutside1X128 = new(u256.Uint).Sub(feeGrowthGlobal1X128, thisTick.feeGrowthOutside1X128) - pool.ticks[tick] = thisTick + pool.ticks.Set(strconv.FormatInt(int64(tick), 10), thisTick) return thisTick.liquidityNet } func (pool *Pool) getTick(tick int32) TickInfo { - tickInfo := pool.ticks[tick] + iTickInfo, exist := pool.ticks.Get(strconv.FormatInt(int64(tick), 10)) + if !exist { + res := TickInfo{} + res.init() + return res + } + tickInfo := iTickInfo.(TickInfo) tickInfo.init() return tickInfo @@ -177,7 +185,7 @@ func (p *Pool) GetTickInitialized(tick int32) bool { } func (p *Pool) mustGetTick(tick int32) TickInfo { - tickInfo, exist := p.ticks[tick] + tickInfo, exist := p.ticks.Get(strconv.FormatInt(int64(tick), 10)) if !exist { panic(addDetailToError( errDataNotFound, @@ -185,5 +193,5 @@ func (p *Pool) mustGetTick(tick int32) TickInfo { )) } - return tickInfo + return tickInfo.(TickInfo) } diff --git a/pool/tick_bitmap.gno b/pool/tick_bitmap.gno index 45b4ba9e..5aa8efaa 100644 --- a/pool/tick_bitmap.gno +++ b/pool/tick_bitmap.gno @@ -1,6 +1,8 @@ package pool import ( + "strconv" + "gno.land/p/demo/ufmt" plp "gno.land/p/gnoswap/pool" @@ -59,16 +61,20 @@ func (pool *Pool) tickBitmapNextInitializedTickWithInOneWord( // getTickBitmap gets the tick bitmap for the given word position // if the tick bitmap is not initialized, initialize it to zero func (pool *Pool) getTickBitmap(wordPos int16) *u256.Uint { - if pool.tickBitmaps[wordPos] == nil { - pool.tickBitmaps[wordPos] = u256.Zero() + wordPosStr := strconv.FormatInt(int64(wordPos), 10) + + bitmap, exist := pool.tickBitmaps.Get(wordPosStr) + if !exist { + pool.tickBitmaps.Set(wordPosStr, u256.Zero()) + return u256.Zero() } - return pool.tickBitmaps[wordPos] + return bitmap.(*u256.Uint) } // setTickBitmap sets the tick bitmap for the given word position func (pool *Pool) setTickBitmap(wordPos int16, bitmap *u256.Uint) { - pool.tickBitmaps[wordPos] = bitmap + pool.tickBitmaps.Set(strconv.FormatInt(int64(wordPos), 10), bitmap) } // getWordAndBitPos gets tick's wordPos and bitPos depending on the swap direction @@ -81,20 +87,18 @@ func getWordAndBitPos(tick int32, lte bool) (int16, uint8) { return tickBitmapPosition(tick) } -// bMap is a map that maps boolean values to uint values. -// true maps to 1, and false maps to 0. -var bMap = map[bool]uint{ - true: 1, - false: 0, -} - // getMaskBit generates a mask based on the provided bit position (bitPos) and a boolean flag (lte). // The function constructs a bitmask with a shift depending on the bit position and the boolean value. // It either returns the mask or its negation, based on the value of 'lte' (swap direction). func getMaskBit(bitPos uint, lte bool) *u256.Uint { - // Shift the number 1 to the left by (bitPos + bMap[lte]) positions. + // Shift the number 1 to the left by (bitPos + bitPosAdjust) positions. // If lte is true, the shift will be bitPos + 1; if false, it will be just bitPos. - shifted := new(u256.Uint).Lsh(u256.One(), bitPos+bMap[lte]) + var bitPosAdjust uint + if lte { + bitPosAdjust = 1 + } + + shifted := new(u256.Uint).Lsh(u256.One(), bitPos+bitPosAdjust) // Subtract 1 from the shifted value to create a mask. mask := new(u256.Uint).Sub(shifted, u256.One()) diff --git a/pool/tick_bitmap_test.gno b/pool/tick_bitmap_test.gno index 37c8c047..487d9700 100644 --- a/pool/tick_bitmap_test.gno +++ b/pool/tick_bitmap_test.gno @@ -1,8 +1,11 @@ package pool import ( + "strconv" "testing" + "gno.land/p/demo/avl" + u256 "gno.land/p/gnoswap/uint256" ) @@ -76,7 +79,7 @@ func TestTickBitmapFlipTick(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pool := &Pool{ - tickBitmaps: make(map[int16]*u256.Uint), + tickBitmaps: avl.NewTree(), } if tt.shouldPanic { @@ -92,7 +95,9 @@ func TestTickBitmapFlipTick(t *testing.T) { if !tt.shouldPanic { wordPos, bitPos := tickBitmapPosition(tt.tick / tt.tickSpacing) expected := new(u256.Uint).Lsh(u256.NewUint(1), uint(bitPos)) - if pool.tickBitmaps[wordPos].Cmp(expected) != 0 { + iBitmap, _ := pool.tickBitmaps.Get(strconv.FormatInt(int64(wordPos), 10)) + bitmap := iBitmap.(*u256.Uint) + if bitmap.Cmp(expected) != 0 { t.Errorf("bitmap not set correctly") } } @@ -137,7 +142,7 @@ func TestTickBitmapNextInitializedTickWithInOneWord(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pool := &Pool{ - tickBitmaps: make(map[int16]*u256.Uint), + tickBitmaps: avl.NewTree(), } if tt.setupBitmap != nil { tt.setupBitmap(pool) diff --git a/pool/tick_test.gno b/pool/tick_test.gno index 34fb95ab..49fd173c 100644 --- a/pool/tick_test.gno +++ b/pool/tick_test.gno @@ -1,8 +1,10 @@ package pool import ( + "strconv" "testing" + "gno.land/p/demo/avl" "gno.land/p/demo/uassert" i256 "gno.land/p/gnoswap/int256" @@ -60,24 +62,24 @@ func TestcalculateMaxLiquidityPerTick(t *testing.T) { func TestCalculateFeeGrowthInside(t *testing.T) { // Create a mock pool pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } // Setup test ticks - pool.ticks[0] = TickInfo{ + pool.ticks.Set("0", TickInfo{ liquidityGross: u256.NewUint(1000), liquidityNet: i256.NewInt(100), feeGrowthOutside0X128: u256.NewUint(5), feeGrowthOutside1X128: u256.NewUint(7), initialized: true, - } - pool.ticks[100] = TickInfo{ + }) + pool.ticks.Set("100", TickInfo{ liquidityGross: u256.NewUint(2000), liquidityNet: i256.NewInt(-100), feeGrowthOutside0X128: u256.NewUint(10), feeGrowthOutside1X128: u256.NewUint(15), initialized: true, - } + }) tests := []struct { name string @@ -256,7 +258,7 @@ func TestCalculateFeeGrowthInside(t *testing.T) { func TestTickUpdate(t *testing.T) { pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } tests := []struct { @@ -398,7 +400,8 @@ func TestTickUpdate(t *testing.T) { wantFlipped: false, shouldPanic: false, verify: func() { - info := pool.ticks[1] + iInfo, _ := pool.ticks.Get(strconv.FormatInt(int64(1), 10)) + info := iInfo.(TickInfo) uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "1") uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "2") }, @@ -419,7 +422,8 @@ func TestTickUpdate(t *testing.T) { wantFlipped: false, shouldPanic: false, verify: func() { - info := pool.ticks[1] + iInfo, _ := pool.ticks.Get(strconv.FormatInt(int64(1), 10)) + info := iInfo.(TickInfo) uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "1") uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "2") }, @@ -436,9 +440,13 @@ func TestTickUpdate(t *testing.T) { wantFlipped: false, shouldPanic: false, verify: func() { - info := pool.ticks[2] - uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "") - uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "") + var tick TickInfo + iTick, exist := pool.ticks.Get(strconv.FormatInt(int64(2), 10)) + if exist { + tick = iTick.(TickInfo) + } + uassert.Equal(t, tick.feeGrowthOutside0X128.ToString(), "") + uassert.Equal(t, tick.feeGrowthOutside1X128.ToString(), "") }, }, } @@ -489,17 +497,17 @@ func TestTickUpdate(t *testing.T) { func TestTickCross(t *testing.T) { pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } // Setup initial tick state - pool.ticks[100] = TickInfo{ + pool.ticks.Set("100", TickInfo{ liquidityGross: u256.NewUint(1000), liquidityNet: i256.NewInt(500), feeGrowthOutside0X128: u256.NewUint(10), feeGrowthOutside1X128: u256.NewUint(15), initialized: true, - } + }) tests := []struct { name string @@ -534,7 +542,7 @@ func TestTickCross(t *testing.T) { func TestGetTick(t *testing.T) { pool := &Pool{ - ticks: make(map[int32]TickInfo), + ticks: avl.NewTree(), } // Setup a tick @@ -545,7 +553,7 @@ func TestGetTick(t *testing.T) { feeGrowthOutside1X128: u256.NewUint(15), initialized: true, } - pool.ticks[50] = expectedTick + pool.ticks.Set("50", expectedTick) tests := []struct { name string @@ -588,8 +596,14 @@ func (pool *Pool) setTick( ) { t.Helper() - info := pool.ticks[tick] - info.init() + var info TickInfo + iInfo, exist := pool.ticks.Get(strconv.FormatInt(int64(tick), 10)) + if !exist { + info = TickInfo{} + info.init() + } else { + info = iInfo.(TickInfo) + } info.feeGrowthOutside0X128 = feeGrowthOutside0X128 info.feeGrowthOutside1X128 = feeGrowthOutside1X128 @@ -600,11 +614,10 @@ func (pool *Pool) setTick( info.secondsOutside = secondsOutside info.initialized = initialized - pool.ticks[tick] = info + pool.ticks.Set(strconv.FormatInt(int64(tick), 10), info) } func (pool *Pool) deleteTick(t *testing.T, tick int32) { t.Helper() - - delete(pool.ticks, tick) + pool.ticks.Remove(strconv.FormatInt(int64(tick), 10)) } diff --git a/pool/type.gno b/pool/type.gno index d2a6c519..b3258a30 100644 --- a/pool/type.gno +++ b/pool/type.gno @@ -3,11 +3,13 @@ package pool import ( "std" - "gno.land/r/gnoswap/v1/common" - "gno.land/r/gnoswap/v1/consts" + "gno.land/p/demo/avl" i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" ) type Slot0 struct { @@ -154,14 +156,14 @@ func (step *StepComputations) initSwapStep(state SwapState, pool *Pool, zeroForO step.sqrtPriceStartX96 = state.sqrtPriceX96 step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( state.tick, - pool.tickSpacing, + pool.tickSpacing, zeroForOne, ) // prevent overshoot the min/max tick step.clampTickNext() - // get the price for the next tick + // get the price for the next tick step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) } @@ -225,10 +227,6 @@ func (t *TickInfo) init() { t.secondsPerLiquidityOutsideX128 = t.secondsPerLiquidityOutsideX128.NilToZero() } -type Ticks map[int32]TickInfo // tick => TickInfo -type TickBitmaps map[int16]*u256.Uint // tick(wordPos) => bitmap(tickWord ^ mask) -type Positions map[string]PositionInfo // positionKey => PositionInfo - // type Pool describes a single Pool's state // A pool is identificed with a unique key (token0, token1, fee), where token0 < token1 type Pool struct { @@ -253,11 +251,11 @@ type Pool struct { liquidity *u256.Uint // total amount of liquidity in the pool - ticks Ticks // maps tick index to tick + ticks *avl.Tree // tick(int32) -> TickInfo - tickBitmaps TickBitmaps // maps tick index to tick bitmap + tickBitmaps *avl.Tree // tick(wordPos)(int16) -> bitMap(tickWord ^ mask)(*u256.Uint) - positions Positions // maps the key (caller, lower tick, upper tick) to a unique position + positions *avl.Tree // positionKey -> PositionInfo } func newPool(poolInfo *createPoolParams) *Pool { @@ -277,8 +275,8 @@ func newPool(poolInfo *createPoolParams) *Pool { feeGrowthGlobal1X128: u256.Zero(), protocolFees: newProtocolFees(), liquidity: u256.Zero(), - ticks: Ticks{}, - tickBitmaps: TickBitmaps{}, - positions: Positions{}, + ticks: avl.NewTree(), + tickBitmaps: avl.NewTree(), + positions: avl.NewTree(), } }