diff --git a/constants/constants.go b/constants/constants.go index 5f5ff05..301b459 100644 --- a/constants/constants.go +++ b/constants/constants.go @@ -5,6 +5,7 @@ import ( "github.com/daoleno/uniswap-sdk-core/entities" "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" ) const PoolInitCodeHash = "0xe34f199b19b2b4f47f68442619d555527d244f78a3297ea89325f843f87b8b54" @@ -51,5 +52,7 @@ var ( Q96 = new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil) Q192 = new(big.Int).Exp(Q96, big.NewInt(2), nil) + Q96U256 = new(uint256.Int).Exp(uint256.NewInt(2), uint256.NewInt(96)) + PercentZero = entities.NewFraction(big.NewInt(0), big.NewInt(1)) ) diff --git a/entities/pool.go b/entities/pool.go index 936c8ea..98ed498 100644 --- a/entities/pool.go +++ b/entities/pool.go @@ -4,10 +4,12 @@ import ( "errors" "math/big" + "github.com/KyberNetwork/int256" "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" "github.com/KyberNetwork/uniswapv3-sdk-uint256/utils" "github.com/daoleno/uniswap-sdk-core/entities" "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" ) var ( @@ -19,13 +21,13 @@ var ( ) type StepComputations struct { - sqrtPriceStartX96 *big.Int + sqrtPriceStartX96 *utils.Uint160 tickNext int initialized bool - sqrtPriceNextX96 *big.Int - amountIn *big.Int - amountOut *big.Int - feeAmount *big.Int + sqrtPriceNextX96 *utils.Uint160 + amountIn *utils.Uint256 + amountOut *utils.Uint256 + feeAmount *utils.Uint256 } // Represents a V3 pool @@ -33,8 +35,8 @@ type Pool struct { Token0 *entities.Token Token1 *entities.Token Fee constants.FeeAmount - SqrtRatioX96 *big.Int - Liquidity *big.Int + SqrtRatioX96 *utils.Uint160 + Liquidity *utils.Uint128 TickCurrent int TickDataProvider TickDataProvider @@ -43,10 +45,10 @@ type Pool struct { } type SwapResult struct { - amountCalculated *big.Int - sqrtRatioX96 *big.Int - liquidity *big.Int - remainingAmountIn *big.Int + amountCalculated *utils.Int256 + sqrtRatioX96 *utils.Uint160 + liquidity *utils.Uint128 + remainingAmountIn *utils.Int256 currentTick int crossInitTickLoops int } @@ -62,6 +64,17 @@ func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount, initCod return utils.ComputePoolAddress(constants.FactoryAddress, tokenA, tokenB, fee, initCodeHashManualOverride) } +// deprecated +func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *big.Int, liquidity *big.Int, tickCurrent int, ticks TickDataProvider) (*Pool, error) { + return NewPoolV2( + tokenA, tokenB, fee, + uint256.MustFromBig(sqrtRatioX96), + uint256.MustFromBig(liquidity), + tickCurrent, + ticks, + ) +} + /** * Construct a pool * @param tokenA One of the tokens in the pool @@ -72,16 +85,16 @@ func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount, initCod * @param tickCurrent The current tick of the pool * @param ticks The current state of the pool ticks or a data provider that can return tick data */ -func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *big.Int, liquidity *big.Int, tickCurrent int, ticks TickDataProvider) (*Pool, error) { +func NewPoolV2(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *utils.Uint160, liquidity *utils.Uint128, tickCurrent int, ticks TickDataProvider) (*Pool, error) { if fee >= constants.FeeMax { return nil, ErrFeeTooHigh } - tickCurrentSqrtRatioX96, err := utils.GetSqrtRatioAtTick(tickCurrent) + tickCurrentSqrtRatioX96, err := utils.GetSqrtRatioAtTickV2(tickCurrent) if err != nil { return nil, err } - nextTickSqrtRatioX96, err := utils.GetSqrtRatioAtTick(tickCurrent + 1) + nextTickSqrtRatioX96, err := utils.GetSqrtRatioAtTickV2(tickCurrent + 1) if err != nil { return nil, err } @@ -125,7 +138,7 @@ func (p *Pool) Token0Price() *entities.Price { if p.token0Price != nil { return p.token0Price } - p.token0Price = entities.NewPrice(p.Token0, p.Token1, constants.Q192, new(big.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96)) + p.token0Price = entities.NewPrice(p.Token0, p.Token1, constants.Q192, new(uint256.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96).ToBig()) return p.token0Price } @@ -134,7 +147,7 @@ func (p *Pool) Token1Price() *entities.Price { if p.token1Price != nil { return p.token1Price } - p.token1Price = entities.NewPrice(p.Token1, p.Token0, new(big.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96), constants.Q192) + p.token1Price = entities.NewPrice(p.Token1, p.Token0, new(uint256.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96).ToBig(), constants.Q192) return p.token1Price } @@ -164,12 +177,16 @@ func (p *Pool) ChainID() uint { * @param sqrtPriceLimitX96 The Q64.96 sqrt price limit * @returns The output amount and the pool with updated state */ -func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*GetAmountResult, error) { +func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *utils.Uint160) (*GetAmountResult, error) { if !(inputAmount.Currency.IsToken() && p.InvolvesToken(inputAmount.Currency.Wrapped())) { return nil, ErrTokenNotInvolved } zeroForOne := inputAmount.Currency.Equal(p.Token0) - swapResult, err := p.swap(zeroForOne, inputAmount.Quotient(), sqrtPriceLimitX96) + q, err := int256.FromBig(inputAmount.Quotient()) + if err != nil { + return nil, err + } + swapResult, err := p.swap(zeroForOne, q, sqrtPriceLimitX96) if err != nil { return nil, err } @@ -179,7 +196,7 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi } else { outputToken = p.Token0 } - pool, err := NewPool( + pool, err := NewPoolV2( p.Token0, p.Token1, p.Fee, @@ -192,8 +209,8 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi return nil, err } return &GetAmountResult{ - ReturnedAmount: entities.FromRawAmount(outputToken, new(big.Int).Mul(swapResult.amountCalculated, constants.NegativeOne)), - RemainingAmountIn: entities.FromRawAmount(inputAmount.Currency, swapResult.remainingAmountIn), + ReturnedAmount: entities.FromRawAmount(outputToken, new(utils.Int256).Neg(swapResult.amountCalculated).ToBig()), + RemainingAmountIn: entities.FromRawAmount(inputAmount.Currency, swapResult.remainingAmountIn.ToBig()), NewPoolState: pool, CrossInitTickLoops: swapResult.crossInitTickLoops, }, nil @@ -205,12 +222,17 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi * @param sqrtPriceLimitX96 The Q64.96 sqrt price limit. If zero for one, the price cannot be less than this value after the swap. If one for zero, the price cannot be greater than this value after the swap * @returns The input amount and the pool with updated state */ -func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*entities.CurrencyAmount, *Pool, error) { +func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *utils.Uint160) (*entities.CurrencyAmount, *Pool, error) { if !(outputAmount.Currency.IsToken() && p.InvolvesToken(outputAmount.Currency.Wrapped())) { return nil, nil, ErrTokenNotInvolved } zeroForOne := outputAmount.Currency.Equal(p.Token1) - swapResult, err := p.swap(zeroForOne, new(big.Int).Mul(outputAmount.Quotient(), constants.NegativeOne), sqrtPriceLimitX96) + q, err := int256.FromBig(outputAmount.Quotient()) + if err != nil { + return nil, nil, err + } + q.Neg(q) + swapResult, err := p.swap(zeroForOne, q, sqrtPriceLimitX96) if err != nil { return nil, nil, err } @@ -220,7 +242,7 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi } else { inputToken = p.Token1 } - pool, err := NewPool( + pool, err := NewPoolV2( p.Token0, p.Token1, p.Fee, @@ -232,7 +254,7 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi if err != nil { return nil, nil, err } - return entities.FromRawAmount(inputToken, swapResult.amountCalculated), pool, nil + return entities.FromRawAmount(inputToken, swapResult.amountCalculated.ToBig()), pool, nil } /** @@ -245,25 +267,25 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi * @returns swapResult.liquidity * @returns swapResult.tickCurrent */ -func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int) (*SwapResult, error) { +func (p *Pool) swap(zeroForOne bool, amountSpecified *utils.Int256, sqrtPriceLimitX96 *utils.Uint160) (*SwapResult, error) { var err error if sqrtPriceLimitX96 == nil { if zeroForOne { - sqrtPriceLimitX96 = new(big.Int).Add(utils.MinSqrtRatio, constants.One) + sqrtPriceLimitX96 = new(uint256.Int).AddUint64(utils.MinSqrtRatioU256, 1) } else { - sqrtPriceLimitX96 = new(big.Int).Sub(utils.MaxSqrtRatio, constants.One) + sqrtPriceLimitX96 = new(uint256.Int).SubUint64(utils.MaxSqrtRatioU256, 1) } } if zeroForOne { - if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatio) < 0 { + if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatioU256) < 0 { return nil, ErrSqrtPriceLimitX96TooLow } if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) >= 0 { return nil, ErrSqrtPriceLimitX96TooHigh } } else { - if sqrtPriceLimitX96.Cmp(utils.MaxSqrtRatio) > 0 { + if sqrtPriceLimitX96.Cmp(utils.MaxSqrtRatioU256) > 0 { return nil, ErrSqrtPriceLimitX96TooHigh } if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) <= 0 { @@ -271,22 +293,22 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int } } - exactInput := amountSpecified.Cmp(constants.Zero) >= 0 + exactInput := amountSpecified.Sign() >= 0 // keep track of swap state state := struct { - amountSpecifiedRemaining *big.Int - amountCalculated *big.Int - sqrtPriceX96 *big.Int + amountSpecifiedRemaining *utils.Int256 + amountCalculated *utils.Int256 + sqrtPriceX96 *utils.Uint160 tick int - liquidity *big.Int + liquidity *utils.Uint128 }{ - amountSpecifiedRemaining: amountSpecified, - amountCalculated: constants.Zero, - sqrtPriceX96: p.SqrtRatioX96, + amountSpecifiedRemaining: new(utils.Int256).Set(amountSpecified), + amountCalculated: int256.NewInt(0), + sqrtPriceX96: new(utils.Uint160).Set(p.SqrtRatioX96), tick: p.TickCurrent, - liquidity: p.Liquidity, + liquidity: new(utils.Uint128).Set(p.Liquidity), } // crossInitTickLoops is the number of loops that cross an initialized tick. @@ -294,7 +316,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int crossInitTickLoops := 0 // start swap while loop - for state.amountSpecifiedRemaining.Cmp(constants.Zero) != 0 && state.sqrtPriceX96.Cmp(sqrtPriceLimitX96) != 0 { + for !state.amountSpecifiedRemaining.IsZero() && state.sqrtPriceX96.Cmp(sqrtPriceLimitX96) != 0 { var step StepComputations step.sqrtPriceStartX96 = state.sqrtPriceX96 @@ -312,11 +334,11 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int step.tickNext = utils.MaxTick } - step.sqrtPriceNextX96, err = utils.GetSqrtRatioAtTick(step.tickNext) + step.sqrtPriceNextX96, err = utils.GetSqrtRatioAtTickV2(step.tickNext) if err != nil { return nil, err } - var targetValue *big.Int + var targetValue *utils.Uint160 if zeroForOne { if step.sqrtPriceNextX96.Cmp(sqrtPriceLimitX96) < 0 { targetValue = sqrtPriceLimitX96 @@ -336,12 +358,27 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int return nil, err } + var amountInPlusFee utils.Uint256 + amountInPlusFee.Add(step.amountIn, step.feeAmount) + + var amountInPlusFeeSigned utils.Int256 + err = utils.ToInt256(&amountInPlusFee, &amountInPlusFeeSigned) + if err != nil { + return nil, err + } + + var amountOutSigned utils.Int256 + err = utils.ToInt256(step.amountOut, &amountOutSigned) + if err != nil { + return nil, err + } + if exactInput { - state.amountSpecifiedRemaining = new(big.Int).Sub(state.amountSpecifiedRemaining, new(big.Int).Add(step.amountIn, step.feeAmount)) - state.amountCalculated = new(big.Int).Sub(state.amountCalculated, step.amountOut) + state.amountSpecifiedRemaining.Sub(state.amountSpecifiedRemaining, &amountInPlusFeeSigned) + state.amountCalculated.Sub(state.amountCalculated, &amountOutSigned) } else { - state.amountSpecifiedRemaining = new(big.Int).Add(state.amountSpecifiedRemaining, step.amountOut) - state.amountCalculated = new(big.Int).Add(state.amountCalculated, new(big.Int).Add(step.amountIn, step.feeAmount)) + state.amountSpecifiedRemaining.Add(state.amountSpecifiedRemaining, &amountOutSigned) + state.amountCalculated.Add(state.amountCalculated, &amountInPlusFeeSigned) } // TODO @@ -357,9 +394,9 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int // if we're moving leftward, we interpret liquidityNet as the opposite sign // safe because liquidityNet cannot be type(int128).min if zeroForOne { - liquidityNet = new(big.Int).Mul(liquidityNet, constants.NegativeOne) + liquidityNet = new(utils.Int128).Neg(liquidityNet) } - state.liquidity = utils.AddDelta(state.liquidity, liquidityNet) + utils.AddDeltaInPlace(state.liquidity, liquidityNet) crossInitTickLoops++ } @@ -371,7 +408,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int } else if state.sqrtPriceX96.Cmp(step.sqrtPriceStartX96) != 0 { // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved - state.tick, err = utils.GetTickAtSqrtRatio(state.sqrtPriceX96) + state.tick, err = utils.GetTickAtSqrtRatioV2(state.sqrtPriceX96) if err != nil { return nil, err } diff --git a/entities/pool_test.go b/entities/pool_test.go index 028ade7..e03b4cb 100644 --- a/entities/pool_test.go +++ b/entities/pool_test.go @@ -4,10 +4,12 @@ import ( "math/big" "testing" + "github.com/KyberNetwork/int256" "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" "github.com/KyberNetwork/uniswapv3-sdk-uint256/utils" "github.com/daoleno/uniswap-sdk-core/entities" "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" "github.com/stretchr/testify/assert" ) @@ -15,6 +17,9 @@ var ( USDC = entities.NewToken(1, common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), 6, "USDC", "USD Coin") DAI = entities.NewToken(1, common.HexToAddress("0x6B175474E89094C44Da98b954EedeAC495271d0F"), 18, "DAI", "Dai Stablecoin") OneEther = big.NewInt(1e18) + + OneEtherI256 = int256.NewInt(1e18) + OneEtherUI256 = uint256.NewInt(1e18) ) func TestNewPool(t *testing.T) { @@ -116,13 +121,13 @@ func newTestPool() *Pool { ticks := []Tick{ { Index: NearestUsableTick(utils.MinTick, constants.TickSpacings[constants.FeeLow]), - LiquidityNet: OneEther, - LiquidityGross: OneEther, + LiquidityNet: OneEtherI256, + LiquidityGross: OneEtherUI256, }, { Index: NearestUsableTick(utils.MaxTick, constants.TickSpacings[constants.FeeLow]), - LiquidityNet: new(big.Int).Mul(OneEther, constants.NegativeOne), - LiquidityGross: OneEther, + LiquidityNet: new(int256.Int).Neg(OneEtherI256), + LiquidityGross: OneEtherUI256, }, } diff --git a/entities/position.go b/entities/position.go index 20e3a49..f783519 100644 --- a/entities/position.go +++ b/entities/position.go @@ -82,7 +82,7 @@ func (p *Position) Amount0() (*entities.CurrencyAmount, error) { if err != nil { return nil, err } - p.token0Amount = entities.FromRawAmount(p.Pool.Token0, utils.GetAmount0Delta(p.Pool.SqrtRatioX96, sqrtTickUpper, p.Liquidity, true)) + p.token0Amount = entities.FromRawAmount(p.Pool.Token0, utils.GetAmount0Delta(p.Pool.SqrtRatioX96.ToBig(), sqrtTickUpper, p.Liquidity, true)) } else { p.token0Amount = entities.FromRawAmount(p.Pool.Token0, constants.Zero) } @@ -100,7 +100,7 @@ func (p *Position) Amount1() (*entities.CurrencyAmount, error) { if err != nil { return nil, err } - p.token1Amount = entities.FromRawAmount(p.Pool.Token1, utils.GetAmount1Delta(sqrtTickLower, p.Pool.SqrtRatioX96, p.Liquidity, false)) + p.token1Amount = entities.FromRawAmount(p.Pool.Token1, utils.GetAmount1Delta(sqrtTickLower, p.Pool.SqrtRatioX96.ToBig(), p.Liquidity, false)) } else { sqrtTickLower, err := utils.GetSqrtRatioAtTick(p.TickLower) if err != nil { @@ -266,8 +266,8 @@ func (p *Position) MintAmounts() (amount0, amount1 *big.Int, err error) { amount1 = constants.Zero return amount0, amount1, nil } else if p.Pool.TickCurrent < p.TickUpper { - amount0 = utils.GetAmount0Delta(p.Pool.SqrtRatioX96, rUpper, p.Liquidity, true) - amount1 = utils.GetAmount1Delta(rLower, p.Pool.SqrtRatioX96, p.Liquidity, true) + amount0 = utils.GetAmount0Delta(p.Pool.SqrtRatioX96.ToBig(), rUpper, p.Liquidity, true) + amount1 = utils.GetAmount1Delta(rLower, p.Pool.SqrtRatioX96.ToBig(), p.Liquidity, true) } else { amount0 = constants.Zero amount1 = utils.GetAmount1Delta(rLower, rUpper, p.Liquidity, true) @@ -298,7 +298,7 @@ func FromAmounts(pool *Pool, tickLower, tickUpper int, amount0, amount1 *big.Int if err != nil { return nil, err } - return NewPosition(pool, utils.MaxLiquidityForAmounts(pool.SqrtRatioX96, sqrtRatioAX96, sqrtRatioBX96, amount0, amount1, useFullPrecision), tickLower, tickUpper) + return NewPosition(pool, utils.MaxLiquidityForAmounts(pool.SqrtRatioX96.ToBig(), sqrtRatioAX96, sqrtRatioBX96, amount0, amount1, useFullPrecision), tickLower, tickUpper) } /** diff --git a/entities/tickdataprovider.go b/entities/tickdataprovider.go index 6086e61..5e5b6b0 100644 --- a/entities/tickdataprovider.go +++ b/entities/tickdataprovider.go @@ -1,11 +1,14 @@ package entities -import "math/big" +import ( + "github.com/KyberNetwork/uniswapv3-sdk-uint256/utils" + "github.com/holiman/uint256" +) type Tick struct { Index int - LiquidityGross *big.Int - LiquidityNet *big.Int + LiquidityGross *uint256.Int + LiquidityNet *utils.Int128 } // Provides information about ticks diff --git a/entities/ticklist.go b/entities/ticklist.go index 10da96c..c9d5ce7 100644 --- a/entities/ticklist.go +++ b/entities/ticklist.go @@ -3,9 +3,8 @@ package entities import ( "errors" "math" - "math/big" - "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" + "github.com/KyberNetwork/int256" ) const ( @@ -41,11 +40,11 @@ func ValidateList(ticks []Tick, tickSpacing int) error { } // ensure tick liquidity deltas sum to 0 - sum := big.NewInt(0) + sum := int256.NewInt(0) for _, tick := range ticks { sum.Add(sum, tick.LiquidityNet) } - if sum.Cmp(big.NewInt(0)) != 0 { + if !sum.IsZero() { return ErrZeroNet } @@ -195,7 +194,7 @@ func NextInitializedTickIndex(ticks []Tick, tick int, lte bool) (int, bool, erro } var isInitialized bool - if nextInitializedTick.LiquidityGross.Cmp(constants.Zero) != 0 { + if !nextInitializedTick.LiquidityGross.IsZero() { isInitialized = true } diff --git a/entities/ticklist_test.go b/entities/ticklist_test.go index 5eef4fc..ff39ca5 100644 --- a/entities/ticklist_test.go +++ b/entities/ticklist_test.go @@ -1,28 +1,29 @@ package entities import ( - "math/big" "testing" + "github.com/KyberNetwork/int256" "github.com/KyberNetwork/uniswapv3-sdk-uint256/utils" + "github.com/holiman/uint256" "github.com/stretchr/testify/assert" ) var ( lowTick = Tick{ Index: utils.MinTick + 1, - LiquidityNet: big.NewInt(10), - LiquidityGross: big.NewInt(10), + LiquidityNet: int256.NewInt(10), + LiquidityGross: uint256.NewInt(10), } midTick = Tick{ Index: 0, - LiquidityNet: big.NewInt(-5), - LiquidityGross: big.NewInt(5), + LiquidityNet: int256.NewInt(-5), + LiquidityGross: uint256.NewInt(5), } highTick = Tick{ Index: utils.MaxTick - 1, - LiquidityNet: big.NewInt(-5), - LiquidityGross: big.NewInt(5), + LiquidityNet: int256.NewInt(-5), + LiquidityGross: uint256.NewInt(5), } ) diff --git a/entities/trade_test.go b/entities/trade_test.go index 1d1f4d1..acb738c 100644 --- a/entities/trade_test.go +++ b/entities/trade_test.go @@ -4,10 +4,12 @@ import ( "math/big" "testing" + "github.com/KyberNetwork/int256" "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" "github.com/KyberNetwork/uniswapv3-sdk-uint256/utils" "github.com/daoleno/uniswap-sdk-core/entities" "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" "github.com/stretchr/testify/assert" ) @@ -79,16 +81,18 @@ var ( func v2StylePool(token0, token1 *entities.Token, reserve0, reserve1 *entities.CurrencyAmount, feeAmount constants.FeeAmount) *Pool { sqrtRatioX96 := utils.EncodeSqrtRatioX96(reserve1.Quotient(), reserve0.Quotient()) liquidity := new(big.Int).Sqrt(new(big.Int).Mul(reserve0.Quotient(), reserve1.Quotient())) + liquidityGross := uint256.MustFromBig(liquidity) + liquidityNet := int256.MustFromBig(liquidity) ticks := []Tick{ { Index: NearestUsableTick(utils.MinTick, constants.TickSpacings[feeAmount]), - LiquidityNet: liquidity, - LiquidityGross: liquidity, + LiquidityNet: liquidityNet, + LiquidityGross: liquidityGross, }, { Index: NearestUsableTick(utils.MaxTick, constants.TickSpacings[feeAmount]), - LiquidityNet: new(big.Int).Mul(liquidity, big.NewInt(-1)), - LiquidityGross: liquidity, + LiquidityNet: new(int256.Int).Neg(liquidityNet), + LiquidityGross: liquidityGross, }, } s, err := utils.GetTickAtSqrtRatio(sqrtRatioX96) diff --git a/examples/helper/pool.go b/examples/helper/pool.go index be008ac..60ed931 100644 --- a/examples/helper/pool.go +++ b/examples/helper/pool.go @@ -4,7 +4,9 @@ import ( "errors" "math/big" + "github.com/KyberNetwork/int256" "github.com/KyberNetwork/uniswapv3-sdk-uint256/examples/contract" + "github.com/holiman/uint256" "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" "github.com/KyberNetwork/uniswapv3-sdk-uint256/entities" @@ -61,14 +63,14 @@ func ConstructV3Pool(client *ethclient.Client, token0, token1 *coreEntities.Toke { Index: entities.NearestUsableTick(sdkutils.MinTick, constants.TickSpacings[feeAmount]), - LiquidityNet: pooltick.LiquidityNet, - LiquidityGross: pooltick.LiquidityGross, + LiquidityNet: int256.MustFromBig(pooltick.LiquidityNet), + LiquidityGross: uint256.MustFromBig(pooltick.LiquidityGross), }, { Index: entities.NearestUsableTick(sdkutils.MaxTick, constants.TickSpacings[feeAmount]), - LiquidityNet: pooltick.LiquidityNet, - LiquidityGross: pooltick.LiquidityGross, + LiquidityNet: int256.MustFromBig(pooltick.LiquidityNet), + LiquidityGross: uint256.MustFromBig(pooltick.LiquidityGross), }, } diff --git a/examples/helper/tx.go b/examples/helper/tx.go index 4d3c85c..09446b2 100644 --- a/examples/helper/tx.go +++ b/examples/helper/tx.go @@ -11,7 +11,7 @@ import ( "github.com/ethereum/go-ethereum/ethclient" ) -//SendTx Send a real transaction to the blockchain. +// SendTx Send a real transaction to the blockchain. func SendTX(client *ethclient.Client, toAddress common.Address, value *big.Int, data []byte, w *Wallet) (*types.Transaction, error) { signedTx, err := TryTX(client, toAddress, value, data, w) @@ -21,7 +21,7 @@ func SendTX(client *ethclient.Client, toAddress common.Address, value *big.Int, return signedTx, client.SendTransaction(context.Background(), signedTx) } -//Trytx Trying to send a transaction, it just return the transaction hash if success. +// Trytx Trying to send a transaction, it just return the transaction hash if success. func TryTX(client *ethclient.Client, toAddress common.Address, value *big.Int, data []byte, w *Wallet) (*types.Transaction, error) { gasPrice, err := client.SuggestGasPrice(context.Background()) diff --git a/examples/liquidity/main.go b/examples/liquidity/main.go index 8b2dd34..c6b369a 100644 --- a/examples/liquidity/main.go +++ b/examples/liquidity/main.go @@ -18,7 +18,7 @@ import ( "github.com/ethereum/go-ethereum/ethclient" ) -//mint a new liquidity +// mint a new liquidity func mintOrAdd(client *ethclient.Client, wallet *helper.Wallet, tokenID *big.Int) { log.SetFlags(log.Lshortfile | log.LstdFlags) diff --git a/go.mod b/go.mod index 76b5561..32644fb 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,14 @@ module github.com/KyberNetwork/uniswapv3-sdk-uint256 -go 1.18 +go 1.21 require ( + github.com/KyberNetwork/int256 v0.1.4 github.com/daoleno/uniswap-sdk-core v0.1.5 github.com/ethereum/go-ethereum v1.10.20 + github.com/holiman/uint256 v1.2.3 github.com/shopspring/decimal v1.3.1 - github.com/stretchr/testify v1.8.0 + github.com/stretchr/testify v1.8.4 ) require ( diff --git a/go.sum b/go.sum index d95108d..b0edde8 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,19 @@ +github.com/KyberNetwork/int256 v0.1.4 h1:SbnhxqcsZXrP+5pSkolpSo2ViEWiq3K/hjfl5OQgX/4= +github.com/KyberNetwork/int256 v0.1.4/go.mod h1:qE/Ikpo86fn60sIB7CwcfiqJTjK2p+k+ASvuO79Rq4g= github.com/VictoriaMetrics/fastcache v1.6.0 h1:C/3Oi3EiBCqufydp1neRZkqcwmEiuRT9c3fqvvgKm5o= +github.com/VictoriaMetrics/fastcache v1.6.0/go.mod h1:0qHz5QP0GMX4pfmMA/zt5RgfNuXJrTP0zS7DqpHGGTw= github.com/btcsuite/btcd/btcec/v2 v2.2.0 h1:fzn1qaOt32TuLjFlkzYSsBC35Q3KUjT1SwPxiMSCF5k= github.com/btcsuite/btcd/btcec/v2 v2.2.0/go.mod h1:U7MHm051Al6XmscBQ0BoNydpOTsFAn707034b5nY8zU= github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 h1:q0rUy8C/TYNBQS1+CGKw68tLOFYSNEs0TFnxxnS9+4U= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= github.com/cespare/cp v0.1.0 h1:SE+dxFebS7Iik5LK0tsi1k9ZCxEaFX4AjQmoyA+1dJk= +github.com/cespare/cp v0.1.0/go.mod h1:SOGHArjBr4JWaSDEVpWpo/hNg6RoKrls6Oh40hiwW+s= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/daoleno/uniswap-sdk-core v0.1.5 h1:VlU6NXnJBJ75D3GmX01CGIEMoiizXlu9v+jSEj26lhM= github.com/daoleno/uniswap-sdk-core v0.1.5/go.mod h1:OV1Kvws5JShxPz3qFpjpkuZB4gdebRpqm/AcYMZ7TZQ= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/deckarep/golang-set v1.8.0 h1:sk9/l/KqpunDwP7pSjUg0keiOOLEnOBHzykLrsPppp4= @@ -17,64 +23,88 @@ github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= github.com/edsrzf/mmap-go v1.0.0 h1:CEBF7HpRnUCSJgGUb5h1Gm7e3VkmVDrR8lvWVLtrOFw= +github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= github.com/ethereum/go-ethereum v1.10.20 h1:75IW830ClSS40yrQC1ZCMZCt5I+zU16oqId2SiQwdQ4= github.com/ethereum/go-ethereum v1.10.20/go.mod h1:LWUN82TCHGpxB3En5HVmLLzPD7YSrEUFmFfN1nKkVN0= github.com/fjl/memsize v0.0.0-20190710130421-bcb5799ab5e5 h1:FtmdgXiUlNeRsoNMFlKLDt+S+6hbjVMEW6RGQ7aUf7c= +github.com/fjl/memsize v0.0.0-20190710130421-bcb5799ab5e5/go.mod h1:VvhXpOYNQvB+uIk2RvXzuaQtkQJzzIx6lSBe1xv7hi0= github.com/gballet/go-libpcsclite v0.0.0-20190607065134-2772fd86a8ff h1:tY80oXqGNY4FhTFhk+o9oFHGINQ/+vhlm8HFzi6znCI= +github.com/gballet/go-libpcsclite v0.0.0-20190607065134-2772fd86a8ff/go.mod h1:x7DCsMOv1taUwEWCzT4cmDeAkigA5/QCwUodaVOe8Ww= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= github.com/golang-jwt/jwt/v4 v4.3.0 h1:kHL1vqdqWNfATmA0FNMdmZNMyZI1U6O31X4rlIPoBog= +github.com/golang-jwt/jwt/v4 v4.3.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/go-bexpr v0.1.10 h1:9kuI5PFotCboP3dkDYFr/wi0gg0QVbSNz5oFRpxn4uE= +github.com/hashicorp/go-bexpr v0.1.10/go.mod h1:oxlubA2vC/gFVfX1A6JGp7ls7uCDlfJn732ehYYg+g0= github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d h1:dg1dEPuWpEqDnvIw251EVy4zlP8gWbsGj4BsUKCRpYs= +github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/holiman/bloomfilter/v2 v2.0.3 h1:73e0e/V0tCydx14a0SCYS/EWCxgwLZ18CZcZKVu0fao= -github.com/holiman/uint256 v1.2.0 h1:gpSYcPLWGv4sG43I2mVLiDZCNDh/EpGjSk8tmtxitHM= +github.com/holiman/bloomfilter/v2 v2.0.3/go.mod h1:zpoh+gs7qcpqrHr3dB55AMiJwo0iURXE7ZOP9L9hSkA= +github.com/holiman/uint256 v1.2.3 h1:K8UWO1HUJpRMXBxbmaY1Y8IAMZC/RsKB+ArEnnK4l5o= +github.com/holiman/uint256 v1.2.3/go.mod h1:SC8Ryt4n+UBbPbIBKaG9zbbDlp4jOru9xFZmPzLUTxw= github.com/huin/goupnp v1.0.3 h1:N8No57ls+MnjlB+JPiCVSOyy/ot7MJTqlo7rn+NYSqQ= +github.com/huin/goupnp v1.0.3/go.mod h1:ZxNlw5WqJj6wSsRK5+YfflQGXYfccj5VgQsMNixHM7Y= github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= +github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= +github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= +github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mitchellh/mapstructure v1.4.1 h1:CpVNEelQCZBooIPDn+AR3NpivK/TIKU8bDxdASFVQag= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/pointerstructure v1.2.0 h1:O+i9nHnXS3l/9Wu7r4NrEdwA2VFTicjUEN1uBnDo34A= +github.com/mitchellh/pointerstructure v1.2.0/go.mod h1:BRAsLI5zgXmw97Lf6s25bs8ohIXc3tViBH44KcwB2g4= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/tsdb v0.7.1 h1:YZcsG11NqnK4czYLrWd9mpEuAJIHVQLwdrleYfszMAA= +github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/rjeczalik/notify v0.9.2 h1:MiTWrPj55mNDHEiIX5YUSKefw/+lCQVoAFmD6oQm5w8= github.com/rjeczalik/notify v0.9.2/go.mod h1:aErll2f0sUX9PXZnVNyeiObbmTlk5jnMoCa4QEjJeqM= github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= +github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/status-im/keycard-go v0.0.0-20190316090335-8537d3370df4 h1:Gb2Tyox57NRNuZ2d3rmvB3pcmbu7O1RS3m8WRx7ilrg= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/status-im/keycard-go v0.0.0-20190316090335-8537d3370df4/go.mod h1:RZLeN1LMWmRsyYjvAu+I6Dm9QmlDaIIt+Y+4Kd7Tp+Q= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/syndtr/goleveldb v1.0.1-0.20220614013038-64ee5596c38a h1:1ur3QoCqvE5fl+nylMaIr9PVV1w343YRDtsy+Rwu7XI= +github.com/syndtr/goleveldb v1.0.1-0.20220614013038-64ee5596c38a/go.mod h1:RRCYJbIwD5jmqPI9XoAFR0OcDxqUctll6zUj/+B4S48= github.com/tklauser/go-sysconf v0.3.10 h1:IJ1AZGZRWbY8T5Vfk04D9WOA5WSejdflXxP03OUqALw= github.com/tklauser/go-sysconf v0.3.10/go.mod h1:C8XykCvCb+Gn0oNCWPIlcb0RuglQTYaQ2hGm7jmxEFk= github.com/tklauser/numcpus v0.4.0/go.mod h1:1+UI3pD8NW14VMwdgJNJ1ESk2UnwhAnz5hMwiKKqXCQ= github.com/tklauser/numcpus v0.5.0 h1:ooe7gN0fg6myJ0EKoTAf5hebTZrH52px3New/D9iJ+A= github.com/tklauser/numcpus v0.5.0/go.mod h1:OGzpTxpcIMNGYQdit2BYL1pvk/dSOaJWjKoflh+RQjo= github.com/tyler-smith/go-bip39 v1.0.1-0.20181017060643-dbb3b84ba2ef h1:wHSqTBrZW24CsNJDfeh9Ex6Pm0Rcpc7qrgKBiL44vF4= +github.com/tyler-smith/go-bip39 v1.0.1-0.20181017060643-dbb3b84ba2ef/go.mod h1:sJ5fKU0s6JVwZjjcUEX2zFOnvq0ASQ2K9Zr6cf67kNs= github.com/urfave/cli/v2 v2.10.2 h1:x3p8awjp/2arX+Nl/G2040AZpOCHS/eMJJ1/a+mye4Y= +github.com/urfave/cli/v2 v2.10.2/go.mod h1:f8iq5LtQ/bLxafbdBSLPPNsgaW0l/2fYYEHhAyPlwvo= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= +github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg= github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180926160741-c2ed4eda69e7/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -82,11 +112,12 @@ golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220702020025-31831981b65f h1:xdsejrW/0Wf2diT5CPp3XmKUNbr7Xvw8kYilQ+6qjRY= golang.org/x/sys v0.0.0-20220702020025-31831981b65f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE= +golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce h1:+JknDZhAj8YMt7GC73Ei8pv4MzjDUNPHgQWJdtMAaDU= gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce/go.mod h1:5AcXVHNjg+BDxry382+8OKon8SEWiKktQR07RKPsv1c= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/periphery/const_test.go b/periphery/const_test.go index 70ea531..c53b82b 100644 --- a/periphery/const_test.go +++ b/periphery/const_test.go @@ -3,11 +3,13 @@ package periphery import ( "math/big" + "github.com/KyberNetwork/int256" "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" "github.com/KyberNetwork/uniswapv3-sdk-uint256/entities" "github.com/KyberNetwork/uniswapv3-sdk-uint256/utils" core "github.com/daoleno/uniswap-sdk-core/entities" "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" ) var ( @@ -32,6 +34,10 @@ var ( route_weth_0, _ = entities.NewRoute([]*entities.Pool{pool_0_weth}, weth, token0) route_weth_0_1, _ = entities.NewRoute([]*entities.Pool{pool_0_weth, pool_0_1_medium}, weth, token1) + liquidityGross = uint256.NewInt(1_000_000) + liquidityNet = int256.NewInt(1_000_000) + liquidityNetNeg = int256.NewInt(-1_000_000) + feeAmount = constants.FeeMedium sqrtRatioX96 = utils.EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)) liquidity = big.NewInt(1_000_000) @@ -39,13 +45,13 @@ var ( ticks = []entities.Tick{ { Index: entities.NearestUsableTick(utils.MinTick, constants.TickSpacings[feeAmount]), - LiquidityNet: liquidity, - LiquidityGross: liquidity, + LiquidityNet: liquidityNet, + LiquidityGross: liquidityGross, }, { Index: entities.NearestUsableTick(utils.MaxTick, constants.TickSpacings[feeAmount]), - LiquidityNet: new(big.Int).Mul(liquidity, constants.NegativeOne), - LiquidityGross: liquidity, + LiquidityNet: liquidityNetNeg, + LiquidityGross: liquidityGross, }, } diff --git a/periphery/nonfungible_position_manager.go b/periphery/nonfungible_position_manager.go index 6cdbd20..54c0b5a 100644 --- a/periphery/nonfungible_position_manager.go +++ b/periphery/nonfungible_position_manager.go @@ -43,7 +43,7 @@ type IncreaseSpecificOptions struct { TokenID *big.Int // Indicates the ID of the position to increase liquidity for } -// Options for producing the calldata to add liquidity +// Options for producing the calldata to add liquidity type CommonAddLiquidityOptions struct { SlippageTolerance *core.Percent // How much the pool price is allowed to move Deadline *big.Int // When the transaction expires, in epoch seconds @@ -143,7 +143,7 @@ type DecreaseLiquidityParams struct { func encodeCreate(pool *entities.Pool) ([]byte, error) { abi := getNonFungiblePositionManagerABI() - return abi.Pack("createAndInitializePoolIfNecessary", pool.Token0.Address, pool.Token1.Address, big.NewInt(int64(pool.Fee)), pool.SqrtRatioX96) + return abi.Pack("createAndInitializePoolIfNecessary", pool.Token0.Address, pool.Token1.Address, big.NewInt(int64(pool.Fee)), pool.SqrtRatioX96.ToBig()) } func CreateCallParameters(pool *entities.Pool) (*utils.MethodParameters, error) { diff --git a/utils/full_math.go b/utils/full_math.go index 34f03ba..8d4584b 100644 --- a/utils/full_math.go +++ b/utils/full_math.go @@ -1,16 +1,56 @@ package utils import ( + "errors" "math/big" - "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" + "github.com/holiman/uint256" ) -func MulDivRoundingUp(a, b, denominator *big.Int) *big.Int { - product := new(big.Int).Mul(a, b) - result := new(big.Int).Div(product, denominator) - if new(big.Int).Rem(product, denominator).Cmp(big.NewInt(0)) != 0 { - result.Add(result, constants.One) +var ( + ErrMulDivOverflow = errors.New("muldiv overflow") + One = big.NewInt(1) +) + +// Calculates ceil(a×b÷denominator) with full precision +func MulDivRoundingUp(a, b, denominator *uint256.Int) (*uint256.Int, error) { + // the product can overflow so need to use big.Int here + // TODO: optimize this + var product, rem, result big.Int + product.Mul(a.ToBig(), b.ToBig()) + result.DivMod(&product, denominator.ToBig(), &rem) + if rem.Sign() != 0 { + result.Add(&result, One) + } + + resultU, overflow := uint256.FromBig(&result) + if overflow { + return nil, ErrMulDivOverflow + } + return resultU, nil +} + +// Calculates floor(a×b÷denominator) with full precision +func MulDiv(a, b, denominator *uint256.Int) (*uint256.Int, error) { + // the product can overflow so need to use big.Int here + // TODO: optimize this follow univ3 code + var product, result big.Int + product.Mul(a.ToBig(), b.ToBig()) + result.Div(&product, denominator.ToBig()) + + resultU, overflow := uint256.FromBig(&result) + if overflow { + return nil, ErrMulDivOverflow + } + return resultU, nil +} + +// Returns ceil(x / y) +func DivRoundingUp(a, denominator *uint256.Int) *uint256.Int { + var result, rem uint256.Int + result.DivMod(a, denominator, &rem) + if !rem.IsZero() { + result.AddUint64(&result, 1) } - return result + return &result } diff --git a/utils/full_math_test.go b/utils/full_math_test.go new file mode 100644 index 0000000..fa694ed --- /dev/null +++ b/utils/full_math_test.go @@ -0,0 +1,100 @@ +package utils + +import ( + "fmt" + "testing" + + "github.com/holiman/uint256" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMulDiv(t *testing.T) { + // https://github.com/Uniswap/v3-core/blob/main/test/FullMath.spec.ts + + tests := []struct { + a string + b string + deno string + expResult string + }{ + {MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Dec()}, + {"0x100000000000000000000000000000000", "0x80000000000000000000000000000000", "0x180000000000000000000000000000000", "113427455640312821154458202477256070485"}, + {"0x100000000000000000000000000000000", "0x2300000000000000000000000000000000", "0x800000000000000000000000000000000", "1488735355279105777652263907513985925120"}, + {"0x100000000000000000000000000000000", "0x3e800000000000000000000000000000000", "0xbb800000000000000000000000000000000", "113427455640312821154458202477256070485"}, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + r, err := MulDiv( + uint256.MustFromHex(tt.a), uint256.MustFromHex(tt.b), + uint256.MustFromHex(tt.deno)) + require.Nil(t, err) + assert.Equal(t, tt.expResult, r.Dec()) + }) + } + + failTests := []struct { + a string + b string + deno string + }{ + // {"0x100000000000000000000000000000000", "0x5", "0x0"}, // we don't catch div by zero here + // {"0x100000000000000000000000000000000", "0x100000000000000000000000000000000", "0x0"}, + {"0x100000000000000000000000000000000", "0x100000000000000000000000000000000", "0x1"}, + {MaxUint256.Hex(), MaxUint256.Hex(), new(Uint256).SubUint64(MaxUint256, 1).Hex()}, + } + for i, tt := range failTests { + t.Run(fmt.Sprintf("fail test %d", i), func(t *testing.T) { + _, err := MulDiv( + uint256.MustFromHex(tt.a), uint256.MustFromHex(tt.b), + uint256.MustFromHex(tt.deno)) + require.NotNil(t, err) + }) + } +} + +func TestMulDivRoundingUp(t *testing.T) { + // https://github.com/Uniswap/v3-core/blob/main/test/FullMath.spec.ts + + tests := []struct { + a string + b string + deno string + expResult string + }{ + {MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Hex(), MaxUint256.Dec()}, + {"0x100000000000000000000000000000000", "0x80000000000000000000000000000000", "0x180000000000000000000000000000000", "113427455640312821154458202477256070486"}, + {"0x100000000000000000000000000000000", "0x2300000000000000000000000000000000", "0x800000000000000000000000000000000", "1488735355279105777652263907513985925120"}, + {"0x100000000000000000000000000000000", "0x3e800000000000000000000000000000000", "0xbb800000000000000000000000000000000", "113427455640312821154458202477256070486"}, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + r, err := MulDivRoundingUp( + uint256.MustFromHex(tt.a), uint256.MustFromHex(tt.b), + uint256.MustFromHex(tt.deno)) + require.Nil(t, err) + assert.Equal(t, tt.expResult, r.Dec()) + }) + } + + failTests := []struct { + a string + b string + deno string + }{ + // {"0x100000000000000000000000000000000", "0x5", "0x0"}, // we don't catch div by zero here + // {"0x100000000000000000000000000000000", "0x100000000000000000000000000000000", "0x0"}, + {"0x100000000000000000000000000000000", "0x100000000000000000000000000000000", "0x1"}, + {MaxUint256.Hex(), MaxUint256.Hex(), new(Uint256).SubUint64(MaxUint256, 1).Hex()}, + {"0x1e695d2db4f97", "0x10d5effea103c44aaf18a26b449186a7de3dd6c1ce3d26d03dfd9", "0x2"}, // mulDiv overflows 256 bits after rounding up + {"0xffffffffffffffffffffffffffffffffffffffb07f6d608e4dcc38020b140b35", "0xffffffffffffffffffffffffffffffffffffffb07f6d608e4dcc38020b140b36", "0xffffffffffffffffffffffffffffffffffffff60fedac11c9b9870041628166c"}, // mulDiv overflows 256 bits after rounding up case 2 + } + for i, tt := range failTests { + t.Run(fmt.Sprintf("fail test %d", i), func(t *testing.T) { + x, err := MulDivRoundingUp( + uint256.MustFromHex(tt.a), uint256.MustFromHex(tt.b), + uint256.MustFromHex(tt.deno)) + require.NotNil(t, err, x) + }) + } +} diff --git a/utils/int_types.go b/utils/int_types.go new file mode 100644 index 0000000..f790ea8 --- /dev/null +++ b/utils/int_types.go @@ -0,0 +1,66 @@ +package utils + +import ( + "errors" + + "github.com/KyberNetwork/int256" + "github.com/holiman/uint256" +) + +// define placeholders for these types, in case we need to customize them later +// (for example to add boundary check...) + +type Uint256 = uint256.Int +type Uint160 = uint256.Int +type Uint128 = uint256.Int + +type Int256 = int256.Int +type Int128 = int256.Int + +var ( + ErrExceedMaxInt256 = errors.New("exceed max int256") + ErrOverflowUint128 = errors.New("overflow uint128") + ErrOverflowUint160 = errors.New("overflow uint160") + + Uint128Max = uint256.MustFromHex("0xffffffffffffffffffffffffffffffff") + Uint160Max = uint256.MustFromHex("0xffffffffffffffffffffffffffffffffffffffff") +) + +// https://github.com/Uniswap/v3-core/blob/main/contracts/libraries/SafeCast.sol +func ToInt256(value *Uint256, result *Int256) error { + // if value (interpreted as a two's complement signed number) is negative -> it must be larger than max int256 + if value.Sign() < 0 { + return ErrExceedMaxInt256 + } + var ba [32]byte + value.WriteToArray32(&ba) + result.SetBytes32(ba[:]) + return nil +} + +// https://github.com/Uniswap/v3-core/blob/main/contracts/libraries/SafeCast.sol +func CheckToUint160(value *Uint256) error { + // we're using same type for Uint256 and Uint160, so use the original for now + if value.Cmp(Uint160Max) > 0 { + return ErrOverflowUint160 + } + return nil +} + +// x = x + y +func AddDeltaInPlace(x *Uint128, y *Int128) error { + // for now we're using int256 for Int128, and uint256 for Uint128 + // and both of them is using two's complement internally + // so just cast `y` to uint256 and add them together + var ba [32]byte + y.WriteToArray32(&ba) + var yuint Uint128 + yuint.SetBytes32(ba[:]) + x.Add(x, &yuint) + + if x.Cmp(Uint128Max) > 0 { + // could be overflow or underflow + return ErrOverflowUint128 + } + return nil +} diff --git a/utils/int_types_test.go b/utils/int_types_test.go new file mode 100644 index 0000000..25a41ce --- /dev/null +++ b/utils/int_types_test.go @@ -0,0 +1,74 @@ +package utils + +import ( + "fmt" + "testing" + + "github.com/KyberNetwork/int256" + "github.com/holiman/uint256" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToInt256(t *testing.T) { + // https://github.com/OpenZeppelin/openzeppelin-contracts/blob/692dbc560f48b2a5160e6e4f78302bb93314cd88/test/utils/math/SafeCast.test.js#L124 + + successCases := []string{ + "0x0", + "0x1", + "0x18fe", + "0x9234bbe", + "0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", // INT256_MAX + } + + var res int256.Int + for _, tc := range successCases { + t.Run(fmt.Sprintf("test %s", tc), func(t *testing.T) { + ui := uint256.MustFromHex(tc) + err := ToInt256(ui, &res) + require.Nil(t, err) + + // should be equal to the original value + assert.Equal(t, ui.Dec(), res.Dec()) + }) + } + + // INT256_MAX+1 + assert.ErrorIs(t, ErrExceedMaxInt256, ToInt256(uint256.MustFromHex("0x8000000000000000000000000000000000000000000000000000000000000000"), &res)) + // INT256_MAX+2 + assert.ErrorIs(t, ErrExceedMaxInt256, ToInt256(uint256.MustFromHex("0x8000000000000000000000000000000000000000000000000000000000000001"), &res)) + // UINT256_MAX + assert.ErrorIs(t, ErrExceedMaxInt256, ToInt256(uint256.MustFromHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), &res)) +} + +func TestAddDeltaInPlace(t *testing.T) { + //https://github.com/Uniswap/v3-core/blob/main/test/LiquidityMath.spec.ts + + successCases := []struct { + x *Uint128 + y *Int128 + expX *Uint128 + }{ + {uint256.NewInt(1), int256.NewInt(0), uint256.NewInt(1)}, + {uint256.NewInt(1), int256.NewInt(-1), uint256.NewInt(0)}, + {uint256.NewInt(1), int256.NewInt(1), uint256.NewInt(2)}, + } + + for i, tc := range successCases { + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + err := AddDeltaInPlace(tc.x, tc.y) + require.Nil(t, err) + + // should be equal to the original value + assert.Equal(t, tc.expX.Dec(), tc.x.Dec()) + }) + } + + // 2**128-15 + 15 overflows + tmp := new(uint256.Int).SubUint64(new(uint256.Int).Exp(uint256.NewInt(2), uint256.NewInt(128)), 15) + assert.ErrorIs(t, ErrOverflowUint128, AddDeltaInPlace(tmp, int256.NewInt(15))) + // 0 + -1 underflows + assert.ErrorIs(t, ErrOverflowUint128, AddDeltaInPlace(uint256.NewInt(0), int256.NewInt(-1))) + // 3 + -4 underflows underflows + assert.ErrorIs(t, ErrOverflowUint128, AddDeltaInPlace(uint256.NewInt(3), int256.NewInt(-4))) +} diff --git a/utils/liquidity_math.go b/utils/liquidity_math.go deleted file mode 100644 index 631b4d7..0000000 --- a/utils/liquidity_math.go +++ /dev/null @@ -1,15 +0,0 @@ -package utils - -import ( - "math/big" - - "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" -) - -func AddDelta(x, y *big.Int) *big.Int { - if y.Cmp(constants.Zero) < 0 { - return new(big.Int).Sub(x, new(big.Int).Mul(y, constants.NegativeOne)) - } else { - return new(big.Int).Add(x, y) - } -} diff --git a/utils/most_significant_bit.go b/utils/most_significant_bit.go index 42ae01f..e7d17ff 100644 --- a/utils/most_significant_bit.go +++ b/utils/most_significant_bit.go @@ -2,27 +2,40 @@ package utils import ( "errors" - "math/big" - "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" - "github.com/daoleno/uniswap-sdk-core/entities" + "github.com/holiman/uint256" ) var ErrInvalidInput = errors.New("invalid input") -func MostSignificantBit(x *big.Int) (int64, error) { - if x.Cmp(constants.Zero) <= 0 { - return 0, ErrInvalidInput - } - if x.Cmp(entities.MaxUint256) > 0 { +type powerOf2 struct { + power uint + value *uint256.Int +} + +var powersOf2 = []powerOf2{ + {128, uint256.MustFromHex("0x100000000000000000000000000000000")}, + {64, uint256.MustFromHex("0x10000000000000000")}, + {32, uint256.MustFromHex("0x100000000")}, + {16, uint256.MustFromHex("0x10000")}, + {8, uint256.MustFromHex("0x100")}, + {4, uint256.MustFromHex("0x10")}, + {2, uint256.MustFromHex("0x4")}, + {1, uint256.MustFromHex("0x2")}, +} + +func MostSignificantBit(x *uint256.Int) (uint, error) { + if x.Sign() == 0 { return 0, ErrInvalidInput } - var msb int64 - for _, power := range []int64{128, 64, 32, 16, 8, 4, 2, 1} { - min := new(big.Int).Exp(big.NewInt(2), big.NewInt(int64(power)), nil) - if x.Cmp(min) >= 0 { - x = new(big.Int).Rsh(x, uint(power)) - msb += power + + var tmpX uint256.Int + tmpX.Set(x) + var msb uint + for _, p := range powersOf2 { + if tmpX.Cmp(p.value) >= 0 { + tmpX.Rsh(&tmpX, p.power) + msb += p.power } } return msb, nil diff --git a/utils/most_significant_bit_test.go b/utils/most_significant_bit_test.go new file mode 100644 index 0000000..85435d1 --- /dev/null +++ b/utils/most_significant_bit_test.go @@ -0,0 +1,35 @@ +package utils + +import ( + "fmt" + "testing" + + "github.com/holiman/uint256" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMostSignificantBit(t *testing.T) { + tests := []struct { + value string + expResult uint + }{ + {"0x1", 0}, + {"0x100000000000000000000000000000000", 128}, + {"0x10000000000000000", 64}, + {"0x100000000", 32}, + {"0x10000", 16}, + {"0x100", 8}, + {"0x10", 4}, + {"0x4", 2}, + {"0x2", 1}, + {"0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 255}, // 2^256 - 1 + } + for i, tt := range tests { + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + r, err := MostSignificantBit(uint256.MustFromHex(tt.value)) + require.Nil(t, err) + assert.Equal(t, tt.expResult, r) + }) + } +} diff --git a/utils/sqrtprice_math.go b/utils/sqrtprice_math.go index aca664d..c8b6c6d 100644 --- a/utils/sqrtprice_math.go +++ b/utils/sqrtprice_math.go @@ -5,58 +5,102 @@ import ( "math/big" "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" - "github.com/daoleno/uniswap-sdk-core/entities" + "github.com/holiman/uint256" ) var ( ErrSqrtPriceLessThanZero = errors.New("sqrt price less than zero") ErrLiquidityLessThanZero = errors.New("liquidity less than zero") ErrInvariant = errors.New("invariant violation") + ErrAddOverflow = errors.New("add overflow") ) -var MaxUint160 = new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(160), nil), constants.One) -func multiplyIn256(x, y *big.Int) *big.Int { - product := new(big.Int).Mul(x, y) - return new(big.Int).And(product, entities.MaxUint256) +var MaxUint160 = uint256.MustFromHex("0xffffffffffffffffffffffffffffffffffffffff") + +func multiplyIn256(x, y, product *uint256.Int) *uint256.Int { + product.Mul(x, y) + return product // no need to And with MaxUint256 here } -func addIn256(x, y *big.Int) *big.Int { - sum := new(big.Int).Add(x, y) - return new(big.Int).And(sum, entities.MaxUint256) +func addIn256(x, y, sum *uint256.Int) *uint256.Int { + sum.Add(x, y) + return sum // no need to And with MaxUint256 here } +// deprecated func GetAmount0Delta(sqrtRatioAX96, sqrtRatioBX96, liquidity *big.Int, roundUp bool) *big.Int { + res, err := GetAmount0DeltaV2( + uint256.MustFromBig(sqrtRatioAX96), + uint256.MustFromBig(sqrtRatioBX96), + uint256.MustFromBig(liquidity), + roundUp, + ) + if err != nil { + panic(err) + } + return res.ToBig() +} + +func GetAmount0DeltaV2(sqrtRatioAX96, sqrtRatioBX96 *Uint160, liquidity *Uint128, roundUp bool) (*uint256.Int, error) { // https://github.com/Uniswap/v3-core/blob/d8b1c635c275d2a9450bd6a78f3fa2484fef73eb/contracts/libraries/SqrtPriceMath.sol#L159 if sqrtRatioAX96.Cmp(sqrtRatioBX96) > 0 { sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 } - numerator1 := new(big.Int).Lsh(liquidity, 96) - numerator2 := new(big.Int).Sub(sqrtRatioBX96, sqrtRatioAX96) + var numerator1, numerator2 uint256.Int + numerator1.Lsh(liquidity, 96) + numerator2.Sub(sqrtRatioBX96, sqrtRatioAX96) if roundUp { - return MulDivRoundingUp(MulDivRoundingUp(numerator1, numerator2, sqrtRatioBX96), constants.One, sqrtRatioAX96) + deno, err := MulDivRoundingUp(&numerator1, &numerator2, sqrtRatioBX96) + if err != nil { + return nil, err + } + return DivRoundingUp(deno, sqrtRatioAX96), nil + } + // : FullMath.mulDiv(numerator1, numerator2, sqrtRatioBX96) / sqrtRatioAX96; + tmp, err := MulDiv(&numerator1, &numerator2, sqrtRatioBX96) + if err != nil { + return nil, err } - return new(big.Int).Div(new(big.Int).Div(new(big.Int).Mul(numerator1, numerator2), sqrtRatioBX96), sqrtRatioAX96) + result := new(uint256.Int).Div(tmp, sqrtRatioAX96) + return result, nil } +// deprecated func GetAmount1Delta(sqrtRatioAX96, sqrtRatioBX96, liquidity *big.Int, roundUp bool) *big.Int { + res, err := GetAmount1DeltaV2( + uint256.MustFromBig(sqrtRatioAX96), + uint256.MustFromBig(sqrtRatioBX96), + uint256.MustFromBig(liquidity), + roundUp, + ) + if err != nil { + panic(err) + } + return res.ToBig() +} + +func GetAmount1DeltaV2(sqrtRatioAX96, sqrtRatioBX96 *Uint160, liquidity *Uint128, roundUp bool) (*uint256.Int, error) { // https://github.com/Uniswap/v3-core/blob/d8b1c635c275d2a9450bd6a78f3fa2484fef73eb/contracts/libraries/SqrtPriceMath.sol#L188 if sqrtRatioAX96.Cmp(sqrtRatioBX96) > 0 { sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 } + var diff uint256.Int + diff.Sub(sqrtRatioBX96, sqrtRatioAX96) if roundUp { - return MulDivRoundingUp(liquidity, new(big.Int).Sub(sqrtRatioBX96, sqrtRatioAX96), constants.Q96) + return MulDivRoundingUp(liquidity, &diff, constants.Q96U256) } - return new(big.Int).Div(new(big.Int).Mul(liquidity, new(big.Int).Sub(sqrtRatioBX96, sqrtRatioAX96)), constants.Q96) + // : FullMath.mulDiv(liquidity, sqrtRatioBX96 - sqrtRatioAX96, FixedPoint96.Q96); + return MulDiv(liquidity, &diff, constants.Q96U256) } -func GetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn *big.Int, zeroForOne bool) (*big.Int, error) { - if sqrtPX96.Cmp(constants.Zero) <= 0 { +func GetNextSqrtPriceFromInput(sqrtPX96 *Uint160, liquidity *Uint128, amountIn *uint256.Int, zeroForOne bool) (*Uint160, error) { + if sqrtPX96.Sign() <= 0 { return nil, ErrSqrtPriceLessThanZero } - if liquidity.Cmp(constants.Zero) <= 0 { + if liquidity.Sign() <= 0 { return nil, ErrLiquidityLessThanZero } if zeroForOne { @@ -65,11 +109,11 @@ func GetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn *big.Int, zeroForOn return getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountIn, true) } -func GetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut *big.Int, zeroForOne bool) (*big.Int, error) { - if sqrtPX96.Cmp(constants.Zero) <= 0 { +func GetNextSqrtPriceFromOutput(sqrtPX96 *Uint160, liquidity *Uint128, amountOut *uint256.Int, zeroForOne bool) (*Uint160, error) { + if sqrtPX96.Sign() <= 0 { return nil, ErrSqrtPriceLessThanZero } - if liquidity.Cmp(constants.Zero) <= 0 { + if liquidity.Sign() <= 0 { return nil, ErrLiquidityLessThanZero } if zeroForOne { @@ -78,48 +122,65 @@ func GetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut *big.Int, zeroFor return getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountOut, false) } -func getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amount *big.Int, add bool) (*big.Int, error) { - if amount.Cmp(constants.Zero) == 0 { +func getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96 *Uint160, liquidity *Uint128, amount *uint256.Int, add bool) (*Uint160, error) { + if amount.IsZero() { return sqrtPX96, nil } - numerator1 := new(big.Int).Lsh(liquidity, 96) + var numerator1, denominator, product, tmp uint256.Int + numerator1.Lsh(liquidity, 96) + multiplyIn256(amount, sqrtPX96, &product) if add { - product := multiplyIn256(amount, sqrtPX96) - if new(big.Int).Div(product, amount).Cmp(sqrtPX96) == 0 { - denominator := addIn256(numerator1, product) - if denominator.Cmp(numerator1) >= 0 { - return MulDivRoundingUp(numerator1, sqrtPX96, denominator), nil + if tmp.Div(&product, amount).Cmp(sqrtPX96) == 0 { + addIn256(&numerator1, &product, &denominator) + if denominator.Cmp(&numerator1) >= 0 { + return MulDivRoundingUp(&numerator1, sqrtPX96, &denominator) } } - return MulDivRoundingUp(numerator1, constants.One, new(big.Int).Add(new(big.Int).Div(numerator1, sqrtPX96), amount)), nil + tmp.Div(&numerator1, sqrtPX96) + tmp.Add(&tmp, amount) + return DivRoundingUp(&numerator1, &tmp), nil } else { - product := multiplyIn256(amount, sqrtPX96) - if new(big.Int).Div(product, amount).Cmp(sqrtPX96) != 0 { + if tmp.Div(&product, amount).Cmp(sqrtPX96) != 0 { return nil, ErrInvariant } - if numerator1.Cmp(product) <= 0 { + if numerator1.Cmp(&product) <= 0 { return nil, ErrInvariant } - denominator := new(big.Int).Sub(numerator1, product) - return MulDivRoundingUp(numerator1, sqrtPX96, denominator), nil + denominator.Sub(&numerator1, &product) + return MulDivRoundingUp(&numerator1, sqrtPX96, &denominator) } } -func getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amount *big.Int, add bool) (*big.Int, error) { +func getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96 *Uint160, liquidity *Uint128, amount *uint256.Int, add bool) (*Uint160, error) { if add { - var quotient *big.Int + var quotient, tmp uint256.Int if amount.Cmp(MaxUint160) <= 0 { - quotient = new(big.Int).Div(new(big.Int).Lsh(amount, 96), liquidity) + tmp.Lsh(amount, 96) + quotient.Div(&tmp, liquidity) } else { - quotient = new(big.Int).Div(new(big.Int).Mul(amount, constants.Q96), liquidity) + tmp.Mul(amount, constants.Q96U256) + quotient.Div(&tmp, liquidity) } - return new(big.Int).Add(sqrtPX96, quotient), nil + _, overflow := quotient.AddOverflow("ient, sqrtPX96) + if overflow { + return nil, ErrAddOverflow + } + err := CheckToUint160("ient) + if err != nil { + return nil, err + } + return "ient, nil } - quotient := MulDivRoundingUp(amount, constants.Q96, liquidity) + quotient, err := MulDivRoundingUp(amount, constants.Q96U256, liquidity) + if err != nil { + return nil, err + } if sqrtPX96.Cmp(quotient) <= 0 { return nil, ErrInvariant } - return new(big.Int).Sub(sqrtPX96, quotient), nil + quotient.Sub(sqrtPX96, quotient) + // always fits 160 bits + return quotient, nil } diff --git a/utils/sqrtprice_math_test.go b/utils/sqrtprice_math_test.go new file mode 100644 index 0000000..843e863 --- /dev/null +++ b/utils/sqrtprice_math_test.go @@ -0,0 +1,199 @@ +package utils + +import ( + "fmt" + "math/big" + "testing" + + "github.com/holiman/uint256" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetNextSqrtPriceFromInput(t *testing.T) { + // https://github.com/Uniswap/v3-core/blob/main/test/SqrtPriceMath.spec.ts + p1 := EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)) + + tests := []struct { + price string + liquidity string + amount string + zeroForOne bool + expResult string + }{ + {"0x1", "0x1", "0x8000000000000000000000000000000000000000000000000000000000000000", true, "1"}, + {"0x" + p1.Text(16), "0x16345785d8a0000", "0x0", true, p1.Text(10)}, + {"0x" + p1.Text(16), "0x16345785d8a0000", "0x0", false, p1.Text(10)}, + {"0xffffffffffffffffffffffffffffffffffffffff", "0xffffffffffffffffffffffffffffffff", "0xfffffffffffffffffffffffffffffffffffffffeffffffffffffffffffffffff", true, "1"}, + {"0x" + p1.Text(16), "0xde0b6b3a7640000", "0x16345785d8a0000", false, "87150978765690771352898345369"}, + {"0x" + p1.Text(16), "0xde0b6b3a7640000", "0x16345785d8a0000", true, "72025602285694852357767227579"}, + {"0x" + p1.Text(16), "0x8ac7230489e80000", "0x10000000000000000000000000", true, "624999999995069620"}, + {"0x" + p1.Text(16), "0x1", "0x8000000000000000000000000000000000000000000000000000000000000000", true, "1"}, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + r, err := GetNextSqrtPriceFromInput( + uint256.MustFromHex(tt.price), uint256.MustFromHex(tt.liquidity), + uint256.MustFromHex(tt.amount), tt.zeroForOne) + require.Nil(t, err) + assert.Equal(t, tt.expResult, r.Dec()) + }) + } + + failTests := []struct { + price string + liquidity string + amount string + zeroForOne bool + }{ + {"0x0", "0x1", "0x16345785d8a0000", false}, + {"0x1", "0x0", "0x16345785d8a0000", true}, + } + for i, tt := range failTests { + t.Run(fmt.Sprintf("fail test %d", i), func(t *testing.T) { + _, err := GetNextSqrtPriceFromInput( + uint256.MustFromHex(tt.price), uint256.MustFromHex(tt.liquidity), + uint256.MustFromHex(tt.amount), tt.zeroForOne) + require.NotNil(t, err) + }) + } +} + +func TestGetNextSqrtPriceFromOutput(t *testing.T) { + // https://github.com/Uniswap/v3-core/blob/main/test/SqrtPriceMath.spec.ts + p1 := EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)) + + tests := []struct { + price string + liquidity string + amount string + zeroForOne bool + expResult string + }{ + {"0x100000000000000000000000000", "0x400", "0x3ffff", true, "77371252455336267181195264"}, + {"0x" + p1.Text(16), "0x16345785d8a0000", "0x0", true, p1.Text(10)}, + {"0x" + p1.Text(16), "0x16345785d8a0000", "0x0", false, p1.Text(10)}, + {"0x" + p1.Text(16), "0xde0b6b3a7640000", "0x16345785d8a0000", false, "88031291682515930659493278152"}, + {"0x" + p1.Text(16), "0xde0b6b3a7640000", "0x16345785d8a0000", true, "71305346262837903834189555302"}, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + r, err := GetNextSqrtPriceFromOutput( + uint256.MustFromHex(tt.price), uint256.MustFromHex(tt.liquidity), + uint256.MustFromHex(tt.amount), tt.zeroForOne) + require.Nil(t, err) + assert.Equal(t, tt.expResult, r.Dec()) + }) + } + + failTests := []struct { + price string + liquidity string + amount string + zeroForOne bool + }{ + {"0x0", "0x1", "0x16345785d8a0000", false}, + {"0x1", "0x0", "0x16345785d8a0000", true}, + {"0x100000000000000000000000000", "0x400", "0x4", false}, // output amount is exactly the virtual reserves of token0 + {"0x100000000000000000000000000", "0x400", "0x5", false}, // output amount is greater than virtual reserves of token0 + {"0x100000000000000000000000000", "0x400", "0x40001", true}, // output amount is greater than virtual reserves of token1 + {"0x100000000000000000000000000", "0x400", "0x40000", true}, // output amount is exactly the virtual reserves of token1 + + {"0x" + p1.Text(16), "0x1", MaxUint256.Hex(), true}, // amountOut is impossible in zero for one direction + {"0x" + p1.Text(16), "0x1", MaxUint256.Hex(), false}, // amountOut is impossible in one for zero direction + } + for i, tt := range failTests { + t.Run(fmt.Sprintf("fail test %d", i), func(t *testing.T) { + _, err := GetNextSqrtPriceFromOutput( + uint256.MustFromHex(tt.price), uint256.MustFromHex(tt.liquidity), + uint256.MustFromHex(tt.amount), tt.zeroForOne) + require.NotNil(t, err) + }) + } +} + +func TestGetAmount0Delta(t *testing.T) { + // https://github.com/Uniswap/v3-core/blob/main/test/SqrtPriceMath.spec.ts + p1 := EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)) + p2 := EncodeSqrtRatioX96(big.NewInt(2), big.NewInt(1)) + p3 := EncodeSqrtRatioX96(big.NewInt(121), big.NewInt(100)) + + p4 := EncodeSqrtRatioX96(new(big.Int).Exp(big.NewInt(2), big.NewInt(90), nil), big.NewInt(1)) + p5 := EncodeSqrtRatioX96(new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil), big.NewInt(1)) + + tests := []struct { + price string + liquidity string + amount string + zeroForOne bool + expResult string + }{ + {"0x" + p1.Text(16), "0x" + p2.Text(16), "0x0", true, "0"}, + {"0x" + p1.Text(16), "0x" + p1.Text(16), "0x1", true, "0"}, + {"0x" + p1.Text(16), "0x" + p3.Text(16), "0xde0b6b3a7640000", true, "90909090909090910"}, + {"0x" + p1.Text(16), "0x" + p3.Text(16), "0xde0b6b3a7640000", false, "90909090909090909"}, + {"0x" + p4.Text(16), "0x" + p5.Text(16), "0xde0b6b3a7640000", true, "24869"}, + {"0x" + p4.Text(16), "0x" + p5.Text(16), "0xde0b6b3a7640000", false, "24868"}, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + r, err := GetAmount0DeltaV2( + uint256.MustFromHex(tt.price), uint256.MustFromHex(tt.liquidity), + uint256.MustFromHex(tt.amount), tt.zeroForOne) + require.Nil(t, err) + assert.Equal(t, tt.expResult, r.Dec()) + }) + } +} + +func TestGetAmount1Delta(t *testing.T) { + // https://github.com/Uniswap/v3-core/blob/main/test/SqrtPriceMath.spec.ts + p1 := EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)) + p2 := EncodeSqrtRatioX96(big.NewInt(2), big.NewInt(1)) + p3 := EncodeSqrtRatioX96(big.NewInt(121), big.NewInt(100)) + p4 := EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(121)) + + tests := []struct { + price string + liquidity string + amount string + zeroForOne bool + expResult string + }{ + {"0x" + p1.Text(16), "0x" + p2.Text(16), "0x0", true, "0"}, + {"0x" + p1.Text(16), "0x" + p1.Text(16), "0x1", true, "0"}, + {"0x" + p1.Text(16), "0x" + p3.Text(16), "0xde0b6b3a7640000", true, "100000000000000000"}, + {"0x" + p1.Text(16), "0x" + p3.Text(16), "0xde0b6b3a7640000", false, "99999999999999999"}, + {"0x" + p4.Text(16), "0x" + p1.Text(16), "0xde0b6b3a7640000", true, "90909090909090910"}, + {"0x" + p4.Text(16), "0x" + p1.Text(16), "0xde0b6b3a7640000", false, "90909090909090909"}, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + r, err := GetAmount1DeltaV2( + uint256.MustFromHex(tt.price), uint256.MustFromHex(tt.liquidity), + uint256.MustFromHex(tt.amount), tt.zeroForOne) + require.Nil(t, err) + assert.Equal(t, tt.expResult, r.Dec()) + }) + } +} + +func TestSwap(t *testing.T) { + // sqrtP * sqrtQ overflows + + sqrtQ, err := GetNextSqrtPriceFromInput( + uint256.MustFromDecimal("1025574284609383690408304870162715216695788925244"), + uint256.MustFromDecimal("50015962439936049619261659728067971248"), + uint256.MustFromDecimal("406"), true) + require.Nil(t, err) + + require.Equal(t, "1025574284609383582644711336373707553698163132913", sqrtQ.Dec()) + + amount0Delta, err := GetAmount0DeltaV2( + sqrtQ, + uint256.MustFromDecimal("1025574284609383690408304870162715216695788925244"), + uint256.MustFromDecimal("50015962439936049619261659728067971248"), true) + require.Nil(t, err) + + assert.Equal(t, "406", amount0Delta.Dec()) +} diff --git a/utils/swap_math.go b/utils/swap_math.go index 071c74b..51568af 100644 --- a/utils/swap_math.go +++ b/utils/swap_math.go @@ -3,40 +3,79 @@ package utils import ( "math/big" + "github.com/KyberNetwork/int256" "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" + "github.com/holiman/uint256" ) var MaxFee = new(big.Int).Exp(big.NewInt(10), big.NewInt(6), nil) -func ComputeSwapStep(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, amountRemaining *big.Int, feePips constants.FeeAmount) (sqrtRatioNextX96, amountIn, amountOut, feeAmount *big.Int, err error) { +const MaxFeeInt = 1000000 + +var MaxFeeUint256 = uint256.NewInt(MaxFeeInt) + +func ComputeSwapStep( + sqrtRatioCurrentX96, + sqrtRatioTargetX96 *Uint160, + liquidity *Uint128, + amountRemaining *int256.Int, + feePips constants.FeeAmount, +) (sqrtRatioNextX96 *Uint160, amountIn, amountOut, feeAmount *uint256.Int, err error) { zeroForOne := sqrtRatioCurrentX96.Cmp(sqrtRatioTargetX96) >= 0 - exactIn := amountRemaining.Cmp(constants.Zero) >= 0 + exactIn := amountRemaining.Sign() >= 0 + + var amountRemainingU uint256.Int + if exactIn { + amountRemainingBI := amountRemaining.ToBig() + amountRemainingU.SetFromBig(amountRemainingBI) // TODO: optimize this + } else { + amountRemaining1 := new(int256.Int).Set(amountRemaining) + amountRemainingBI := amountRemaining1.ToBig() + amountRemainingU.SetFromBig(amountRemainingBI) // TODO: optimize this + amountRemainingU.Neg(&amountRemainingU) + } + var maxFeeMinusFeePips uint256.Int + maxFeeMinusFeePips.SetUint64(MaxFeeInt - uint64(feePips)) if exactIn { - amountRemainingLessFee := new(big.Int).Div(new(big.Int).Mul(amountRemaining, new(big.Int).Sub(MaxFee, big.NewInt(int64(feePips)))), MaxFee) + var amountRemainingLessFee, tmp uint256.Int + tmp.Mul(&amountRemainingU, &maxFeeMinusFeePips) + amountRemainingLessFee.Div(&tmp, MaxFeeUint256) if zeroForOne { - amountIn = GetAmount0Delta(sqrtRatioTargetX96, sqrtRatioCurrentX96, liquidity, true) + amountIn, err = GetAmount0DeltaV2(sqrtRatioTargetX96, sqrtRatioCurrentX96, liquidity, true) + if err != nil { + return + } } else { - amountIn = GetAmount1Delta(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, true) + amountIn, err = GetAmount1DeltaV2(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, true) + if err != nil { + return + } } if amountRemainingLessFee.Cmp(amountIn) >= 0 { sqrtRatioNextX96 = sqrtRatioTargetX96 } else { - sqrtRatioNextX96, err = GetNextSqrtPriceFromInput(sqrtRatioCurrentX96, liquidity, amountRemainingLessFee, zeroForOne) + sqrtRatioNextX96, err = GetNextSqrtPriceFromInput(sqrtRatioCurrentX96, liquidity, &amountRemainingLessFee, zeroForOne) if err != nil { return } } } else { if zeroForOne { - amountOut = GetAmount1Delta(sqrtRatioTargetX96, sqrtRatioCurrentX96, liquidity, false) + amountOut, err = GetAmount1DeltaV2(sqrtRatioTargetX96, sqrtRatioCurrentX96, liquidity, false) + if err != nil { + return + } } else { - amountOut = GetAmount0Delta(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, false) + amountOut, err = GetAmount0DeltaV2(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, false) + if err != nil { + return + } } - if new(big.Int).Mul(amountRemaining, constants.NegativeOne).Cmp(amountOut) >= 0 { + if amountRemainingU.Cmp(amountOut) >= 0 { sqrtRatioNextX96 = sqrtRatioTargetX96 } else { - sqrtRatioNextX96, err = GetNextSqrtPriceFromOutput(sqrtRatioCurrentX96, liquidity, new(big.Int).Mul(amountRemaining, constants.NegativeOne), zeroForOne) + sqrtRatioNextX96, err = GetNextSqrtPriceFromOutput(sqrtRatioCurrentX96, liquidity, &amountRemainingU, zeroForOne) if err != nil { return } @@ -47,29 +86,44 @@ func ComputeSwapStep(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, amountR if zeroForOne { if !(max && exactIn) { - amountIn = GetAmount0Delta(sqrtRatioNextX96, sqrtRatioCurrentX96, liquidity, true) + amountIn, err = GetAmount0DeltaV2(sqrtRatioNextX96, sqrtRatioCurrentX96, liquidity, true) + if err != nil { + return + } } if !(max && !exactIn) { - amountOut = GetAmount1Delta(sqrtRatioNextX96, sqrtRatioCurrentX96, liquidity, false) + amountOut, err = GetAmount1DeltaV2(sqrtRatioNextX96, sqrtRatioCurrentX96, liquidity, false) + if err != nil { + return + } } } else { if !(max && exactIn) { - amountIn = GetAmount1Delta(sqrtRatioCurrentX96, sqrtRatioNextX96, liquidity, true) + amountIn, err = GetAmount1DeltaV2(sqrtRatioCurrentX96, sqrtRatioNextX96, liquidity, true) + if err != nil { + return + } } if !(max && !exactIn) { - amountOut = GetAmount0Delta(sqrtRatioCurrentX96, sqrtRatioNextX96, liquidity, false) + amountOut, err = GetAmount0DeltaV2(sqrtRatioCurrentX96, sqrtRatioNextX96, liquidity, false) + if err != nil { + return + } } } - if !exactIn && amountOut.Cmp(new(big.Int).Mul(amountRemaining, constants.NegativeOne)) > 0 { - amountOut = new(big.Int).Mul(amountRemaining, constants.NegativeOne) + if !exactIn && amountOut.Cmp(&amountRemainingU) > 0 { + amountOut = &amountRemainingU } if exactIn && sqrtRatioNextX96.Cmp(sqrtRatioTargetX96) != 0 { // we didn't reach the target, so take the remainder of the maximum input as fee - feeAmount = new(big.Int).Sub(amountRemaining, amountIn) + feeAmount = new(uint256.Int).Sub(&amountRemainingU, amountIn) } else { - feeAmount = MulDivRoundingUp(amountIn, big.NewInt(int64(feePips)), new(big.Int).Sub(MaxFee, big.NewInt(int64(feePips)))) + feeAmount, err = MulDivRoundingUp(amountIn, uint256.NewInt(uint64(feePips)), &maxFeeMinusFeePips) + if err != nil { + return + } } return diff --git a/utils/swap_math_test.go b/utils/swap_math_test.go new file mode 100644 index 0000000..7ca3a6c --- /dev/null +++ b/utils/swap_math_test.go @@ -0,0 +1,82 @@ +package utils + +import ( + "fmt" + "math/big" + "testing" + + "github.com/KyberNetwork/int256" + "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" + "github.com/holiman/uint256" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestComputeSwapStep(t *testing.T) { + + p1 := EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)) + p2 := EncodeSqrtRatioX96(big.NewInt(101), big.NewInt(100)) + p3 := EncodeSqrtRatioX96(big.NewInt(1000), big.NewInt(100)) + p4 := EncodeSqrtRatioX96(big.NewInt(10000), big.NewInt(100)) + + tests := []struct { + price string + priceTarget string + liquidity string + amount string + fee constants.FeeAmount + + expAmountIn string + expAmountOut string + expFee string + + expNextPrice string + }{ + {p1.String(), p2.String(), "2000000000000000000", "1000000000000000000", 600, + "9975124224178055", "9925619580021728", "5988667735148", "="}, + {p1.String(), p2.String(), "2000000000000000000", "-1000000000000000000", 600, + "9975124224178055", "9925619580021728", "5988667735148", "="}, + + {p1.String(), p3.String(), "2000000000000000000", "1000000000000000000", 600, + "999400000000000000", "666399946655997866", "600000000000000", "<"}, + {p1.String(), p4.String(), "2000000000000000000", "-1000000000000000000", 600, + "2000000000000000000", "1000000000000000000", "1200720432259356", "<"}, + + {"417332158212080721273783715441582", "1452870262520218020823638996", "159344665391607089467575320103", "-1", 1, + "1", "1", "1", "417332158212080721273783715441581"}, + + {"2", "1", "1", "3915081100057732413702495386755767", 1, + "39614081257132168796771975168", "0", "39614120871253040049813", "1"}, + + {"2413", "79887613182836312", "1985041575832132834610021537970", "10", 1872, + "0", "0", "10", "2413"}, + + {"20282409603651670423947251286016", "22310650564016837466341976414617", "1024", "-4", 3000, + "26215", "0", "79", "="}, + + {"20282409603651670423947251286016", "18254168643286503381552526157414", "1024", "-263000", 3000, + "1", "26214", "1", "="}, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + sqrtRatioNextX96, amountIn, amountOut, feeAmount, err := ComputeSwapStep( + uint256.MustFromDecimal(tt.price), + uint256.MustFromDecimal(tt.priceTarget), + uint256.MustFromDecimal(tt.liquidity), + int256.MustFromDec(tt.amount), + tt.fee, + ) + require.Nil(t, err) + if tt.expNextPrice == "=" { + assert.Equal(t, tt.priceTarget, sqrtRatioNextX96.Dec()) + } else if tt.expNextPrice == "<" { + assert.Greater(t, tt.priceTarget, sqrtRatioNextX96.Dec()) + } else { + assert.Equal(t, tt.expNextPrice, sqrtRatioNextX96.Dec()) + } + assert.Equal(t, tt.expAmountIn, amountIn.Dec()) + assert.Equal(t, tt.expAmountOut, amountOut.Dec()) + assert.Equal(t, tt.expFee, feeAmount.Dec()) + }) + } +} diff --git a/utils/tick_math.go b/utils/tick_math.go index 7f1c30f..107ca10 100644 --- a/utils/tick_math.go +++ b/utils/tick_math.go @@ -4,8 +4,8 @@ import ( "errors" "math/big" - "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" - "github.com/daoleno/uniswap-sdk-core/entities" + "github.com/KyberNetwork/int256" + "github.com/holiman/uint256" ) const ( @@ -17,6 +17,10 @@ var ( Q32 = big.NewInt(1 << 32) MinSqrtRatio = big.NewInt(4295128739) // The sqrt ratio corresponding to the minimum tick that could be used on any pool. MaxSqrtRatio, _ = new(big.Int).SetString("1461446703485210103287273052203988822378723970342", 10) // The sqrt ratio corresponding to the maximum tick that could be used on any pool. + + Q32U256 = uint256.NewInt(1 << 32) + MinSqrtRatioU256 = uint256.NewInt(4295128739) // The sqrt ratio corresponding to the minimum tick that could be used on any pool. + MaxSqrtRatioU256 = uint256.MustFromDecimal("1461446703485210103287273052203988822378723970342") // The sqrt ratio corresponding to the maximum tick that could be used on any pool. ) var ( @@ -24,40 +28,48 @@ var ( ErrInvalidSqrtRatio = errors.New("invalid sqrt ratio") ) -func mulShift(val *big.Int, mulBy *big.Int) *big.Int { - - return new(big.Int).Rsh(new(big.Int).Mul(val, mulBy), 128) +func mulShift(val *Uint256, mulBy *Uint256) { + var tmp Uint256 + val.Rsh(tmp.Mul(val, mulBy), 128) } var ( - sqrtConst1, _ = new(big.Int).SetString("fffcb933bd6fad37aa2d162d1a594001", 16) - sqrtConst2, _ = new(big.Int).SetString("100000000000000000000000000000000", 16) - sqrtConst3, _ = new(big.Int).SetString("fff97272373d413259a46990580e213a", 16) - sqrtConst4, _ = new(big.Int).SetString("fff2e50f5f656932ef12357cf3c7fdcc", 16) - sqrtConst5, _ = new(big.Int).SetString("ffe5caca7e10e4e61c3624eaa0941cd0", 16) - sqrtConst6, _ = new(big.Int).SetString("ffcb9843d60f6159c9db58835c926644", 16) - sqrtConst7, _ = new(big.Int).SetString("ff973b41fa98c081472e6896dfb254c0", 16) - sqrtConst8, _ = new(big.Int).SetString("ff2ea16466c96a3843ec78b326b52861", 16) - sqrtConst9, _ = new(big.Int).SetString("fe5dee046a99a2a811c461f1969c3053", 16) - sqrtConst10, _ = new(big.Int).SetString("fcbe86c7900a88aedcffc83b479aa3a4", 16) - sqrtConst11, _ = new(big.Int).SetString("f987a7253ac413176f2b074cf7815e54", 16) - sqrtConst12, _ = new(big.Int).SetString("f3392b0822b70005940c7a398e4b70f3", 16) - sqrtConst13, _ = new(big.Int).SetString("e7159475a2c29b7443b29c7fa6e889d9", 16) - sqrtConst14, _ = new(big.Int).SetString("d097f3bdfd2022b8845ad8f792aa5825", 16) - sqrtConst15, _ = new(big.Int).SetString("a9f746462d870fdf8a65dc1f90e061e5", 16) - sqrtConst16, _ = new(big.Int).SetString("70d869a156d2a1b890bb3df62baf32f7", 16) - sqrtConst17, _ = new(big.Int).SetString("31be135f97d08fd981231505542fcfa6", 16) - sqrtConst18, _ = new(big.Int).SetString("9aa508b5b7a84e1c677de54f3e99bc9", 16) - sqrtConst19, _ = new(big.Int).SetString("5d6af8dedb81196699c329225ee604", 16) - sqrtConst20, _ = new(big.Int).SetString("2216e584f5fa1ea926041bedfe98", 16) - sqrtConst21, _ = new(big.Int).SetString("48a170391f7dc42444e8fa2", 16) + sqrtConst1 = uint256.MustFromHex("0xfffcb933bd6fad37aa2d162d1a594001") + sqrtConst2 = uint256.MustFromHex("0x100000000000000000000000000000000") + sqrtConst3 = uint256.MustFromHex("0xfff97272373d413259a46990580e213a") + sqrtConst4 = uint256.MustFromHex("0xfff2e50f5f656932ef12357cf3c7fdcc") + sqrtConst5 = uint256.MustFromHex("0xffe5caca7e10e4e61c3624eaa0941cd0") + sqrtConst6 = uint256.MustFromHex("0xffcb9843d60f6159c9db58835c926644") + sqrtConst7 = uint256.MustFromHex("0xff973b41fa98c081472e6896dfb254c0") + sqrtConst8 = uint256.MustFromHex("0xff2ea16466c96a3843ec78b326b52861") + sqrtConst9 = uint256.MustFromHex("0xfe5dee046a99a2a811c461f1969c3053") + sqrtConst10 = uint256.MustFromHex("0xfcbe86c7900a88aedcffc83b479aa3a4") + sqrtConst11 = uint256.MustFromHex("0xf987a7253ac413176f2b074cf7815e54") + sqrtConst12 = uint256.MustFromHex("0xf3392b0822b70005940c7a398e4b70f3") + sqrtConst13 = uint256.MustFromHex("0xe7159475a2c29b7443b29c7fa6e889d9") + sqrtConst14 = uint256.MustFromHex("0xd097f3bdfd2022b8845ad8f792aa5825") + sqrtConst15 = uint256.MustFromHex("0xa9f746462d870fdf8a65dc1f90e061e5") + sqrtConst16 = uint256.MustFromHex("0x70d869a156d2a1b890bb3df62baf32f7") + sqrtConst17 = uint256.MustFromHex("0x31be135f97d08fd981231505542fcfa6") + sqrtConst18 = uint256.MustFromHex("0x9aa508b5b7a84e1c677de54f3e99bc9") + sqrtConst19 = uint256.MustFromHex("0x5d6af8dedb81196699c329225ee604") + sqrtConst20 = uint256.MustFromHex("0x2216e584f5fa1ea926041bedfe98") + sqrtConst21 = uint256.MustFromHex("0x48a170391f7dc42444e8fa2") + + MaxUint256 = uint256.MustFromHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") ) +// deprecated +func GetSqrtRatioAtTick(tick int) (*big.Int, error) { + res, err := GetSqrtRatioAtTickV2(tick) + return res.ToBig(), err +} + /** * Returns the sqrt ratio as a Q64.96 for the given tick. The sqrt ratio is computed as sqrt(1.0001)^tick * @param tick the tick for which to compute the sqrt ratio */ -func GetSqrtRatioAtTick(tick int) (*big.Int, error) { +func GetSqrtRatioAtTickV2(tick int) (*Uint160, error) { if tick < MinTick || tick > MaxTick { return nil, ErrInvalidTick } @@ -65,127 +77,146 @@ func GetSqrtRatioAtTick(tick int) (*big.Int, error) { if tick < 0 { absTick = -tick } - var ratio *big.Int + var ratio, tmp Uint256 if absTick&0x1 != 0 { - ratio = sqrtConst1 + ratio.Set(sqrtConst1) } else { - ratio = sqrtConst2 + ratio.Set(sqrtConst2) } if (absTick & 0x2) != 0 { - ratio = mulShift(ratio, sqrtConst3) + mulShift(&ratio, sqrtConst3) } if (absTick & 0x4) != 0 { - ratio = mulShift(ratio, sqrtConst4) + mulShift(&ratio, sqrtConst4) } if (absTick & 0x8) != 0 { - ratio = mulShift(ratio, sqrtConst5) + mulShift(&ratio, sqrtConst5) } if (absTick & 0x10) != 0 { - ratio = mulShift(ratio, sqrtConst6) + mulShift(&ratio, sqrtConst6) } if (absTick & 0x20) != 0 { - ratio = mulShift(ratio, sqrtConst7) + mulShift(&ratio, sqrtConst7) } if (absTick & 0x40) != 0 { - ratio = mulShift(ratio, sqrtConst8) + mulShift(&ratio, sqrtConst8) } if (absTick & 0x80) != 0 { - ratio = mulShift(ratio, sqrtConst9) + mulShift(&ratio, sqrtConst9) } if (absTick & 0x100) != 0 { - ratio = mulShift(ratio, sqrtConst10) + mulShift(&ratio, sqrtConst10) } if (absTick & 0x200) != 0 { - ratio = mulShift(ratio, sqrtConst11) + mulShift(&ratio, sqrtConst11) } if (absTick & 0x400) != 0 { - ratio = mulShift(ratio, sqrtConst12) + mulShift(&ratio, sqrtConst12) } if (absTick & 0x800) != 0 { - ratio = mulShift(ratio, sqrtConst13) + mulShift(&ratio, sqrtConst13) } if (absTick & 0x1000) != 0 { - ratio = mulShift(ratio, sqrtConst14) + mulShift(&ratio, sqrtConst14) } if (absTick & 0x2000) != 0 { - ratio = mulShift(ratio, sqrtConst15) + mulShift(&ratio, sqrtConst15) } if (absTick & 0x4000) != 0 { - ratio = mulShift(ratio, sqrtConst16) + mulShift(&ratio, sqrtConst16) } if (absTick & 0x8000) != 0 { - ratio = mulShift(ratio, sqrtConst17) + mulShift(&ratio, sqrtConst17) } if (absTick & 0x10000) != 0 { - ratio = mulShift(ratio, sqrtConst18) + mulShift(&ratio, sqrtConst18) } if (absTick & 0x20000) != 0 { - ratio = mulShift(ratio, sqrtConst19) + mulShift(&ratio, sqrtConst19) } if (absTick & 0x40000) != 0 { - ratio = mulShift(ratio, sqrtConst20) + mulShift(&ratio, sqrtConst20) } if (absTick & 0x80000) != 0 { - ratio = mulShift(ratio, sqrtConst21) + mulShift(&ratio, sqrtConst21) } if tick > 0 { - ratio = new(big.Int).Div(entities.MaxUint256, ratio) + tmp.Div(MaxUint256, &ratio) + ratio.Set(&tmp) } // back to Q96 - if new(big.Int).Rem(ratio, Q32).Cmp(constants.Zero) > 0 { - return new(big.Int).Add((new(big.Int).Div(ratio, Q32)), constants.One), nil + var rem Uint256 + tmp.DivMod(&ratio, Q32U256, &rem) + if !rem.IsZero() { + tmp.AddUint64(&tmp, 1) + return &tmp, nil } else { - return new(big.Int).Div(ratio, Q32), nil + return &tmp, nil } } var ( - magicSqrt10001, _ = new(big.Int).SetString("255738958999603826347141", 10) - magicTickLow, _ = new(big.Int).SetString("3402992956809132418596140100660247210", 10) - magicTickHigh, _ = new(big.Int).SetString("291339464771989622907027621153398088495", 10) + magicSqrt10001 = int256.MustFromDec("255738958999603826347141") + magicTickLow = int256.MustFromDec("3402992956809132418596140100660247210") + magicTickHigh = int256.MustFromDec("291339464771989622907027621153398088495") ) +// deprecated +func GetTickAtSqrtRatio(sqrtRatioX96 *big.Int) (int, error) { + return GetTickAtSqrtRatioV2(uint256.MustFromBig(sqrtRatioX96)) +} + /** * Returns the tick corresponding to a given sqrt ratio, s.t. #getSqrtRatioAtTick(tick) <= sqrtRatioX96 * and #getSqrtRatioAtTick(tick + 1) > sqrtRatioX96 * @param sqrtRatioX96 the sqrt ratio as a Q64.96 for which to compute the tick */ -func GetTickAtSqrtRatio(sqrtRatioX96 *big.Int) (int, error) { - if sqrtRatioX96.Cmp(MinSqrtRatio) < 0 || sqrtRatioX96.Cmp(MaxSqrtRatio) >= 0 { +func GetTickAtSqrtRatioV2(sqrtRatioX96 *Uint160) (int, error) { + if sqrtRatioX96.Cmp(MinSqrtRatioU256) < 0 || sqrtRatioX96.Cmp(MaxSqrtRatioU256) >= 0 { return 0, ErrInvalidSqrtRatio } - sqrtRatioX128 := new(big.Int).Lsh(sqrtRatioX96, 32) - msb, err := MostSignificantBit(sqrtRatioX128) + var sqrtRatioX128 Uint256 + sqrtRatioX128.Lsh(sqrtRatioX96, 32) + msb, err := MostSignificantBit(&sqrtRatioX128) if err != nil { return 0, err } - var r *big.Int - if big.NewInt(msb).Cmp(big.NewInt(128)) >= 0 { - r = new(big.Int).Rsh(sqrtRatioX128, uint(msb-127)) + var r Uint256 + if msb >= 128 { + r.Rsh(&sqrtRatioX128, msb-127) } else { - r = new(big.Int).Lsh(sqrtRatioX128, uint(127-msb)) + r.Lsh(&sqrtRatioX128, 127-msb) } - log2 := new(big.Int).Lsh(new(big.Int).Sub(big.NewInt(msb), big.NewInt(128)), 64) + log2 := int256.NewInt(int64(msb - 128)) + log2.Lsh(log2, 64) + var tmp, f Uint256 for i := 0; i < 14; i++ { - r = new(big.Int).Rsh(new(big.Int).Mul(r, r), 127) - f := new(big.Int).Rsh(r, 128) - log2 = new(big.Int).Or(log2, new(big.Int).Lsh(f, uint(63-i))) - r = new(big.Int).Rsh(r, uint(f.Int64())) + tmp.Mul(&r, &r) + r.Rsh(&tmp, 127) + f.Rsh(&r, 128) + tmp.Lsh(&f, uint(63-i)) + + // this is for Or, so we can cast the underlying words directly without copying + tmpsigned := (*int256.Int)(&tmp) + + log2.Or(log2, tmpsigned) + r.Rsh(&r, uint(f.Uint64())) } - logSqrt10001 := new(big.Int).Mul(log2, magicSqrt10001) + var logSqrt10001, tmp1, tmp2 Int256 + logSqrt10001.Mul(log2, magicSqrt10001) - tickLow := new(big.Int).Rsh(new(big.Int).Sub(logSqrt10001, magicTickLow), 128).Int64() - tickHigh := new(big.Int).Rsh(new(big.Int).Add(logSqrt10001, magicTickHigh), 128).Int64() + tickLow := tmp2.Rsh(tmp1.Sub(&logSqrt10001, magicTickLow), 128).Uint64() + tickHigh := tmp2.Rsh(tmp1.Add(&logSqrt10001, magicTickHigh), 128).Uint64() if tickLow == tickHigh { return int(tickLow), nil } - sqrtRatio, err := GetSqrtRatioAtTick(int(tickHigh)) + sqrtRatio, err := GetSqrtRatioAtTickV2(int(tickHigh)) if err != nil { return 0, err } diff --git a/utils/tick_math_test.go b/utils/tick_math_test.go index e280a16..91f6bef 100644 --- a/utils/tick_math_test.go +++ b/utils/tick_math_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/KyberNetwork/uniswapv3-sdk-uint256/constants" + "github.com/holiman/uint256" "github.com/stretchr/testify/assert" ) @@ -18,17 +19,35 @@ func TestGetSqrtRatioAtTick(t *testing.T) { rmax, _ := GetSqrtRatioAtTick(MinTick) assert.Equal(t, rmax, MinSqrtRatio, "returns the correct value for min tick") + r, _ := GetSqrtRatioAtTickV2(MinTick + 1) + assert.Equal(t, uint256.NewInt(4295343490), r, "returns the correct value for min tick + 1") + r0, _ := GetSqrtRatioAtTick(0) assert.Equal(t, r0, new(big.Int).Lsh(constants.One, 96), "returns the correct value for tick 0") rmin, _ := GetSqrtRatioAtTick(MaxTick) assert.Equal(t, rmin, MaxSqrtRatio, "returns the correct value for max tick") + + r, _ = GetSqrtRatioAtTickV2(MaxTick - 1) + assert.Equal(t, uint256.MustFromDecimal("1461373636630004318706518188784493106690254656249"), r, "returns the correct value for max tick - 1") + + r, _ = GetSqrtRatioAtTickV2(MaxTick) + assert.Equal(t, uint256.MustFromDecimal("1461446703485210103287273052203988822378723970342"), r, "returns the correct value for max tick") } func TestGetTickAtSqrtRatio(t *testing.T) { tmin, _ := GetTickAtSqrtRatio(MinSqrtRatio) assert.Equal(t, tmin, MinTick, "returns the correct value for sqrt ratio at min tick") + _, err := GetTickAtSqrtRatioV2(new(uint256.Int).SubUint64(MinSqrtRatioU256, 1)) + assert.ErrorIs(t, ErrInvalidSqrtRatio, err) + + _, err = GetTickAtSqrtRatioV2(MaxSqrtRatioU256) + assert.ErrorIs(t, ErrInvalidSqrtRatio, err) + tmax, _ := GetTickAtSqrtRatio(new(big.Int).Sub(MaxSqrtRatio, constants.One)) assert.Equal(t, tmax, MaxTick-1, "returns the correct value for sqrt ratio at max tick") + + tt, _ := GetTickAtSqrtRatio(big.NewInt(4295343490)) + assert.Equal(t, MinTick+1, tt) }