Skip to content

Commit

Permalink
AG-113: improve gas estimation for uniswap-v3 (#15)
Browse files Browse the repository at this point in the history
* refactor: swap return structure result

* feat: add count cross tick loops

* refactor: rename and only count for initialized ticks
  • Loading branch information
lehainam-dev authored Dec 15, 2023
1 parent 3015a9a commit 71a06a9
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 41 deletions.
105 changes: 77 additions & 28 deletions entities/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
)

var (
ErrFeeTooHigh = errors.New("Fee too high")
ErrInvalidSqrtRatioX96 = errors.New("Invalid sqrtRatioX96")
ErrTokenNotInvolved = errors.New("Token not involved in pool")
ErrFeeTooHigh = errors.New("fee too high")
ErrInvalidSqrtRatioX96 = errors.New("invalid sqrtRatioX96")
ErrTokenNotInvolved = errors.New("token not involved in pool")
ErrSqrtPriceLimitX96TooLow = errors.New("SqrtPriceLimitX96 too low")
ErrSqrtPriceLimitX96TooHigh = errors.New("SqrtPriceLimitX96 too high")
)
Expand Down Expand Up @@ -42,6 +42,20 @@ type Pool struct {
token1Price *entities.Price
}

type SwapResult struct {
amountCalculated *big.Int
sqrtRatioX96 *big.Int
liquidity *big.Int
currentTick int
crossInitTickLoops int
}

type GetAmountResult struct {
ReturnedAmount *entities.CurrencyAmount
NewPoolState *Pool
CrossInitTickLoops int
}

func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount, initCodeHashManualOverride string) (common.Address, error) {
return utils.ComputePoolAddress(constants.FactoryAddress, tokenA, tokenB, fee, initCodeHashManualOverride)
}
Expand Down Expand Up @@ -148,26 +162,38 @@ func (p *Pool) ChainID() uint {
* @param sqrtPriceLimitX96 The Q64.96 sqrt price limit
* @returns The output amount and the pool with updated state
*/
func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*entities.CurrencyAmount, *Pool, error) {
func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*GetAmountResult, error) {
if !(inputAmount.Currency.IsToken() && p.InvolvesToken(inputAmount.Currency.Wrapped())) {
return nil, nil, ErrTokenNotInvolved
return nil, ErrTokenNotInvolved
}
zeroForOne := inputAmount.Currency.Equal(p.Token0)
outputAmount, sqrtRatioX96, liquidity, tickCurrent, err := p.swap(zeroForOne, inputAmount.Quotient(), sqrtPriceLimitX96)
swapResult, err := p.swap(zeroForOne, inputAmount.Quotient(), sqrtPriceLimitX96)
if err != nil {
return nil, nil, err
return nil, err
}
var outputToken *entities.Token
if zeroForOne {
outputToken = p.Token1
} else {
outputToken = p.Token0
}
pool, err := NewPool(p.Token0, p.Token1, p.Fee, sqrtRatioX96, liquidity, tickCurrent, p.TickDataProvider)
pool, err := NewPool(
p.Token0,
p.Token1,
p.Fee,
swapResult.sqrtRatioX96,
swapResult.liquidity,
swapResult.currentTick,
p.TickDataProvider,
)
if err != nil {
return nil, nil, err
return nil, err
}
return entities.FromRawAmount(outputToken, new(big.Int).Mul(outputAmount, constants.NegativeOne)), pool, nil
return &GetAmountResult{
ReturnedAmount: entities.FromRawAmount(outputToken, new(big.Int).Mul(swapResult.amountCalculated, constants.NegativeOne)),
NewPoolState: pool,
CrossInitTickLoops: swapResult.crossInitTickLoops,
}, nil
}

/**
Expand All @@ -181,7 +207,7 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi
return nil, nil, ErrTokenNotInvolved
}
zeroForOne := outputAmount.Currency.Equal(p.Token1)
inputAmount, sqrtRatioX96, liquidity, tickCurrent, err := p.swap(zeroForOne, new(big.Int).Mul(outputAmount.Quotient(), constants.NegativeOne), sqrtPriceLimitX96)
swapResult, err := p.swap(zeroForOne, new(big.Int).Mul(outputAmount.Quotient(), constants.NegativeOne), sqrtPriceLimitX96)
if err != nil {
return nil, nil, err
}
Expand All @@ -191,24 +217,33 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi
} else {
inputToken = p.Token1
}
pool, err := NewPool(p.Token0, p.Token1, p.Fee, sqrtRatioX96, liquidity, tickCurrent, p.TickDataProvider)
pool, err := NewPool(
p.Token0,
p.Token1,
p.Fee,
swapResult.sqrtRatioX96,
swapResult.liquidity,
swapResult.currentTick,
p.TickDataProvider,
)
if err != nil {
return nil, nil, err
}
return entities.FromRawAmount(inputToken, inputAmount), pool, nil
return entities.FromRawAmount(inputToken, swapResult.amountCalculated), pool, nil
}

/**
* Executes a swap
* @param zeroForOne Whether the amount in is token0 or token1
* @param amountSpecified The amount of the swap, which implicitly configures the swap as exact input (positive), or exact output (negative)
* @param sqrtPriceLimitX96 The Q64.96 sqrt price limit. If zero for one, the price cannot be less than this value after the swap. If one for zero, the price cannot be greater than this value after the swap
* @returns amountCalculated
* @returns sqrtRatioX96
* @returns liquidity
* @returns tickCurrent
* @returns swapResult.amountCalculated
* @returns swapResult.sqrtRatioX96
* @returns swapResult.liquidity
* @returns swapResult.tickCurrent
*/
func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int) (amountCalCulated *big.Int, sqrtRatioX96 *big.Int, liquidity *big.Int, tickCurrent int, err error) {
func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int) (*SwapResult, error) {
var err error
if sqrtPriceLimitX96 == nil {
if zeroForOne {
sqrtPriceLimitX96 = new(big.Int).Add(utils.MinSqrtRatio, constants.One)
Expand All @@ -219,17 +254,17 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int

if zeroForOne {
if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatio) < 0 {
return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooLow
return nil, ErrSqrtPriceLimitX96TooLow
}
if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) >= 0 {
return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooHigh
return nil, ErrSqrtPriceLimitX96TooHigh
}
} else {
if sqrtPriceLimitX96.Cmp(utils.MaxSqrtRatio) > 0 {
return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooHigh
return nil, ErrSqrtPriceLimitX96TooHigh
}
if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) <= 0 {
return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooLow
return nil, ErrSqrtPriceLimitX96TooLow
}
}

Expand All @@ -251,6 +286,10 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int
liquidity: p.Liquidity,
}

// crossInitTickLoops is the number of loops that cross an initialized tick.
// We only count when tick passes an initialized tick, since gas only significant in this case.
crossInitTickLoops := 0

// start swap while loop
for state.amountSpecifiedRemaining.Cmp(constants.Zero) != 0 && state.sqrtPriceX96.Cmp(sqrtPriceLimitX96) != 0 {
var step StepComputations
Expand All @@ -261,7 +300,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int
// tickBitmap.nextInitializedTickWithinOneWord
step.tickNext, step.initialized, err = p.TickDataProvider.NextInitializedTickIndex(state.tick, zeroForOne)
if err != nil {
return nil, nil, nil, 0, err
return nil, err
}

if step.tickNext < utils.MinTick {
Expand All @@ -272,7 +311,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int

step.sqrtPriceNextX96, err = utils.GetSqrtRatioAtTick(step.tickNext)
if err != nil {
return nil, nil, nil, 0, err
return nil, err
}
var targetValue *big.Int
if zeroForOne {
Expand All @@ -291,7 +330,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int

state.sqrtPriceX96, step.amountIn, step.amountOut, step.feeAmount, err = utils.ComputeSwapStep(state.sqrtPriceX96, targetValue, state.liquidity, state.amountSpecifiedRemaining, p.Fee)
if err != nil {
return nil, nil, nil, 0, err
return nil, err
}

if exactInput {
Expand All @@ -308,7 +347,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int
if step.initialized {
tick, err := p.TickDataProvider.GetTick(step.tickNext)
if err != nil {
return nil, nil, nil, 0, err
return nil, err
}

liquidityNet := tick.LiquidityNet
Expand All @@ -318,21 +357,31 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int
liquidityNet = new(big.Int).Mul(liquidityNet, constants.NegativeOne)
}
state.liquidity = utils.AddDelta(state.liquidity, liquidityNet)

crossInitTickLoops++
}
if zeroForOne {
state.tick = step.tickNext - 1
} else {
state.tick = step.tickNext
}

} else if state.sqrtPriceX96.Cmp(step.sqrtPriceStartX96) != 0 {
// recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved
state.tick, err = utils.GetTickAtSqrtRatio(state.sqrtPriceX96)
if err != nil {
return nil, nil, nil, 0, err
return nil, err
}
}
}
return state.amountCalculated, state.sqrtPriceX96, state.liquidity, state.tick, nil
return &SwapResult{
amountCalculated: state.amountCalculated,
sqrtRatioX96: state.sqrtPriceX96,
liquidity: state.liquidity,
currentTick: state.tick,

crossInitTickLoops: crossInitTickLoops,
}, nil
}

func (p *Pool) tickSpacing() int {
Expand Down
12 changes: 6 additions & 6 deletions entities/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,21 @@ func TestGetOutputAmount(t *testing.T) {

// USDC -> DAI
inputAmount := entities.FromRawAmount(USDC, big.NewInt(100))
outputAmount, _, err := pool.GetOutputAmount(inputAmount, nil)
outputAmount, err := pool.GetOutputAmount(inputAmount, nil)
if err != nil {
t.Fatal(err)
}
assert.True(t, outputAmount.Currency.Equal(DAI))
assert.Equal(t, outputAmount.Quotient(), big.NewInt(98))
assert.True(t, outputAmount.ReturnedAmount.Currency.Equal(DAI))
assert.Equal(t, outputAmount.ReturnedAmount.Quotient(), big.NewInt(98))

// DAI -> USDC
inputAmount = entities.FromRawAmount(DAI, big.NewInt(100))
outputAmount, _, err = pool.GetOutputAmount(inputAmount, nil)
outputAmount, err = pool.GetOutputAmount(inputAmount, nil)
if err != nil {
t.Fatal(err)
}
assert.True(t, outputAmount.Currency.Equal(USDC))
assert.Equal(t, outputAmount.Quotient(), big.NewInt(98))
assert.True(t, outputAmount.ReturnedAmount.Currency.Equal(USDC))
assert.Equal(t, outputAmount.ReturnedAmount.Quotient(), big.NewInt(98))
}

func TestGetInputAmount(t *testing.T) {
Expand Down
14 changes: 7 additions & 7 deletions entities/trade.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,11 @@ func FromRoute(route *Route, amount *entities.CurrencyAmount, tradeType entities
amounts[0] = amount.Wrapped()
for i := 0; i < len(route.TokenPath)-1; i++ {
pool := route.Pools[i]
outputAmount, _, err = pool.GetOutputAmount(amounts[i], nil)
outputResult, err := pool.GetOutputAmount(amounts[i], nil)
if err != nil {
return nil, err
}
amounts[i+1] = outputAmount
amounts[i+1] = outputResult.ReturnedAmount
}
inputAmount = entities.FromFractionalAmount(route.Input, amount.Numerator, amount.Denominator)
outputAmount = entities.FromFractionalAmount(route.Output, amounts[len(amounts)-1].Numerator, amounts[len(amounts)-1].Denominator)
Expand Down Expand Up @@ -281,11 +281,11 @@ func FromRoutes(wrappedRoutes []*WrappedRoute, tradeType entities.TradeType) (*T
amounts[0] = entities.FromFractionalAmount(route.Input.Wrapped(), amount.Numerator, amount.Denominator)
for i := 0; i < len(route.TokenPath)-1; i++ {
pool := route.Pools[i]
outputAmount, _, err := pool.GetOutputAmount(amounts[i], nil)
outputResult, err := pool.GetOutputAmount(amounts[i], nil)
if err != nil {
return nil, err
}
amounts[i+1] = outputAmount
amounts[i+1] = outputResult.ReturnedAmount
}
inputAmount = entities.FromFractionalAmount(route.Input, amount.Numerator, amount.Denominator)
outputAmount = entities.FromFractionalAmount(route.Output, amounts[len(amounts)-1].Numerator, amounts[len(amounts)-1].Denominator)
Expand Down Expand Up @@ -494,7 +494,7 @@ func BestTradeExactIn(pools []*Pool, currencyAmountIn *entities.CurrencyAmount,
if !pool.Token0.Equal(amountIn.Currency) && !pool.Token1.Equal(amountIn.Currency) {
continue
}
amountOut, _, err := pool.GetOutputAmount(amountIn, nil)
amountOut, err := pool.GetOutputAmount(amountIn, nil)
if err != nil {
// TODO
// input too low
Expand All @@ -504,7 +504,7 @@ func BestTradeExactIn(pools []*Pool, currencyAmountIn *entities.CurrencyAmount,
return nil, err
}
// we have arrived at the output token, so this is the final trade of one of the paths
if amountOut.Currency.IsToken() && amountOut.Currency.Equal(tokenOut) {
if amountOut.ReturnedAmount.Currency.IsToken() && amountOut.ReturnedAmount.Currency.Equal(tokenOut) {
r, err := NewRoute(append(currentPools, pool), currencyAmountIn.Currency, currencyOut)
if err != nil {
return nil, err
Expand All @@ -523,7 +523,7 @@ func BestTradeExactIn(pools []*Pool, currencyAmountIn *entities.CurrencyAmount,
poolsExcludingThisPool = append(poolsExcludingThisPool, pools[i+1:]...)

// otherwise, consider all the other paths that lead from this token as long as we have not exceeded maxHops
bestTrades, err = BestTradeExactIn(poolsExcludingThisPool, currencyAmountIn, currencyOut, &BestTradeOptions{MaxNumResults: opts.MaxNumResults, MaxHops: opts.MaxHops - 1}, append(currentPools, pool), amountOut, bestTrades)
bestTrades, err = BestTradeExactIn(poolsExcludingThisPool, currencyAmountIn, currencyOut, &BestTradeOptions{MaxNumResults: opts.MaxNumResults, MaxHops: opts.MaxHops - 1}, append(currentPools, pool), amountOut.ReturnedAmount, bestTrades)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 71a06a9

Please sign in to comment.