Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GSW-1045 feat: update uint256 overflow calcualtion logic #215

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 105 additions & 79 deletions _deploy/p/demo/gnoswap/uint256/gs_overflow_calculation.gno
Original file line number Diff line number Diff line change
@@ -1,102 +1,128 @@
// REF: https://github.com/Uniswap/solidity-lib/blob/master/contracts/libraries/FullMath.sol
// REF: https://github.com/Uniswap/v3-core/blob/main/contracts/libraries/FullMath.sol
package uint256

const (
MAX_UINT256 = "115792089237316195423570985008687907853269984665640564039457584007913129639935"
)

func fullMul(
x *Uint,
y *Uint,
) (*Uint, *Uint) { // l, h
mm := new(Uint).MulMod(x, y, MustFromDecimal(MAX_UINT256))

l := new(Uint).Mul(x, y)
h := new(Uint).Sub(mm, l)

if mm.Lt(l) {
h = new(Uint).Sub(h, One())
}

return l, h
}

func fullDiv(
l *Uint,
h *Uint,
d *Uint,
func MulDiv(
a, b, denominator *Uint,
) *Uint {
// uint256 pow2 = d & -d;
// d
_negD := new(Uint).Neg(d)
pow2 := new(Uint).And(d, _negD)
d = new(Uint).Div(d, pow2)
l = new(Uint).Div(l, pow2)

_negPow2 := new(Uint).Neg(pow2)

value1 := new(Uint).Div(_negPow2, pow2) // (-pow2) / pow2
value2 := new(Uint).Add(value1, One()) // (-pow2) / pow2 + 1)
value3 := new(Uint).Mul(h, value2) // h * ((-pow2) / pow2 + 1);
l = new(Uint).Add(l, value3)

r := One()
for i := 0; i < 8; i++ {
value1 := new(Uint).Mul(d, r) // d * r
value2 := new(Uint).Sub(NewUint(2), value1) // 2 - ( d * r )
r = new(Uint).Mul(r, value2) // r *= 2 - d * r;
prod0 := Zero()
prod1 := Zero()

{
mm := new(Uint).MulMod(a, b, new(Uint).Not(Zero()))
prod0 = new(Uint).Mul(a, b)

ltBool := mm.Lt(prod0)
ltUint := Zero()
if ltBool {
ltUint = One()
}
prod1 = new(Uint).Sub(new(Uint).Sub(mm, prod0), ltUint)
}
res := new(Uint).Mul(l, r)
return res
}

func MulDiv(
x *Uint,
y *Uint,
d *Uint,
) *Uint {
l, h := fullMul(x, y)
mm := new(Uint).MulMod(x, y, d)
// Handle non-overflow cases, 256 by 256 division
if prod1.IsZero() {
if !(denominator.Gt(Zero())) { // require(denominator > 0);
panic("denominator > 0")
}

if mm.Gt(l) {
h = new(Uint).Sub(h, One())
result := new(Uint).Div(prod0, denominator)
return result
}
l = new(Uint).Sub(l, mm)

if h.IsZero() {
return new(Uint).Div(l, d)
// Make sure the result is less than 2**256.
// Also prevents denominator == 0
if !(denominator.Gt(prod1)) { // require(denominator > prod1)
panic("denominator > prod1")
}

if !(h.Lt(d)) {
panic("FULLDIV_OVERFLOW")
}
///////////////////////////////////////////////
// 512 by 256 division.
///////////////////////////////////////////////

// Make division exact by subtracting the remainder from [prod1 prod0]
// Compute remainder using mulmod
remainder := Zero()
remainder = new(Uint).MulMod(a, b, denominator)

return fullDiv(l, h, d)
// Subtract 256 bit number from 512 bit number
gtBool := remainder.Gt(prod0)
gtUint := Zero()
if gtBool {
gtUint = One()
}
prod1 = new(Uint).Sub(prod1, gtUint)
prod0 = new(Uint).Sub(prod0, remainder)

// Factor powers of two out of denominator
// Compute largest power of two divisor of denominator.
// Always >= 1.
twos := Zero()
twos = new(Uint).And(new(Uint).Neg(denominator), denominator)

// Divide denominator by power of two
denominator = new(Uint).Div(denominator, twos)

// Divide [prod1 prod0] by the factors of two
prod0 = new(Uint).Div(prod0, twos)

// Shift in bits from prod1 into prod0. For this we need
// to flip `twos` such that it is 2**256 / twos.
// If twos is zero, then it becomes one
twos = new(Uint).Add(
new(Uint).Div(
new(Uint).Sub(Zero(), twos),
twos,
),
One(),
)
prod0 = new(Uint).Or(prod0, new(Uint).Mul(prod1, twos))

// Invert denominator mod 2**256
// Now that denominator is an odd number, it has an inverse
// modulo 2**256 such that denominator * inv = 1 mod 2**256.
// Compute the inverse by starting with a seed that is correct
// correct for four bits. That is, denominator * inv = 1 mod 2**4
inv := Zero()
inv = new(Uint).Mul(NewUint(3), denominator)
inv = new(Uint).Xor(inv, NewUint(2))

// Now use Newton-Raphson iteration to improve the precision.
// Thanks to Hensel's lifting lemma, this also works in modular
// arithmetic, doubling the correct bits in each step.

inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**8
inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**16
inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**32
inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**64
inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**128
inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**256

// Because the division is now exact we can divide by multiplying
// with the modular inverse of denominator. This will give us the
// correct result modulo 2**256. Since the precoditions guarantee
// that the outcome is less than 2**256, this is the final result.
// We don't need to compute the high bits of the result and prod1
// is no longer required.
result := new(Uint).Mul(prod0, inv)
return result
}

func DivRoundingUp(
x *Uint,
y *Uint,
func MulDivRoundingUp(
a, b, denominator *Uint,
) *Uint {
div := new(Uint).Div(x, y)
result := MulDiv(a, b, denominator)

mod := new(Uint).Mod(x, y)
return new(Uint).Add(div, gt(mod, Zero()))
}
if new(Uint).MulMod(a, b, denominator).Gt(Zero()) {
if !(result.Lt(MustFromDecimal(MAX_UINT256))) { // require(result < MAX_UINT256)
panic("result < MAX_UINT256")
}

// HELPERs
func lt(x, y *Uint) *Uint {
if x.Lt(y) {
return One()
} else {
return Zero()
result = new(Uint).Add(result, One())
}
}

func gt(x, y *Uint) *Uint {
if x.Gt(y) {
return One()
} else {
return Zero()
}
return result
}
25 changes: 25 additions & 0 deletions _deploy/p/demo/gnoswap/uint256/gs_overflow_calculation_test.gno
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package uint256

import "testing"

func TestMulDiv(t *testing.T) {
a := MustFromDecimal("3961170441225674086664416884948992")
b := MustFromDecimal("1461300573427867316490840528175048480732148624513")
c := MustFromDecimal("1461300573427867316570072651998408279850435624081")

z := MulDiv(a, b, c)
if z.ToString() != "3961170441225674086449641121090634" {
t.Errorf("expected 3961170441225674086449641121090634, got %s", z.ToString())
}
}

func TestMulDivRoundingUp(t *testing.T) {
a := MustFromDecimal("3961170441225674086664416884948992")
b := MustFromDecimal("1461300573427867316490840528175048480732148624513")
c := MustFromDecimal("1461300573427867316570072651998408279850435624081")

z := MulDivRoundingUp(a, b, c)
if z.ToString() != "3961170441225674086449641121090635" {
t.Errorf("expected 3961170441225674086449641121090635, got %s", z.ToString())
}
}
Loading