Skip to content

Commit

Permalink
Lean: fix some extern lean functions
Browse files Browse the repository at this point in the history
1. Fix vectorUpdate being conflated with bitvectorUpdate
2. Use boolean operators instead of Prop connectives, the
   lean coertion mechanism seemed confused in big monadic
   expressions.
  • Loading branch information
ineol authored and Alasdair committed Feb 16, 2025
1 parent 75668d7 commit 8f90ea6
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 79 deletions.
8 changes: 4 additions & 4 deletions lib/flow.sail
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ therefore be included in just about every Sail specification.

*/

val eq_unit = pure { lean : "Eq", _ : "eq_unit" } : (unit, unit) -> bool(true)
val eq_unit = pure { lean : "BEq.beq", _ : "eq_unit" } : (unit, unit) -> bool(true)
function eq_unit(_, _) = true

val eq_bit = pure { lem : "eq", lean : "Eq", _ : "eq_bit" } : (bit, bit) -> bool
Expand All @@ -70,11 +70,11 @@ val and_bool_no_flow = pure {coq: "andb", lean: "Bool.and", _: "and_bool"} : (bo

val or_bool = pure {coq: "orb", lean: "Bool.or", _: "or_bool"} : forall ('p : Bool) ('q : Bool). (bool('p), bool('q)) -> bool('p | 'q)

val eq_int = pure {ocaml: "eq_int", interpreter: "eq_int", lem: "eq", coq: "Z.eqb", lean: "Eq", _: "eq_int"} : forall 'n 'm. (int('n), int('m)) -> bool('n == 'm)
val eq_int = pure {ocaml: "eq_int", interpreter: "eq_int", lem: "eq", coq: "Z.eqb", lean: "BEq.beq", _: "eq_int"} : forall 'n 'm. (int('n), int('m)) -> bool('n == 'm)

val eq_bool = pure {ocaml: "eq_bool", interpreter: "eq_bool", lem: "eq", coq: "Bool.eqb", lean: "Eq", _: "eq_bool"} : (bool, bool) -> bool
val eq_bool = pure {ocaml: "eq_bool", interpreter: "eq_bool", lem: "eq", coq: "Bool.eqb", lean: "BEq.beq", _: "eq_bool"} : (bool, bool) -> bool

val neq_int = pure {lem: "neq", lean: "Ne"} : forall 'n 'm. (int('n), int('m)) -> bool('n != 'm)
val neq_int = pure {lem: "neq", lean: "bne"} : forall 'n 'm. (int('n), int('m)) -> bool('n != 'm)
function neq_int (x, y) = not_bool(eq_int(x, y))

val neq_bool : (bool, bool) -> bool
Expand Down
4 changes: 2 additions & 2 deletions lib/generic_equality.sail
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ $define _GENERIC_EQUALITY

$include <flow.sail>

val eq_anything = pure {ocaml: "(fun (x, y) -> x = y)", lem: "eq", coq: "generic_eq", lean: "Eq", _: "eq_anything"} : forall ('a : Type). ('a, 'a) -> bool
val eq_anything = pure {ocaml: "(fun (x, y) -> x = y)", lem: "eq", coq: "generic_eq", lean: "BEq.beq", _: "eq_anything"} : forall ('a : Type). ('a, 'a) -> bool

overload operator == = {eq_anything}

val neq_anything = pure {
lem: "neq",
lean: "Ne",
lean: "bne",
coq: "generic_neq"} : forall ('a : Type). ('a, 'a) -> bool

function neq_anything(x, y) = not_bool(eq_anything(y, x))
Expand Down
2 changes: 1 addition & 1 deletion lib/string.sail
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ $include <arith.sail>
val eq_string = pure {
lem: "eq",
coq: "generic_eq",
lean: "Eq",
lean: "BEq.beq",
_: "eq_string"} : (string, string) -> bool

overload operator == = {eq_string}
Expand Down
11 changes: 7 additions & 4 deletions lib/vector.sail
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ val eq_bits = pure {
interpreter: "eq_list",
lem: "eq_vec",
coq: "eq_vec",
lean: "Eq",
lean: "BEq.beq",
_: "eq_bits"
} : forall 'n. (bits('n), bits('n)) -> bool

Expand All @@ -70,7 +70,7 @@ val neq_bits = pure {
lem: "neq_vec",
coq: "neq_vec",
c: "neq_bits",
lean: "Ne"
lean: "bne"
} : forall 'n. (bits('n), bits('n)) -> bool

function neq_bits(x, y) = not_bool(eq_bits(x, y))
Expand All @@ -91,7 +91,10 @@ val vector_length = pure {
_: "length"
} : forall 'n ('a : Type). vector('n, 'a) -> int('n)

val vector_init = pure "vector_init" : forall 'n ('a : Type), 'n >= 0. (implicit('n), 'a) -> vector('n, 'a)
val vector_init = pure {
lean: "Vector.mkVector",
_: "vector_init"
} : forall 'n ('a : Type), 'n >= 0. (implicit('n), 'a) -> vector('n, 'a)

overload length = {bitvector_length, vector_length}

Expand Down Expand Up @@ -205,7 +208,7 @@ val bitvector_update = pure {
interpreter: "update",
lem: "update_vec_dec",
coq: "update_vec_dec",
lean: "vectorUpdate",
lean: "bitvectorUpdate",
_: "vector_update"
} : forall 'n 'm, 0 <= 'm < 'n. (bits('n), int('m), bit) -> bits('n)
$else
Expand Down
2 changes: 2 additions & 0 deletions src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def reg_deref (reg_ref : @RegisterRef Register RegisterType α) : PreSailM Regis

def vectorAccess [Inhabited α] (v : Vector α m) (n : Nat) := v[n]!

def bitvectorUpdate (v : BitVec m) (n : Nat) (b : Bool) := v[n]! = b

def vectorUpdate (v : Vector α m) (n : Nat) (a : α) := v.set! n a

def assert (p : Bool) (s : String) : PreSailM RegisterType c ue Unit :=
Expand Down
30 changes: 15 additions & 15 deletions test/lean/SailTinyArm.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1792,13 +1792,13 @@ def GPRs : (Vector (RegisterRef RegisterType (BitVec 64)) 31) :=

/-- Type quantifiers: n : Nat, 0 ≤ n ∧ n ≤ 31 -/
def wX (n : Nat) (value : (BitVec 64)) : SailM Unit := do
if (Ne n 31)
if (bne n 31)
then writeRegRef (vectorAccess GPRs n) value
else (pure ())

/-- Type quantifiers: n : Nat, 0 ≤ n ∧ n ≤ 31 -/
def rX (n : Nat) : SailM (BitVec 64) := do
if (Ne n 31)
if (bne n 31)
then (reg_deref (vectorAccess GPRs n))
else (pure (0x0000000000000000 : (BitVec 64)))

Expand All @@ -1812,11 +1812,11 @@ def decodeLoadStoreRegister (opc : (BitVec 2)) (Rm : (BitVec 5)) (option_v : (Bi
let t : reg_index := (BitVec.toNat Rt)
let n : reg_index := (BitVec.toNat Rn)
let m : reg_index := (BitVec.toNat Rm)
if (Bool.or (Ne option_v (0b011 : (BitVec 3))) (Eq S 1#1))
if (Bool.or (bne option_v (0b011 : (BitVec 3))) (Eq S 1#1))
then none
else if (Eq opc (0b00 : (BitVec 2)))
else if (BEq.beq opc (0b00 : (BitVec 2)))
then (some (LoadRegister (t, n, m)))
else if (Eq opc (0b01 : (BitVec 2)))
else if (BEq.beq opc (0b01 : (BitVec 2)))
then (some (StoreRegister (t, n, m)))
else none

Expand All @@ -1826,12 +1826,12 @@ def decodeExclusiveOr (sf : (BitVec 1)) (shift : (BitVec 2)) (N : (BitVec 1)) (R
let m : reg_index := (BitVec.toNat Rm)
if (Bool.and (Eq sf 0#1) (Eq (BitVec.access imm6 5) 1#1))
then none
else if (Ne imm6 (0b000000 : (BitVec 6)))
else if (bne imm6 (0b000000 : (BitVec 6)))
then none
else (some (ExclusiveOr (d, n, m)))

def decodeDataMemoryBarrier (CRm : (BitVec 4)) : (Option ast) :=
if (Ne CRm (0xF : (BitVec 4)))
if (bne CRm (0xF : (BitVec 4)))
then none
else (some (DataMemoryBarrier ()))

Expand Down Expand Up @@ -1917,7 +1917,7 @@ def execute_DataMemoryBarrier (_ : Unit) : SailM Unit := do
/-- Type quantifiers: t : Nat, 0 ≤ t ∧ t ≤ 31 -/
def execute_CompareAndBranch (t : Nat) (offset : (BitVec 64)) : SailM Unit := do
let operand ← do (rX t)
if (Eq operand (0x0000000000000000 : (BitVec 64)))
if (BEq.beq operand (0x0000000000000000 : (BitVec 64)))
then let base ← do (rPC ())
let addr := (HAdd.hAdd base offset)
(wPC addr)
Expand All @@ -1932,17 +1932,17 @@ def execute (merge_var : ast) : SailM Unit := do
| .CompareAndBranch (t, offset) => (execute_CompareAndBranch t offset)

def decode (v__0 : (BitVec 32)) : (Option ast) :=
if (Bool.and (Eq (Sail.BitVec.extractLsb v__0 31 24) (0xF8 : (BitVec 8)))
(Bool.and (Eq (Sail.BitVec.extractLsb v__0 21 21) (0b1 : (BitVec 1)))
(Eq (Sail.BitVec.extractLsb v__0 11 10) (0b10 : (BitVec 2)))))
if (Bool.and (BEq.beq (Sail.BitVec.extractLsb v__0 31 24) (0xF8 : (BitVec 8)))
(Bool.and (BEq.beq (Sail.BitVec.extractLsb v__0 21 21) (0b1 : (BitVec 1)))
(BEq.beq (Sail.BitVec.extractLsb v__0 11 10) (0b10 : (BitVec 2)))))
then let S := (BitVec.access v__0 12)
let option_v : (BitVec 3) := (Sail.BitVec.extractLsb v__0 15 13)
let opc : (BitVec 2) := (Sail.BitVec.extractLsb v__0 23 22)
let Rt : (BitVec 5) := (Sail.BitVec.extractLsb v__0 4 0)
let Rn : (BitVec 5) := (Sail.BitVec.extractLsb v__0 9 5)
let Rm : (BitVec 5) := (Sail.BitVec.extractLsb v__0 20 16)
(decodeLoadStoreRegister opc Rm option_v S Rn Rt)
else if (Eq (Sail.BitVec.extractLsb v__0 30 24) (0b1001010 : (BitVec 7)))
else if (BEq.beq (Sail.BitVec.extractLsb v__0 30 24) (0b1001010 : (BitVec 7)))
then let sf := (BitVec.access v__0 31)
let N := (BitVec.access v__0 21)
let shift : (BitVec 2) := (Sail.BitVec.extractLsb v__0 23 22)
Expand All @@ -1951,11 +1951,11 @@ def decode (v__0 : (BitVec 32)) : (Option ast) :=
let Rm : (BitVec 5) := (Sail.BitVec.extractLsb v__0 20 16)
let Rd : (BitVec 5) := (Sail.BitVec.extractLsb v__0 4 0)
(decodeExclusiveOr sf shift N Rm imm6 Rn Rd)
else if (Bool.and (Eq (Sail.BitVec.extractLsb v__0 31 12) (0xD5033 : (BitVec 20)))
(Eq (Sail.BitVec.extractLsb v__0 7 0) (0xBF : (BitVec 8))))
else if (Bool.and (BEq.beq (Sail.BitVec.extractLsb v__0 31 12) (0xD5033 : (BitVec 20)))
(BEq.beq (Sail.BitVec.extractLsb v__0 7 0) (0xBF : (BitVec 8))))
then let CRm : (BitVec 4) := (Sail.BitVec.extractLsb v__0 11 8)
(decodeDataMemoryBarrier CRm)
else if (Eq (Sail.BitVec.extractLsb v__0 31 24) (0xB4 : (BitVec 8)))
else if (BEq.beq (Sail.BitVec.extractLsb v__0 31 24) (0xB4 : (BitVec 8)))
then let imm19 : (BitVec 19) := (Sail.BitVec.extractLsb v__0 23 5)
let Rt : (BitVec 5) := (Sail.BitVec.extractLsb v__0 4 0)
(decodeCompareAndBranch imm19 Rt)
Expand Down
4 changes: 2 additions & 2 deletions test/lean/bitvec_operation.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def concat_str_dec (str : String) (x : Int) : String :=
(HAppend.hAppend str (Int.repr x))

def bitvector_eq (x : (BitVec 16)) (y : (BitVec 16)) : Bool :=
(Eq x y)
(BEq.beq x y)

def bitvector_neq (x : (BitVec 16)) (y : (BitVec 16)) : Bool :=
(Ne x y)
(bne x y)

def bitvector_len (x : (BitVec 16)) : Nat :=
(Sail.BitVec.length x)
Expand Down
16 changes: 8 additions & 8 deletions test/lean/extern.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def spc_backwards (x_0 : String) : Unit :=

def spc_backwards_matches (s : String) : Bool :=
let len := (String.length s)
(Bool.and (Eq (String.leadingSpaces s) len) (GT.gt len 0))
(Bool.and (BEq.beq (String.leadingSpaces s) len) (GT.gt len 0))

def opt_spc_forwards (_ : Unit) : String :=
""
Expand All @@ -44,7 +44,7 @@ def opt_spc_backwards (x_0 : String) : Unit :=
()

def opt_spc_backwards_matches (s : String) : Bool :=
(Eq (String.leadingSpaces s) (String.length s))
(BEq.beq (String.leadingSpaces s) (String.length s))

def def_spc_forwards (_ : Unit) : String :=
" "
Expand All @@ -56,7 +56,7 @@ def def_spc_backwards (x_0 : String) : Unit :=
()

def def_spc_backwards_matches (s : String) : Bool :=
(Eq (String.leadingSpaces s) (String.length s))
(BEq.beq (String.leadingSpaces s) (String.length s))

def sep_forwards (arg_ : Unit) : String :=
match arg_ with
Expand Down Expand Up @@ -129,7 +129,7 @@ def extern_abs_int_plain (_ : Unit) : Int :=
(Sail.Int.intAbs x)

def extern_eq_unit (_ : Unit) : Bool :=
(Eq () ())
(BEq.beq () ())

def extern_eq_bit (_ : Unit) : Bool :=
(Eq 0#1 1#1)
Expand All @@ -147,10 +147,10 @@ def extern_or (_ : Unit) : Bool :=
(Bool.or true false)

def extern_eq_bool (_ : Unit) : Bool :=
(Eq true false)
(BEq.beq true false)

def extern_eq_int (_ : Unit) : Bool :=
(Eq 5 4)
(BEq.beq 5 4)

def extern_lteq_int (_ : Unit) : Bool :=
(LE.le 5 4)
Expand All @@ -165,7 +165,7 @@ def extern_gt_int (_ : Unit) : Bool :=
(GT.gt 5 4)

def extern_eq_anything (_ : Unit) : Bool :=
(Eq true true)
(BEq.beq true true)

def extern_vector_update (_ : Unit) : (Vector Int 5) :=
(vectorUpdate #v[23, 23, 23, 23, 23] 2 42)
Expand All @@ -186,7 +186,7 @@ def extern_string_startswith (_ : Unit) : Bool :=
(String.startsWith "Hello, world" "Hello")

def extern_eq_string (_ : Unit) : Bool :=
(Eq "Hello" "world")
(BEq.beq "Hello" "world")

def extern_concat_str (_ : Unit) : String :=
(HAppend.hAppend "Hello, " "world")
Expand Down
6 changes: 3 additions & 3 deletions test/lean/ite.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def concat_str_dec (str : String) (x : Int) : String :=

/-- Type quantifiers: n : Nat, 0 ≤ n -/
def elif (n : Nat) : (BitVec 1) :=
if (Eq n 0)
if (BEq.beq n 0)
then 1#1
else if (Eq n 1)
else if (BEq.beq n 1)
then 1#1
else 0#1

Expand All @@ -118,7 +118,7 @@ def monadic_in_out (n : Nat) : SailM Nat := do

/-- Type quantifiers: n : Nat, 0 ≤ n -/
def monadic_lines (n : Nat) : SailM Unit := do
let b := (Eq n 0)
let b := (BEq.beq n 0)
if b
then writeReg R n
writeReg B b
Expand Down
42 changes: 21 additions & 21 deletions test/lean/mapping.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def size_bits_forwards (arg_ : word_width) : (BitVec 2) :=

def size_bits_backwards (arg_ : (BitVec 2)) : word_width :=
let b__0 := arg_
if (Eq b__0 (0b00 : (BitVec 2)))
if (BEq.beq b__0 (0b00 : (BitVec 2)))
then BYTE
else if (Eq b__0 (0b01 : (BitVec 2)))
else if (BEq.beq b__0 (0b01 : (BitVec 2)))
then HALF
else if (Eq b__0 (0b10 : (BitVec 2)))
else if (BEq.beq b__0 (0b10 : (BitVec 2)))
then WORD
else DOUBLE

Expand All @@ -135,13 +135,13 @@ def size_bits_forwards_matches (arg_ : word_width) : Bool :=

def size_bits_backwards_matches (arg_ : (BitVec 2)) : Bool :=
let b__0 := arg_
if (Eq b__0 (0b00 : (BitVec 2)))
if (BEq.beq b__0 (0b00 : (BitVec 2)))
then true
else if (Eq b__0 (0b01 : (BitVec 2)))
else if (BEq.beq b__0 (0b01 : (BitVec 2)))
then true
else if (Eq b__0 (0b10 : (BitVec 2)))
else if (BEq.beq b__0 (0b10 : (BitVec 2)))
then true
else if (Eq b__0 (0b11 : (BitVec 2)))
else if (BEq.beq b__0 (0b11 : (BitVec 2)))
then true
else false

Expand All @@ -154,11 +154,11 @@ def size_bits2_forwards (arg_ : word_width) : (BitVec 2) :=

def size_bits2_backwards (arg_ : (BitVec 2)) : word_width :=
let b__0 := arg_
if (Eq b__0 (0b00 : (BitVec 2)))
if (BEq.beq b__0 (0b00 : (BitVec 2)))
then BYTE
else if (Eq b__0 (0b01 : (BitVec 2)))
else if (BEq.beq b__0 (0b01 : (BitVec 2)))
then HALF
else if (Eq b__0 (0b10 : (BitVec 2)))
else if (BEq.beq b__0 (0b10 : (BitVec 2)))
then WORD
else DOUBLE

Expand All @@ -171,13 +171,13 @@ def size_bits2_forwards_matches (arg_ : word_width) : Bool :=

def size_bits2_backwards_matches (arg_ : (BitVec 2)) : Bool :=
let b__0 := arg_
if (Eq b__0 (0b00 : (BitVec 2)))
if (BEq.beq b__0 (0b00 : (BitVec 2)))
then true
else if (Eq b__0 (0b01 : (BitVec 2)))
else if (BEq.beq b__0 (0b01 : (BitVec 2)))
then true
else if (Eq b__0 (0b10 : (BitVec 2)))
else if (BEq.beq b__0 (0b10 : (BitVec 2)))
then true
else if (Eq b__0 (0b11 : (BitVec 2)))
else if (BEq.beq b__0 (0b11 : (BitVec 2)))
then true
else false

Expand All @@ -190,11 +190,11 @@ def size_bits3_forwards (arg_ : word_width) : (BitVec 2) :=

def size_bits3_backwards (arg_ : (BitVec 2)) : word_width :=
let b__0 := arg_
if (Eq b__0 (0b00 : (BitVec 2)))
if (BEq.beq b__0 (0b00 : (BitVec 2)))
then BYTE
else if (Eq b__0 (0b01 : (BitVec 2)))
else if (BEq.beq b__0 (0b01 : (BitVec 2)))
then HALF
else if (Eq b__0 (0b10 : (BitVec 2)))
else if (BEq.beq b__0 (0b10 : (BitVec 2)))
then WORD
else DOUBLE

Expand All @@ -207,13 +207,13 @@ def size_bits3_forwards_matches (arg_ : word_width) : Bool :=

def size_bits3_backwards_matches (arg_ : (BitVec 2)) : Bool :=
let b__0 := arg_
if (Eq b__0 (0b00 : (BitVec 2)))
if (BEq.beq b__0 (0b00 : (BitVec 2)))
then true
else if (Eq b__0 (0b01 : (BitVec 2)))
else if (BEq.beq b__0 (0b01 : (BitVec 2)))
then true
else if (Eq b__0 (0b10 : (BitVec 2)))
else if (BEq.beq b__0 (0b10 : (BitVec 2)))
then true
else if (Eq b__0 (0b11 : (BitVec 2)))
else if (BEq.beq b__0 (0b11 : (BitVec 2)))
then true
else false

Expand Down
Loading

0 comments on commit 8f90ea6

Please sign in to comment.