diff --git a/entities/pool.go b/entities/pool.go index 99753c3..9aff6a8 100644 --- a/entities/pool.go +++ b/entities/pool.go @@ -218,14 +218,14 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int } if zeroForOne { - if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatio) <= 0 { + if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatio) < 0 { return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooLow } if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) >= 0 { return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooHigh } } else { - if sqrtPriceLimitX96.Cmp(utils.MaxSqrtRatio) >= 0 { + if sqrtPriceLimitX96.Cmp(utils.MaxSqrtRatio) > 0 { return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooHigh } if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) <= 0 { @@ -259,7 +259,10 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int // because each iteration of the while loop rounds, we can't optimize this code (relative to the smart contract) // by simply traversing to the next available tick, we instead need to exactly replicate // tickBitmap.nextInitializedTickWithinOneWord - step.tickNext, step.initialized = p.TickDataProvider.NextInitializedTickWithinOneWord(state.tick, zeroForOne, p.tickSpacing()) + step.tickNext, step.initialized, err = p.TickDataProvider.NextInitializedTickIndex(state.tick, zeroForOne) + if err != nil { + return nil, nil, nil, 0, err + } if step.tickNext < utils.MinTick { step.tickNext = utils.MinTick @@ -303,7 +306,12 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int if state.sqrtPriceX96.Cmp(step.sqrtPriceNextX96) == 0 { // if the tick is initialized, run the tick transition if step.initialized { - liquidityNet := p.TickDataProvider.GetTick(step.tickNext).LiquidityNet + tick, err := p.TickDataProvider.GetTick(step.tickNext) + if err != nil { + return nil, nil, nil, 0, err + } + + liquidityNet := tick.LiquidityNet // if we're moving leftward, we interpret liquidityNet as the opposite sign // safe because liquidityNet cannot be type(int128).min if zeroForOne { diff --git a/entities/tickdataprovider.go b/entities/tickdataprovider.go index e50d5be..6086e61 100644 --- a/entities/tickdataprovider.go +++ b/entities/tickdataprovider.go @@ -14,7 +14,7 @@ type TickDataProvider interface { * Return information corresponding to a specific tick * @param tick the tick to load */ - GetTick(tick int) Tick + GetTick(tick int) (Tick, error) /** * Return the next tick that is initialized within a single word @@ -22,5 +22,8 @@ type TickDataProvider interface { * @param lte Whether the next tick should be lte the current tick * @param tickSpacing The tick spacing of the pool */ - NextInitializedTickWithinOneWord(tick int, lte bool, tickSpacing int) (int, bool) + NextInitializedTickWithinOneWord(tick int, lte bool, tickSpacing int) (int, bool, error) + + // NextInitializedTickIndex return the next tick that is initialized + NextInitializedTickIndex(tick int, lte bool) (int, bool, error) } diff --git a/entities/ticklist.go b/entities/ticklist.go index d287e55..82953e1 100644 --- a/entities/ticklist.go +++ b/entities/ticklist.go @@ -6,11 +6,24 @@ import ( "math/big" ) +const ( + ZeroValueTickIndex = 0 + ZeroValueTickInitialized = false +) + var ( ErrZeroTickSpacing = errors.New("tick spacing must be greater than 0") ErrInvalidTickSpacing = errors.New("invalid tick spacing") ErrZeroNet = errors.New("tick net delta must be zero") ErrSorted = errors.New("ticks must be sorted") + ErrEmptyTickList = errors.New("empty tick list") + ErrBelowSmallest = errors.New("below smallest") + ErrAtOrAboveLargest = errors.New("at or above largest") + ErrInvalidTickIndex = errors.New("invalid tick index") +) + +var ( + EmptyTick = Tick{} ) func ValidateList(ticks []Tick, tickSpacing int) error { @@ -41,74 +54,148 @@ func ValidateList(ticks []Tick, tickSpacing int) error { return nil } -func IsBelowSmallest(ticks []Tick, tick int) bool { +func IsBelowSmallest(ticks []Tick, tick int) (bool, error) { if len(ticks) == 0 { - panic("empty tick list") + return true, ErrEmptyTickList } - return tick < ticks[0].Index + + return tick < ticks[0].Index, nil } -func IsAtOrAboveLargest(ticks []Tick, tick int) bool { +func IsAtOrAboveLargest(ticks []Tick, tick int) (bool, error) { if len(ticks) == 0 { - panic("empty tick list") + return true, ErrEmptyTickList } - return tick >= ticks[len(ticks)-1].Index + + return tick >= ticks[len(ticks)-1].Index, nil } -func GetTick(ticks []Tick, index int) Tick { - tick := ticks[binarySearch(ticks, index)] - if tick.Index != index { - panic("index is not contained in ticks") +func GetTick(ticks []Tick, index int) (Tick, error) { + tickIndex, err := binarySearch(ticks, index) + if err != nil { + return EmptyTick, err + } + + if tickIndex < 0 { + return EmptyTick, ErrInvalidTickIndex } - return tick + + tick := ticks[tickIndex] + + return tick, nil } -func NextInitializedTick(ticks []Tick, tick int, lte bool) Tick { +func NextInitializedTick(ticks []Tick, tick int, lte bool) (Tick, error) { if lte { - if IsBelowSmallest(ticks, tick) { - panic("below smallest") + isBelowSmallest, err := IsBelowSmallest(ticks, tick) + if err != nil { + return EmptyTick, err } - if IsAtOrAboveLargest(ticks, tick) { - return ticks[len(ticks)-1] + + if isBelowSmallest { + return EmptyTick, ErrBelowSmallest + } + + isAtOrAboveLargest, err := IsAtOrAboveLargest(ticks, tick) + if err != nil { + return EmptyTick, err } - index := binarySearch(ticks, tick) - return ticks[index] + + if isAtOrAboveLargest { + return ticks[len(ticks)-1], nil + } + + index, err := binarySearch(ticks, tick) + if err != nil { + return EmptyTick, err + } + + return ticks[index], nil } else { - if IsAtOrAboveLargest(ticks, tick) { - panic("at or above largest") + isAtOrAboveLargest, err := IsAtOrAboveLargest(ticks, tick) + if err != nil { + return EmptyTick, err + } + + if isAtOrAboveLargest { + return EmptyTick, ErrAtOrAboveLargest + } + + isBelowSmallest, err := IsBelowSmallest(ticks, tick) + + if err != nil { + return EmptyTick, err + } + + if isBelowSmallest { + return ticks[0], nil } - if IsBelowSmallest(ticks, tick) { - return ticks[0] + + index, err := binarySearch(ticks, tick) + if err != nil { + return EmptyTick, err } - index := binarySearch(ticks, tick) - return ticks[index+1] + + return ticks[index+1], nil } } -func NextInitializedTickWithinOneWord(ticks []Tick, tick int, lte bool, tickSpacing int) (int, bool) { +func NextInitializedTickWithinOneWord(ticks []Tick, tick int, lte bool, tickSpacing int) (int, bool, error) { compressed := math.Floor(float64(tick) / float64(tickSpacing)) // matches rounding in the code if lte { wordPos := int(compressed) >> 8 minimum := (wordPos << 8) * tickSpacing - if IsBelowSmallest(ticks, tick) { - return minimum, false + isBelowSmallest, err := IsBelowSmallest(ticks, tick) + if err != nil { + return ZeroValueTickIndex, ZeroValueTickInitialized, err } - index := NextInitializedTick(ticks, tick, lte).Index - nextInitializedTick := math.Max(float64(minimum), float64(index)) - return int(nextInitializedTick), int(nextInitializedTick) == index + + if isBelowSmallest { + return minimum, ZeroValueTickInitialized, ErrBelowSmallest + } + + nextInitializedTick, err := NextInitializedTick(ticks, tick, lte) + if err != nil { + return ZeroValueTickIndex, ZeroValueTickInitialized, err + } + + index := nextInitializedTick.Index + nextInitializedTickIndex := math.Max(float64(minimum), float64(index)) + return int(nextInitializedTickIndex), int(nextInitializedTickIndex) == index, nil } else { wordPos := int(compressed+1) >> 8 maximum := ((wordPos+1)<<8)*tickSpacing - 1 - if IsAtOrAboveLargest(ticks, tick) { - return maximum, false + isAtOrAboveLargest, err := IsAtOrAboveLargest(ticks, tick) + if err != nil { + return ZeroValueTickIndex, ZeroValueTickInitialized, err } - index := NextInitializedTick(ticks, tick, lte).Index - nextInitializedTick := math.Min(float64(maximum), float64(index)) - return int(nextInitializedTick), int(nextInitializedTick) == index + + if isAtOrAboveLargest { + return maximum, ZeroValueTickInitialized, ErrAtOrAboveLargest + } + + nextInitializedTick, err := NextInitializedTick(ticks, tick, lte) + if err != nil { + return ZeroValueTickIndex, ZeroValueTickInitialized, err + } + + index := nextInitializedTick.Index + nextInitializedTickIndex := math.Min(float64(maximum), float64(index)) + return int(nextInitializedTickIndex), int(nextInitializedTickIndex) == index, nil } } +func NextInitializedTickIndex(ticks []Tick, tick int, lte bool) (int, bool, error) { + nextInitializedTick, err := NextInitializedTick(ticks, tick, lte) + if err != nil { + return ZeroValueTickIndex, ZeroValueTickInitialized, err + } + + // The found tick is surely initialized + return nextInitializedTick.Index, true, nil +} + // utils func isTicksSorted(ticks []Tick) bool { @@ -126,9 +213,14 @@ func isTicksSorted(ticks []Tick) bool { * @param tick tick to find the largest tick that is less than or equal to tick * @private */ -func binarySearch(ticks []Tick, tick int) int { - if IsBelowSmallest(ticks, tick) { - panic("tick is below smallest tick") +func binarySearch(ticks []Tick, tick int) (int, error) { + isBelowSmallest, err := IsBelowSmallest(ticks, tick) + if err != nil { + return ZeroValueTickIndex, err + } + + if isBelowSmallest { + return ZeroValueTickIndex, ErrBelowSmallest } // binary search @@ -137,7 +229,7 @@ func binarySearch(ticks []Tick, tick int) int { for start <= end { mid := (start + end) / 2 if ticks[mid].Index == tick { - return mid + return mid, nil } else if ticks[mid].Index < tick { start = mid + 1 } else { @@ -148,8 +240,8 @@ func binarySearch(ticks []Tick, tick int) int { // if we get here, we didn't find a tick that is less than or equal to tick // so we return the index of the tick that is closest to tick if ticks[start].Index < tick { - return start + return start, nil } else { - return start - 1 + return start - 1, nil } } diff --git a/entities/ticklist_test.go b/entities/ticklist_test.go index 02e1ce8..9afc6bd 100644 --- a/entities/ticklist_test.go +++ b/entities/ticklist_test.go @@ -34,14 +34,21 @@ func TestValidateList(t *testing.T) { func TestIsBelowSmallest(t *testing.T) { result := []Tick{lowTick, midTick, highTick} - assert.True(t, IsBelowSmallest(result, utils.MinTick)) - assert.False(t, IsBelowSmallest(result, utils.MinTick+1)) + isBelowSmallest1, _ := IsBelowSmallest(result, utils.MinTick) + assert.True(t, isBelowSmallest1) + + isBelowSmallest2, _ := IsBelowSmallest(result, utils.MinTick+1) + assert.False(t, isBelowSmallest2) } func TestIsAtOrAboveSmallest(t *testing.T) { result := []Tick{lowTick, midTick, highTick} - assert.False(t, IsAtOrAboveLargest(result, utils.MaxTick-2)) - assert.True(t, IsAtOrAboveLargest(result, utils.MaxTick-1)) + + isAtOrAboveLargest1, _ := IsAtOrAboveLargest(result, utils.MaxTick-2) + assert.False(t, isAtOrAboveLargest1) + + isAtOrAboveLargest2, _ := IsAtOrAboveLargest(result, utils.MaxTick-1) + assert.True(t, isAtOrAboveLargest2) } func TestNextInitializedTick(t *testing.T) { @@ -73,12 +80,18 @@ func TestNextInitializedTick(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.want, NextInitializedTick(tt.args.ticks, tt.args.tick, tt.args.lte)) + nextInitializedTick, _ := NextInitializedTick(tt.args.ticks, tt.args.tick, tt.args.lte) + assert.Equal(t, tt.want, nextInitializedTick) }) } - assert.Panics(t, func() { NextInitializedTick(ticks, utils.MinTick, true) }, "blow smallest") - assert.Panics(t, func() { NextInitializedTick(ticks, utils.MaxTick-1, false) }, "at or above largest") + nextInitializedTick1, err1 := NextInitializedTick(ticks, utils.MinTick, true) + assert.Zero(t, nextInitializedTick1, "below smallest") + assert.ErrorIs(t, err1, ErrBelowSmallest) + + nextInitializedTick2, err2 := NextInitializedTick(ticks, utils.MaxTick-1, false) + assert.Zero(t, nextInitializedTick2, "at or above largest") + assert.ErrorIs(t, err2, ErrAtOrAboveLargest) } func TestNextInitializedTickWithinOneWord(t *testing.T) { @@ -122,7 +135,7 @@ func TestNextInitializedTickWithinOneWord(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got0, got1 := NextInitializedTickWithinOneWord(tt.args.ticks, tt.args.tick, tt.args.lte, tt.args.tickSpacing) + got0, got1, _ := NextInitializedTickWithinOneWord(tt.args.ticks, tt.args.tick, tt.args.lte, tt.args.tickSpacing) assert.Equal(t, tt.want0, got0) assert.Equal(t, tt.want1, got1) }) diff --git a/entities/ticklistdataprovider.go b/entities/ticklistdataprovider.go index 32e6084..e6e6429 100644 --- a/entities/ticklistdataprovider.go +++ b/entities/ticklistdataprovider.go @@ -12,10 +12,14 @@ func NewTickListDataProvider(ticks []Tick, tickSpacing int) (*TickListDataProvid return &TickListDataProvider{ticks: ticks}, nil } -func (p *TickListDataProvider) GetTick(tick int) Tick { +func (p *TickListDataProvider) GetTick(tick int) (Tick, error) { return GetTick(p.ticks, tick) } -func (p *TickListDataProvider) NextInitializedTickWithinOneWord(tick int, lte bool, tickSpacing int) (int, bool) { +func (p *TickListDataProvider) NextInitializedTickWithinOneWord(tick int, lte bool, tickSpacing int) (int, bool, error) { return NextInitializedTickWithinOneWord(p.ticks, tick, lte, tickSpacing) } + +func (p *TickListDataProvider) NextInitializedTickIndex(tick int, lte bool) (int, bool, error) { + return NextInitializedTickIndex(p.ticks, tick, lte) +} diff --git a/entities/trade_test.go b/entities/trade_test.go index b778154..78a1e9f 100644 --- a/entities/trade_test.go +++ b/entities/trade_test.go @@ -453,7 +453,7 @@ func TestBestTradeExactIn(t *testing.T) { assert.Equal(t, len(result[0].Swaps[0].Route.Pools), 1) assert.Equal(t, result[0].Swaps[0].Route.TokenPath, []*entities.Token{token0, token2}) assert.True(t, result[0].InputAmount().EqualTo(entities.FromRawAmount(token0, big.NewInt(10000)).Fraction)) - assert.True(t, result[0].OutputAmount().EqualTo(entities.FromRawAmount(token2, big.NewInt(9971)).Fraction)) + assert.True(t, result[0].OutputAmount().EqualTo(entities.FromRawAmount(token2, big.NewInt(9972)).Fraction)) assert.Equal(t, len(result[1].Swaps[0].Route.Pools), 2) assert.Equal(t, result[1].Swaps[0].Route.TokenPath, []*entities.Token{token0, token1, token2}) assert.True(t, result[1].InputAmount().EqualTo(entities.FromRawAmount(token0, big.NewInt(10000)).Fraction)) diff --git a/go.mod b/go.mod index e882bc0..dc5bacf 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,11 @@ module github.com/KyberNetwork/uniswapv3-sdk go 1.18 +replace github.com/daoleno/uniswapv3-sdk v0.4.0 => github.com/KyberNetwork/uniswapv3-sdk v0.4.0 + require ( github.com/daoleno/uniswap-sdk-core v0.1.5 + github.com/daoleno/uniswapv3-sdk v0.4.0 github.com/ethereum/go-ethereum v1.10.20 github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.8.0 diff --git a/go.sum b/go.sum index d95108d..fba3742 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/KyberNetwork/uniswapv3-sdk v0.4.0 h1:hbTeJBFgFqYqYTduGuEnb4JIvCtcmuvBTFuRARJIa1Y= +github.com/KyberNetwork/uniswapv3-sdk v0.4.0/go.mod h1:K+cqy6zkitxxfShghmuoVwjGJWO16FTXAV+dvddXtgw= github.com/VictoriaMetrics/fastcache v1.6.0 h1:C/3Oi3EiBCqufydp1neRZkqcwmEiuRT9c3fqvvgKm5o= github.com/btcsuite/btcd/btcec/v2 v2.2.0 h1:fzn1qaOt32TuLjFlkzYSsBC35Q3KUjT1SwPxiMSCF5k= github.com/btcsuite/btcd/btcec/v2 v2.2.0/go.mod h1:U7MHm051Al6XmscBQ0BoNydpOTsFAn707034b5nY8zU=