Skip to content

Commit

Permalink
feat: use avl.Tree instead of map
Browse files Browse the repository at this point in the history
  • Loading branch information
r3v4s committed Dec 16, 2024
1 parent 5aadc0c commit bf8e09d
Show file tree
Hide file tree
Showing 18 changed files with 234 additions and 213 deletions.
21 changes: 12 additions & 9 deletions pool/_helper_test.gno
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ import (
"std"
"testing"

"gno.land/p/demo/avl"
"gno.land/p/demo/testutils"
"gno.land/p/demo/uassert"
pusers "gno.land/p/demo/users"

"gno.land/r/demo/users"

"gno.land/r/demo/wugnot"
"gno.land/r/gnoswap/v1/gns"
"gno.land/r/onbloc/bar"
Expand All @@ -13,10 +20,6 @@ import (
"gno.land/r/onbloc/qux"
"gno.land/r/onbloc/usdc"

"gno.land/p/demo/testutils"
"gno.land/p/demo/uassert"
pusers "gno.land/p/demo/users"
"gno.land/r/demo/users"
"gno.land/r/gnoswap/v1/consts"
pn "gno.land/r/gnoswap/v1/position"
)
Expand Down Expand Up @@ -360,7 +363,7 @@ func ugnotDeposit(t *testing.T, addr std.Address, amount uint64) {

// resetObject resets the object state(clear or make it default values)
func resetObject(t *testing.T) {
pools = make(poolMap)
pools = avl.NewTree()
slot0FeeProtocol = 0
poolCreationFee = 100_000_000
withdrawalFeeBPS = 100
Expand Down Expand Up @@ -413,11 +416,11 @@ func burnUsdc(addr pusers.AddressOrName) {

func TestBeforeResetObject(t *testing.T) {
// make some data
pools = make(poolMap)
pools["gno.land/r/gnoswap/v1/gns:gno.land/r/onbloc/usdc"] = &Pool{
pools = avl.NewTree()
pools.Set("gno.land/r/gnoswap/v1/gns:gno.land/r/onbloc/usdc", &Pool{
token0Path: "gno.land/r/gnoswap/v1/gns",
token1Path: "gno.land/r/onbloc/usdc",
}
})

slot0FeeProtocol = 1
poolCreationFee = 100_000_000
Expand All @@ -434,7 +437,7 @@ func TestBeforeResetObject(t *testing.T) {

func TestResetObject(t *testing.T) {
resetObject(t)
uassert.Equal(t, len(pools), 0)
uassert.Equal(t, pools.Size(), 0)
uassert.Equal(t, slot0FeeProtocol, uint8(0))
uassert.Equal(t, poolCreationFee, uint64(100_000_000))
uassert.Equal(t, withdrawalFeeBPS, uint64(100))
Expand Down
43 changes: 27 additions & 16 deletions pool/api.gno
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ package pool

import (
b64 "encoding/base64"

"gno.land/p/demo/json"

"std"
"strconv"
"strings"
"time"

"gno.land/p/demo/json"
"gno.land/p/demo/ufmt"

u256 "gno.land/p/gnoswap/uint256"
)

type RpcPool struct {
Expand Down Expand Up @@ -81,10 +81,11 @@ type RpcPosition struct {

func ApiGetPools() string {
rpcPools := []RpcPool{}
for poolPath, _ := range pools {
pools.Iterate("", "", func(poolPath string, value interface{}) bool {
rpcPool := rpcMakePool(poolPath)
rpcPools = append(rpcPools, rpcPool)
}
return false
})

responses := json.ArrayNode("", []*json.Node{})
for _, pool := range rpcPools {
Expand Down Expand Up @@ -122,10 +123,10 @@ func ApiGetPools() string {
}

func ApiGetPool(poolPath string) string {
_, exist := pools[poolPath]
if !exist {
if getPoolFromPoolPath(poolPath) == nil {
return ""
}

rpcPool := rpcMakePool(poolPath)

responseNode := json.ObjectNode("", map[string]*json.Node{
Expand Down Expand Up @@ -198,8 +199,10 @@ func rpcMakePool(poolPath string) RpcPool {
rpcPool.Liquidity = pool.liquidity.ToString()

rpcPool.Ticks = RpcTicks{}
for tick, tickInfo := range pool.ticks {
rpcPool.Ticks[tick] = RpcTickInfo{
pool.ticks.Iterate("", "", func(strTick string, value interface{}) bool {
tick, _ := strconv.Atoi(strTick)
tickInfo := value.(TickInfo)
rpcPool.Ticks[int32(tick)] = RpcTickInfo{
LiquidityGross: tickInfo.liquidityGross.ToString(),
LiquidityNet: tickInfo.liquidityNet.ToString(),
FeeGrowthOutside0X128: tickInfo.feeGrowthOutside0X128.ToString(),
Expand All @@ -209,18 +212,24 @@ func rpcMakePool(poolPath string) RpcPool {
SecondsOutside: tickInfo.secondsOutside,
Initialized: tickInfo.initialized,
}
}

return false
})

rpcPool.TickBitmaps = RpcTickBitmaps{}
for tick, tickBitmap := range pool.tickBitmaps {
rpcPool.TickBitmaps[tick] = tickBitmap.ToString()
}
pool.tickBitmaps.Iterate("", "", func(strTick string, value interface{}) bool {
tick, _ := strconv.Atoi(strTick)
tickBitmap := value.(*u256.Uint)
rpcPool.TickBitmaps[int16(tick)] = tickBitmap.ToString()

return false
})

Positions := pool.positions
rpcPositions := []RpcPosition{}
for posKey, posInfo := range Positions {
pool.positions.Iterate("", "", func(posKey string, value interface{}) bool {
owner, tickLower, tickUpper := posKeyDivide(posKey)

posInfo := value.(PositionInfo)
rpcPositions = append(rpcPositions, RpcPosition{
Owner: owner,
TickLower: tickLower,
Expand All @@ -229,7 +238,9 @@ func rpcMakePool(poolPath string) RpcPool {
Token0Owed: posInfo.tokensOwed0.ToString(),
Token1Owed: posInfo.tokensOwed1.ToString(),
})
}

return false
})
rpcPool.Positions = rpcPositions

return rpcPool
Expand Down
2 changes: 1 addition & 1 deletion pool/api_test.gno
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestInitTwoPools(t *testing.T) {
// bar:baz
CreatePool(barPath, bazPath, fee500, "130621891405341611593710811006") // tick 10000

uassert.Equal(t, len(pools), 2)
uassert.Equal(t, pools.Size(), 2)
}

func TestApiGetPools(t *testing.T) {
Expand Down
6 changes: 4 additions & 2 deletions pool/getter.gno
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package pool
// pool
func PoolGetPoolList() []string {
poolPaths := []string{}
for poolPath, _ := range pools {

pools.Iterate("", "", func(poolPath string, value interface{}) bool {
poolPaths = append(poolPaths, poolPath)
}
return false
})

return poolPaths
}
Expand Down
16 changes: 9 additions & 7 deletions pool/getter_test.gno
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package pool
import (
"testing"

"gno.land/p/demo/avl"

i256 "gno.land/p/gnoswap/int256"
u256 "gno.land/p/gnoswap/uint256"
)
Expand Down Expand Up @@ -35,8 +37,8 @@ func TestInitData(t *testing.T) {
liquidity: u256.NewUint(1000000),
}

mockTicks := Ticks{}
mockTicks[0] = TickInfo{
mockTicks := avl.NewTree()
mockTicks.Set("0", TickInfo{
liquidityGross: u256.NewUint(1000000),
liquidityNet: i256.NewInt(2000000),
feeGrowthOutside0X128: u256.NewUint(3000000),
Expand All @@ -45,20 +47,20 @@ func TestInitData(t *testing.T) {
secondsPerLiquidityOutsideX128: u256.NewUint(6000000),
secondsOutside: 7,
initialized: true,
}
})
mockPool.ticks = mockTicks

mockPositions := Positions{}
mockPositions["test_position"] = PositionInfo{
mockPositions := avl.NewTree()
mockPositions.Set("test_position", PositionInfo{
liquidity: u256.NewUint(1000000),
feeGrowthInside0LastX128: u256.NewUint(2000000),
feeGrowthInside1LastX128: u256.NewUint(3000000),
tokensOwed0: u256.NewUint(4000000),
tokensOwed1: u256.NewUint(5000000),
}
})
mockPool.positions = mockPositions

pools["token0:token1:3000"] = mockPool
pools.Set("token0:token1:3000", mockPool)
}

func TestPoolGetters(t *testing.T) {
Expand Down
31 changes: 22 additions & 9 deletions pool/pool.gno
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func Burn(
}

positionKey := positionGetKey(caller, tickLower, tickUpper)
pool.positions[positionKey] = position
pool.positions.Set(positionKey, position)

// actual token transfer happens in Collect()
return amount0.ToString(), amount1.ToString()
Expand Down Expand Up @@ -132,13 +132,14 @@ func Collect(
pool := GetPool(token0Path, token1Path, fee)

positionKey := positionGetKey(std.PrevRealm().Addr(), tickLower, tickUpper)
position, exist := pool.positions[positionKey]
iposition, exist := pool.positions.Get(positionKey)
if !exist {
panic(addDetailToError(
errDataNotFound,
ufmt.Sprintf("pool.gno__Collect() || positionKey(%s) does not exist", positionKey),
))
}
position := iposition.(PositionInfo)

var amount0, amount1 *u256.Uint

Expand All @@ -158,7 +159,7 @@ func Collect(
token1 := common.GetTokenTeller(pool.token1Path)
checkTransferError(token1.Transfer(recipient, amount1.Uint64()))

pool.positions[positionKey] = position
pool.positions.Set(positionKey, position)

return amount0.ToString(), amount1.ToString()
}
Expand Down Expand Up @@ -697,9 +698,11 @@ func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 {
newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 )

// iterate all pool
for _, pool := range pools {
pools.Iterate("", "", func(poolPath string, value interface{}) bool {
pool := value.(*Pool)
pool.slot0.feeProtocol = newFee
}
return false
})

// update slot0
slot0FeeProtocol = newFee
Expand Down Expand Up @@ -1054,10 +1057,20 @@ func (p *Pool) GetLiquidity() *u256.Uint {
}

func mustGetPool(poolPath string) *Pool {
pool, exist := pools[poolPath]
if !exist {
panic(addDetailToError(errDataNotFound,
ufmt.Sprintf("poolPath(%s) does not exist", poolPath)))
pool := getPoolFromPoolPath(poolPath)
if pool == nil {
panic(addDetailToError(
errDataNotFound,
ufmt.Sprintf("pool.gno__mustGetPool() || expected poolPath(%s) to exist", poolPath),
))
}
return pool
}

func getPoolFromPoolPath(poolPath string) *Pool {
pool, exist := pools.Get(poolPath)
if !exist {
return nil
}
return pool.(*Pool)
}
Loading

0 comments on commit bf8e09d

Please sign in to comment.