Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some arithmetic primops #13

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions Clear/PrimOps.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,28 @@ abbrev fromBool := Bool.toUInt256

def evmAddMod (a b c : UInt256) : UInt256 :=
if c = 0 then 0 else
Fin.mod (a + b) c
-- "All intermediate calculations of this operation are **not** subject to the 2^256 modulo."
Fin.mod (a.val + b.val) c

def evmMulMod (a b c : UInt256) : UInt256 :=
if c = 0 then 0 else
Fin.mod (a * b) c
-- "All intermediate calculations of this operation are **not** subject to the 2^256 modulo."
Fin.mod (a.val * b.val) c

def evmExp (a b : UInt256) : UInt256 :=
a ^ b.val

def evmMod (x y : UInt256) : UInt256 :=
if y == 0 then 0 else x % y

set_option linter.unusedVariables false in
def primCall (s : State) : PrimOp → List Literal → State × List Literal
| .Add, [a,b] => (s, [a + b])
| .Sub, [a,b] => (s, [a - b])
| .Mul, [a,b] => (s, [a * b])
| .Div, [a,b] => (s, [a / b])
| .Sdiv, [a,b] => (s, [UInt256.sdiv a b])
| .Mod, [a,b] => (s, [Fin.mod a b])
| .Mod, [a,b] => (s, [evmMod a b])
| .Smod, [a,b] => (s, [UInt256.smod a b])
| .Addmod, [a,b,c] => (s, [evmAddMod a b c])
| .Mulmod, [a,b,c] => (s, [evmMulMod a b c])
Expand Down Expand Up @@ -109,7 +114,7 @@ lemma EVMSub' : primCall s .Sub [a,b] = (s, [a -
lemma EVMMul' : primCall s .Mul [a,b] = (s, [a * b]) := rfl
lemma EVMDiv' : primCall s .Div [a,b] = (s, [a / b]) := rfl
lemma EVMSdiv' : primCall s .Sdiv [a,b] = (s, [UInt256.sdiv a b]) := rfl
lemma EVMMod' : primCall s .Mod [a,b] = (s, [Fin.mod a b]) := rfl
lemma EVMMod' : primCall s .Mod [a,b] = (s, [evmMod a b]) := rfl
lemma EVMSmod' : primCall s .Smod [a,b] = (s, [UInt256.smod a b]) := rfl
lemma EVMAddmod' : primCall s .Addmod [a,b,c] = (s, [evmAddMod a b c]) := rfl
lemma EVMMulmod' : primCall s .Mulmod [a,b,c] = (s, [evmMulMod a b c]) := rfl
Expand Down
17 changes: 9 additions & 8 deletions Clear/UInt256.lean
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def eq0 (a : UInt256) : Bool := a = 0
def lnot (a : UInt256) : UInt256 := (UInt256.size - 1) - a

def byteAt (a b : UInt256) : UInt256 :=
b >>> (31 - a) * 8 <<< 248
b >>> (.ofNat ((31 - a.val) * 8)) &&& 0xFF

def sgn (a : UInt256) : UInt256 :=
if a ≥ 2 ^ 255 then -1
Expand All @@ -76,14 +76,15 @@ def sdiv (a b : UInt256) : UInt256 :=
else a / b

def smod (a b : UInt256) : UInt256 :=
if a ≥ 2 ^ 255 then
if b ≥ 2 ^ 255 then
Fin.mod (abs a) (abs b)
else (-1) * Fin.mod (abs a) b
if b == 0 then 0
else
if b ≥ 2 ^ 255 then
(-1) * Fin.mod a (abs b)
else Fin.mod a b
let sgnA := if 2 ^ 255 <= a then -1 else 1
let sgnB := if 2 ^ 255 <= b then -1 else 1
let mask : UInt256 := .ofNat (2 ^ 256 - 1 : ℕ)
let absA := if sgnA == 1 then a else - (.xor a mask + 1)
let absB := if sgnB == 1 then b else - (.xor b mask + 1)
sgnA * (absA % absB)


def slt (a b : UInt256) : Bool :=
if a ≥ 2 ^ 255 then
Expand Down