Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GSW-1838 refactor: use avl.Tree in pool #430

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions pool/_helper_test.gno
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"gno.land/r/onbloc/qux"
"gno.land/r/onbloc/usdc"

"gno.land/p/demo/avl"
"gno.land/p/demo/testutils"
"gno.land/p/demo/uassert"
pusers "gno.land/p/demo/users"
Expand All @@ -23,7 +24,7 @@ import (

const (
ugnotDenom string = "ugnot"
ugnotPath string = "gno.land/r/gnoswap/v1/pool:ugnot"
ugnotPath string = "ugnot"
wugnotPath string = "gno.land/r/demo/wugnot"
gnsPath string = "gno.land/r/gnoswap/v1/gns"
barPath string = "gno.land/r/onbloc/bar"
Expand Down Expand Up @@ -517,7 +518,7 @@ func ugnotDeposit(t *testing.T, addr std.Address, amount uint64) {

// resetObject resets the object state(clear or make it default values)
func resetObject(t *testing.T) {
pools = make(poolMap)
pools = avl.NewTree()
slot0FeeProtocol = 0
poolCreationFee = 100_000_000
withdrawalFeeBPS = 100
Expand Down Expand Up @@ -570,11 +571,11 @@ func burnUsdc(addr pusers.AddressOrName) {

func TestBeforeResetObject(t *testing.T) {
// make some data
pools = make(poolMap)
pools["gno.land/r/gnoswap/v1/gns:gno.land/r/onbloc/usdc"] = &Pool{
pools = avl.NewTree()
pools.Set("gno.land/r/gnoswap/v1/gns:gno.land/r/onbloc/usdc", &Pool{
token0Path: "gno.land/r/gnoswap/v1/gns",
token1Path: "gno.land/r/onbloc/usdc",
}
})

slot0FeeProtocol = 1
poolCreationFee = 100_000_000
Expand All @@ -591,7 +592,7 @@ func TestBeforeResetObject(t *testing.T) {

func TestResetObject(t *testing.T) {
resetObject(t)
uassert.Equal(t, len(pools), 0)
uassert.Equal(t, pools.Size(), 0)
uassert.Equal(t, slot0FeeProtocol, uint8(0))
uassert.Equal(t, poolCreationFee, uint64(100_000_000))
uassert.Equal(t, withdrawalFeeBPS, uint64(100))
Expand Down
43 changes: 28 additions & 15 deletions pool/api.gno
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
b64 "encoding/base64"

"gno.land/p/demo/json"
u256 "gno.land/p/gnoswap/uint256"

"std"
"strconv"
Expand Down Expand Up @@ -81,14 +82,16 @@ type RpcPosition struct {

func ApiGetPools() string {
rpcPools := []RpcPool{}
for poolPath, _ := range pools {
pools.Iterate("", "", func(poolPath string, value interface{}) bool {
rpcPool := rpcMakePool(poolPath)
rpcPools = append(rpcPools, rpcPool)
}

return false
})

responses := json.ArrayNode("", []*json.Node{})
for _, pool := range rpcPools {
_poolNode := json.ObjectNode("", map[string]*json.Node{
poolNode := json.ObjectNode("", map[string]*json.Node{
"poolPath": json.StringNode("poolPath", pool.PoolPath),
"token0Path": json.StringNode("token0Path", pool.Token0Path),
"token1Path": json.StringNode("token1Path", pool.Token1Path),
Expand All @@ -110,7 +113,7 @@ func ApiGetPools() string {
"tickBitmaps": json.ObjectNode("tickBitmaps", makeRpcTickBitmapsJson(pool.TickBitmaps)),
"positions": json.ArrayNode("positions", makeRpcPositionsArray(pool.Positions)),
})
responses.AppendArray(_poolNode)
responses.AppendArray(poolNode)
}

node := json.ObjectNode("", map[string]*json.Node{
Expand All @@ -122,8 +125,7 @@ func ApiGetPools() string {
}

func ApiGetPool(poolPath string) string {
_, exist := pools[poolPath]
if !exist {
if !pools.Has(poolPath) {
return ""
}
rpcPool := rpcMakePool(poolPath)
Expand Down Expand Up @@ -198,8 +200,11 @@ func rpcMakePool(poolPath string) RpcPool {
rpcPool.Liquidity = pool.liquidity.ToString()

rpcPool.Ticks = RpcTicks{}
for tick, tickInfo := range pool.ticks {
rpcPool.Ticks[tick] = RpcTickInfo{
pool.ticks.Iterate("", "", func(tickStr string, iTickInfo interface{}) bool {
tick, _ := strconv.Atoi(tickStr)
tickInfo := iTickInfo.(TickInfo)

rpcPool.Ticks[int32(tick)] = RpcTickInfo{
LiquidityGross: tickInfo.liquidityGross.ToString(),
LiquidityNet: tickInfo.liquidityNet.ToString(),
FeeGrowthOutside0X128: tickInfo.feeGrowthOutside0X128.ToString(),
Expand All @@ -209,17 +214,22 @@ func rpcMakePool(poolPath string) RpcPool {
SecondsOutside: tickInfo.secondsOutside,
Initialized: tickInfo.initialized,
}
}

return false
})

rpcPool.TickBitmaps = RpcTickBitmaps{}
for tick, tickBitmap := range pool.tickBitmaps {
rpcPool.TickBitmaps[tick] = tickBitmap.ToString()
}
pool.tickBitmaps.Iterate("", "", func(tickStr string, iTickBitmap interface{}) bool {
tick, _ := strconv.Atoi(tickStr)
pool.setTickBitmap(int16(tick), iTickBitmap.(*u256.Uint))

return false
})

Positions := pool.positions
rpcPositions := []RpcPosition{}
for posKey, posInfo := range Positions {
pool.positions.Iterate("", "", func(posKey string, iPositionInfo interface{}) bool {
owner, tickLower, tickUpper := posKeyDivide(posKey)
posInfo := iPositionInfo.(PositionInfo)

rpcPositions = append(rpcPositions, RpcPosition{
Owner: owner,
Expand All @@ -229,7 +239,10 @@ func rpcMakePool(poolPath string) RpcPool {
Token0Owed: posInfo.tokensOwed0.ToString(),
Token1Owed: posInfo.tokensOwed1.ToString(),
})
}

return false
})

rpcPool.Positions = rpcPositions

return rpcPool
Expand Down
2 changes: 1 addition & 1 deletion pool/api_test.gno
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestInitTwoPools(t *testing.T) {
// bar:baz
CreatePool(barPath, bazPath, fee500, "130621891405341611593710811006") // tick 10000

uassert.Equal(t, len(pools), 2)
uassert.Equal(t, pools.Size(), 2)
}

func TestApiGetPools(t *testing.T) {
Expand Down
6 changes: 4 additions & 2 deletions pool/getter.gno
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package pool

func PoolGetPoolList() []string {
poolPaths := []string{}
for poolPath, _ := range pools {
pools.Iterate("", "", func(poolPath string, _ interface{}) bool {
poolPaths = append(poolPaths, poolPath)
}

return false
})

return poolPaths
}
Expand Down
16 changes: 9 additions & 7 deletions pool/getter_test.gno
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package pool
import (
"testing"

"gno.land/p/demo/avl"

i256 "gno.land/p/gnoswap/int256"
u256 "gno.land/p/gnoswap/uint256"
)
Expand Down Expand Up @@ -35,8 +37,8 @@ func TestInitData(t *testing.T) {
liquidity: u256.NewUint(1000000),
}

mockTicks := Ticks{}
mockTicks[0] = TickInfo{
mockTicks := avl.NewTree()
mockTicks.Set("0", TickInfo{
liquidityGross: u256.NewUint(1000000),
liquidityNet: i256.NewInt(2000000),
feeGrowthOutside0X128: u256.NewUint(3000000),
Expand All @@ -45,20 +47,20 @@ func TestInitData(t *testing.T) {
secondsPerLiquidityOutsideX128: u256.NewUint(6000000),
secondsOutside: 7,
initialized: true,
}
})
mockPool.ticks = mockTicks

mockPositions := Positions{}
mockPositions["test_position"] = PositionInfo{
mockPositions := avl.NewTree()
mockPositions.Set("test_position", PositionInfo{
liquidity: u256.NewUint(1000000),
feeGrowthInside0LastX128: u256.NewUint(2000000),
feeGrowthInside1LastX128: u256.NewUint(3000000),
tokensOwed0: u256.NewUint(4000000),
tokensOwed1: u256.NewUint(5000000),
}
})
mockPool.positions = mockPositions

pools["token0:token1:3000"] = mockPool
pools.Set("token0:token1:3000", mockPool)
}

func TestPoolGetters(t *testing.T) {
Expand Down
17 changes: 10 additions & 7 deletions pool/pool.gno
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func Burn(
}

positionKey := getPositionKey(caller, tickLower, tickUpper)
pool.positions[positionKey] = position
pool.setPosition(positionKey, position)

// actual token transfer happens in Collect()
return amount0.ToString(), amount1.ToString()
Expand Down Expand Up @@ -192,7 +192,7 @@ func Collect(
checkTransferError(token1.Transfer(recipient, amount1.Uint64()))
}

pool.positions[positionKey] = position
pool.setPosition(positionKey, position)

return amount0.ToString(), amount1.ToString()
}
Expand Down Expand Up @@ -316,12 +316,15 @@ func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 {
// - feePrtocol0 occupies the lower 4 bits
// - feeProtocol1 is shifted the lower 4 positions to occupy the upper 4 bits
newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 )

// Update slot0 for each pool
for _, pool := range pools {
if pool != nil {
pool.slot0.feeProtocol = newFee
}
}
pools.Iterate("", "", func(poolPath string, iPool interface{}) bool {
pool := iPool.(*Pool)
pool.slot0.feeProtocol = newFee

return false
})

// update slot0
slot0FeeProtocol = newFee
return newFee
Expand Down
82 changes: 50 additions & 32 deletions pool/pool_manager.gno
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package pool

import (
"std"
"strconv"
"strings"

"gno.land/p/demo/avl"
"gno.land/p/demo/ufmt"

"gno.land/r/gnoswap/v1/consts"
Expand All @@ -15,39 +17,18 @@ import (
u256 "gno.land/p/gnoswap/uint256"
)

type poolMap map[string]*Pool

func (pm *poolMap) Get(poolPath string) (*Pool, bool) {
pool, exist := (*pm)[poolPath]
if !exist {
return nil, false
}

return pool, true
}

func (pm *poolMap) Set(poolPath string, pool *Pool) {
(*pm)[poolPath] = pool
}

type tickSpacingMap map[uint32]int32

func (t *tickSpacingMap) Get(fee uint32) int32 {
return (*t)[fee]
}

var (
feeAmountTickSpacing tickSpacingMap = make(tickSpacingMap) // maps fee to tickSpacing || map[feeAmount]tick_spacing
pools poolMap = make(poolMap) // maps poolPath to pool || map[poolPath]*Pool
feeAmountTickSpacing = avl.NewTree() // feeBps(uint32) -> tickSpacing(int32)
pools = avl.NewTree() // poolPath -> *Pool

slot0FeeProtocol uint8 = 0
)

func init() {
feeAmountTickSpacing[100] = 1 // 0.01%
feeAmountTickSpacing[500] = 10 // 0.05%
feeAmountTickSpacing[3000] = 60 // 0.3%
feeAmountTickSpacing[10000] = 200 // 1%
setFeeAmountTickSpacing(100, 1) // 0.01%
setFeeAmountTickSpacing(500, 10) // 0.05%
setFeeAmountTickSpacing(3000, 60) // 0.3%
setFeeAmountTickSpacing(10000, 200) // 1%
}

// createPoolParams holds the essential parameters for creating a new pool.
Expand All @@ -66,7 +47,7 @@ func newPoolParams(
sqrtPriceX96 string,
) *createPoolParams {
price := u256.MustFromDecimal(sqrtPriceX96)
tickSpacing := feeAmountTickSpacing.Get(fee)
tickSpacing := GetFeeAmountTickSpacing(fee)
return &createPoolParams{
token0Path: token0Path,
token1Path: token1Path,
Expand Down Expand Up @@ -211,8 +192,7 @@ func CreatePool(
// DoesPoolPathExist checks if a pool exists for the given poolPath.
// The poolPath is a unique identifier for a pool, combining token paths and fee.
func DoesPoolPathExist(poolPath string) bool {
_, exist := pools[poolPath]
return exist
return pools.Has(poolPath)
}

// GetPool retrieves a pool instance based on the provided token paths and fee tier.
Expand Down Expand Up @@ -260,14 +240,14 @@ func GetPool(token0Path, token1Path string, fee uint32) *Pool {
// Example:
// pool := GetPoolFromPoolPath("path/to/pool")
func GetPoolFromPoolPath(poolPath string) *Pool {
pool, exist := pools[poolPath]
iPool, exist := pools.Get(poolPath)
if !exist {
panic(addDetailToError(
errDataNotFound,
ufmt.Sprintf("expected poolPath(%s) to exist", poolPath),
))
}
return pool
return iPool.(*Pool)
}

// GetPoolPath generates a unique pool path string based on the token paths and fee tier.
Expand Down Expand Up @@ -303,3 +283,41 @@ func GetPoolPath(token0Path, token1Path string, fee uint32) string {
}
return ufmt.Sprintf("%s:%s:%d", token0Path, token1Path, fee)
}

// GetFeeAmountTickSpacing retrieves the tick spacing associated with a given fee amount.
// The tick spacing determines the minimum distance between ticks in the pool.
//
// Parameters:
// - fee (uint32): The fee tier in basis points (e.g., 3000 for 0.3%)
//
// Returns:
// - int32: The tick spacing value for the given fee tier
//
// Panics:
// - If the fee amount is not registered in feeAmountTickSpacing
func GetFeeAmountTickSpacing(fee uint32) int32 {
feeStr := strconv.FormatUint(uint64(fee), 10)
iTickSpacing, exist := feeAmountTickSpacing.Get(feeStr)
if !exist {
panic(addDetailToError(
errDataNotFound,
ufmt.Sprintf("expected feeAmountTickSpacing(%s) to exist", feeStr),
))
}

return iTickSpacing.(int32)
}

// setFeeAmountTickSpacing associates a tick spacing value with a fee amount.
// This is typically called during initialization to set up supported fee tiers.
//
// Parameters:
// - fee (uint32): The fee tier in basis points (e.g., 3000 for 0.3%)
// - tickSpacing (int32): The minimum tick spacing for this fee tier
//
// Note: Smaller tick spacing allows for more granular price points but increases
// computational overhead. Higher fee tiers typically use larger tick spacing.
func setFeeAmountTickSpacing(fee uint32, tickSpacing int32) {
feeStr := strconv.FormatUint(uint64(fee), 10)
feeAmountTickSpacing.Set(feeStr, tickSpacing)
}
Loading
Loading