diff --git a/entities/pool.go b/entities/pool.go index 9aff6a8..8473375 100644 --- a/entities/pool.go +++ b/entities/pool.go @@ -11,9 +11,9 @@ import ( ) var ( - ErrFeeTooHigh = errors.New("Fee too high") - ErrInvalidSqrtRatioX96 = errors.New("Invalid sqrtRatioX96") - ErrTokenNotInvolved = errors.New("Token not involved in pool") + ErrFeeTooHigh = errors.New("fee too high") + ErrInvalidSqrtRatioX96 = errors.New("invalid sqrtRatioX96") + ErrTokenNotInvolved = errors.New("token not involved in pool") ErrSqrtPriceLimitX96TooLow = errors.New("SqrtPriceLimitX96 too low") ErrSqrtPriceLimitX96TooHigh = errors.New("SqrtPriceLimitX96 too high") ) @@ -42,6 +42,20 @@ type Pool struct { token1Price *entities.Price } +type SwapResult struct { + amountCalculated *big.Int + sqrtRatioX96 *big.Int + liquidity *big.Int + currentTick int + crossInitTickLoops int +} + +type GetAmountResult struct { + ReturnedAmount *entities.CurrencyAmount + NewPoolState *Pool + CrossInitTickLoops int +} + func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount, initCodeHashManualOverride string) (common.Address, error) { return utils.ComputePoolAddress(constants.FactoryAddress, tokenA, tokenB, fee, initCodeHashManualOverride) } @@ -148,14 +162,14 @@ func (p *Pool) ChainID() uint { * @param sqrtPriceLimitX96 The Q64.96 sqrt price limit * @returns The output amount and the pool with updated state */ -func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*entities.CurrencyAmount, *Pool, error) { +func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*GetAmountResult, error) { if !(inputAmount.Currency.IsToken() && p.InvolvesToken(inputAmount.Currency.Wrapped())) { - return nil, nil, ErrTokenNotInvolved + return nil, ErrTokenNotInvolved } zeroForOne := inputAmount.Currency.Equal(p.Token0) - outputAmount, sqrtRatioX96, liquidity, tickCurrent, err := p.swap(zeroForOne, inputAmount.Quotient(), sqrtPriceLimitX96) + swapResult, err := p.swap(zeroForOne, inputAmount.Quotient(), sqrtPriceLimitX96) if err != nil { - return nil, nil, err + return nil, err } var outputToken *entities.Token if zeroForOne { @@ -163,11 +177,23 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi } else { outputToken = p.Token0 } - pool, err := NewPool(p.Token0, p.Token1, p.Fee, sqrtRatioX96, liquidity, tickCurrent, p.TickDataProvider) + pool, err := NewPool( + p.Token0, + p.Token1, + p.Fee, + swapResult.sqrtRatioX96, + swapResult.liquidity, + swapResult.currentTick, + p.TickDataProvider, + ) if err != nil { - return nil, nil, err + return nil, err } - return entities.FromRawAmount(outputToken, new(big.Int).Mul(outputAmount, constants.NegativeOne)), pool, nil + return &GetAmountResult{ + ReturnedAmount: entities.FromRawAmount(outputToken, new(big.Int).Mul(swapResult.amountCalculated, constants.NegativeOne)), + NewPoolState: pool, + CrossInitTickLoops: swapResult.crossInitTickLoops, + }, nil } /** @@ -181,7 +207,7 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi return nil, nil, ErrTokenNotInvolved } zeroForOne := outputAmount.Currency.Equal(p.Token1) - inputAmount, sqrtRatioX96, liquidity, tickCurrent, err := p.swap(zeroForOne, new(big.Int).Mul(outputAmount.Quotient(), constants.NegativeOne), sqrtPriceLimitX96) + swapResult, err := p.swap(zeroForOne, new(big.Int).Mul(outputAmount.Quotient(), constants.NegativeOne), sqrtPriceLimitX96) if err != nil { return nil, nil, err } @@ -191,11 +217,19 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi } else { inputToken = p.Token1 } - pool, err := NewPool(p.Token0, p.Token1, p.Fee, sqrtRatioX96, liquidity, tickCurrent, p.TickDataProvider) + pool, err := NewPool( + p.Token0, + p.Token1, + p.Fee, + swapResult.sqrtRatioX96, + swapResult.liquidity, + swapResult.currentTick, + p.TickDataProvider, + ) if err != nil { return nil, nil, err } - return entities.FromRawAmount(inputToken, inputAmount), pool, nil + return entities.FromRawAmount(inputToken, swapResult.amountCalculated), pool, nil } /** @@ -203,12 +237,13 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi * @param zeroForOne Whether the amount in is token0 or token1 * @param amountSpecified The amount of the swap, which implicitly configures the swap as exact input (positive), or exact output (negative) * @param sqrtPriceLimitX96 The Q64.96 sqrt price limit. If zero for one, the price cannot be less than this value after the swap. If one for zero, the price cannot be greater than this value after the swap - * @returns amountCalculated - * @returns sqrtRatioX96 - * @returns liquidity - * @returns tickCurrent + * @returns swapResult.amountCalculated + * @returns swapResult.sqrtRatioX96 + * @returns swapResult.liquidity + * @returns swapResult.tickCurrent */ -func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int) (amountCalCulated *big.Int, sqrtRatioX96 *big.Int, liquidity *big.Int, tickCurrent int, err error) { +func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int) (*SwapResult, error) { + var err error if sqrtPriceLimitX96 == nil { if zeroForOne { sqrtPriceLimitX96 = new(big.Int).Add(utils.MinSqrtRatio, constants.One) @@ -219,17 +254,17 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int if zeroForOne { if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatio) < 0 { - return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooLow + return nil, ErrSqrtPriceLimitX96TooLow } if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) >= 0 { - return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooHigh + return nil, ErrSqrtPriceLimitX96TooHigh } } else { if sqrtPriceLimitX96.Cmp(utils.MaxSqrtRatio) > 0 { - return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooHigh + return nil, ErrSqrtPriceLimitX96TooHigh } if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) <= 0 { - return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooLow + return nil, ErrSqrtPriceLimitX96TooLow } } @@ -251,6 +286,10 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int liquidity: p.Liquidity, } + // crossInitTickLoops is the number of loops that cross an initialized tick. + // We only count when tick passes an initialized tick, since gas only significant in this case. + crossInitTickLoops := 0 + // start swap while loop for state.amountSpecifiedRemaining.Cmp(constants.Zero) != 0 && state.sqrtPriceX96.Cmp(sqrtPriceLimitX96) != 0 { var step StepComputations @@ -261,7 +300,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int // tickBitmap.nextInitializedTickWithinOneWord step.tickNext, step.initialized, err = p.TickDataProvider.NextInitializedTickIndex(state.tick, zeroForOne) if err != nil { - return nil, nil, nil, 0, err + return nil, err } if step.tickNext < utils.MinTick { @@ -272,7 +311,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int step.sqrtPriceNextX96, err = utils.GetSqrtRatioAtTick(step.tickNext) if err != nil { - return nil, nil, nil, 0, err + return nil, err } var targetValue *big.Int if zeroForOne { @@ -291,7 +330,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int state.sqrtPriceX96, step.amountIn, step.amountOut, step.feeAmount, err = utils.ComputeSwapStep(state.sqrtPriceX96, targetValue, state.liquidity, state.amountSpecifiedRemaining, p.Fee) if err != nil { - return nil, nil, nil, 0, err + return nil, err } if exactInput { @@ -308,7 +347,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int if step.initialized { tick, err := p.TickDataProvider.GetTick(step.tickNext) if err != nil { - return nil, nil, nil, 0, err + return nil, err } liquidityNet := tick.LiquidityNet @@ -318,21 +357,31 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int liquidityNet = new(big.Int).Mul(liquidityNet, constants.NegativeOne) } state.liquidity = utils.AddDelta(state.liquidity, liquidityNet) + + crossInitTickLoops++ } if zeroForOne { state.tick = step.tickNext - 1 } else { state.tick = step.tickNext } + } else if state.sqrtPriceX96.Cmp(step.sqrtPriceStartX96) != 0 { // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved state.tick, err = utils.GetTickAtSqrtRatio(state.sqrtPriceX96) if err != nil { - return nil, nil, nil, 0, err + return nil, err } } } - return state.amountCalculated, state.sqrtPriceX96, state.liquidity, state.tick, nil + return &SwapResult{ + amountCalculated: state.amountCalculated, + sqrtRatioX96: state.sqrtPriceX96, + liquidity: state.liquidity, + currentTick: state.tick, + + crossInitTickLoops: crossInitTickLoops, + }, nil } func (p *Pool) tickSpacing() int { diff --git a/entities/pool_test.go b/entities/pool_test.go index 62d68c2..727cd1c 100644 --- a/entities/pool_test.go +++ b/entities/pool_test.go @@ -142,21 +142,21 @@ func TestGetOutputAmount(t *testing.T) { // USDC -> DAI inputAmount := entities.FromRawAmount(USDC, big.NewInt(100)) - outputAmount, _, err := pool.GetOutputAmount(inputAmount, nil) + outputAmount, err := pool.GetOutputAmount(inputAmount, nil) if err != nil { t.Fatal(err) } - assert.True(t, outputAmount.Currency.Equal(DAI)) - assert.Equal(t, outputAmount.Quotient(), big.NewInt(98)) + assert.True(t, outputAmount.ReturnedAmount.Currency.Equal(DAI)) + assert.Equal(t, outputAmount.ReturnedAmount.Quotient(), big.NewInt(98)) // DAI -> USDC inputAmount = entities.FromRawAmount(DAI, big.NewInt(100)) - outputAmount, _, err = pool.GetOutputAmount(inputAmount, nil) + outputAmount, err = pool.GetOutputAmount(inputAmount, nil) if err != nil { t.Fatal(err) } - assert.True(t, outputAmount.Currency.Equal(USDC)) - assert.Equal(t, outputAmount.Quotient(), big.NewInt(98)) + assert.True(t, outputAmount.ReturnedAmount.Currency.Equal(USDC)) + assert.Equal(t, outputAmount.ReturnedAmount.Quotient(), big.NewInt(98)) } func TestGetInputAmount(t *testing.T) { diff --git a/entities/trade.go b/entities/trade.go index 688f8ba..7af06d5 100644 --- a/entities/trade.go +++ b/entities/trade.go @@ -217,11 +217,11 @@ func FromRoute(route *Route, amount *entities.CurrencyAmount, tradeType entities amounts[0] = amount.Wrapped() for i := 0; i < len(route.TokenPath)-1; i++ { pool := route.Pools[i] - outputAmount, _, err = pool.GetOutputAmount(amounts[i], nil) + outputResult, err := pool.GetOutputAmount(amounts[i], nil) if err != nil { return nil, err } - amounts[i+1] = outputAmount + amounts[i+1] = outputResult.ReturnedAmount } inputAmount = entities.FromFractionalAmount(route.Input, amount.Numerator, amount.Denominator) outputAmount = entities.FromFractionalAmount(route.Output, amounts[len(amounts)-1].Numerator, amounts[len(amounts)-1].Denominator) @@ -281,11 +281,11 @@ func FromRoutes(wrappedRoutes []*WrappedRoute, tradeType entities.TradeType) (*T amounts[0] = entities.FromFractionalAmount(route.Input.Wrapped(), amount.Numerator, amount.Denominator) for i := 0; i < len(route.TokenPath)-1; i++ { pool := route.Pools[i] - outputAmount, _, err := pool.GetOutputAmount(amounts[i], nil) + outputResult, err := pool.GetOutputAmount(amounts[i], nil) if err != nil { return nil, err } - amounts[i+1] = outputAmount + amounts[i+1] = outputResult.ReturnedAmount } inputAmount = entities.FromFractionalAmount(route.Input, amount.Numerator, amount.Denominator) outputAmount = entities.FromFractionalAmount(route.Output, amounts[len(amounts)-1].Numerator, amounts[len(amounts)-1].Denominator) @@ -494,7 +494,7 @@ func BestTradeExactIn(pools []*Pool, currencyAmountIn *entities.CurrencyAmount, if !pool.Token0.Equal(amountIn.Currency) && !pool.Token1.Equal(amountIn.Currency) { continue } - amountOut, _, err := pool.GetOutputAmount(amountIn, nil) + amountOut, err := pool.GetOutputAmount(amountIn, nil) if err != nil { // TODO // input too low @@ -504,7 +504,7 @@ func BestTradeExactIn(pools []*Pool, currencyAmountIn *entities.CurrencyAmount, return nil, err } // we have arrived at the output token, so this is the final trade of one of the paths - if amountOut.Currency.IsToken() && amountOut.Currency.Equal(tokenOut) { + if amountOut.ReturnedAmount.Currency.IsToken() && amountOut.ReturnedAmount.Currency.Equal(tokenOut) { r, err := NewRoute(append(currentPools, pool), currencyAmountIn.Currency, currencyOut) if err != nil { return nil, err @@ -523,7 +523,7 @@ func BestTradeExactIn(pools []*Pool, currencyAmountIn *entities.CurrencyAmount, poolsExcludingThisPool = append(poolsExcludingThisPool, pools[i+1:]...) // otherwise, consider all the other paths that lead from this token as long as we have not exceeded maxHops - bestTrades, err = BestTradeExactIn(poolsExcludingThisPool, currencyAmountIn, currencyOut, &BestTradeOptions{MaxNumResults: opts.MaxNumResults, MaxHops: opts.MaxHops - 1}, append(currentPools, pool), amountOut, bestTrades) + bestTrades, err = BestTradeExactIn(poolsExcludingThisPool, currencyAmountIn, currencyOut, &BestTradeOptions{MaxNumResults: opts.MaxNumResults, MaxHops: opts.MaxHops - 1}, append(currentPools, pool), amountOut.ReturnedAmount, bestTrades) if err != nil { return nil, err }