Skip to content

Commit

Permalink
Lean: improve handling of arguments
Browse files Browse the repository at this point in the history
1. We handle irrefutable patterns in argument positions by adding an
   equivalent let in the prelude of the function.
2. We handle `atom` types deep inside the type of the arguments. We
   translate the pattern in the argument as above as a variable, and we
   replace the sail type variable by `tuple.2.1` in the return type.
   This path and the variable bound in the let binding in the prelude
   are definitionally equal.
3. As a side fix, we use more unique names for the autogenerated
   variable names by numbering them, which should avoid some spurious
   shadowing.
  • Loading branch information
ineol authored and Alasdair committed Feb 16, 2025
1 parent ecfe107 commit c6458c5
Show file tree
Hide file tree
Showing 25 changed files with 259 additions and 106 deletions.
81 changes: 60 additions & 21 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type context = {
global : global_context;
env : Type_check.env;
(** The typechecking environment of the current function. This environment is reset using [initial_context] when
we start processing a new function. *)
we start processing a new function. Note that we use it to store paths of the form id.x.y.z. *)
kid_id_renames : id option KBindings.t;
(** Associates a kind variable to the corresponding argument of the function, used for implicit arguments. *)
kid_id_renames_rev : kid Bindings.t; (** Inverse of the [kid_id_renames] mapping. *)
Expand Down Expand Up @@ -67,19 +67,22 @@ let doc_kid ctx (Kid_aux (Var x, _) as ki) =

let is_enum env id = match Env.lookup_id id env with Enum _ -> true | _ -> false

let pat_is_plain_binder env (P_aux (p, _)) =
let pat_is_plain_binder ?(suffix = "") env (P_aux (p, _)) =
match p with
| P_id id when not (is_enum env id) -> Some (Some id, None)
| P_id _ -> Some (Some (Id_aux (Id ("id" ^ suffix), Unknown)), None)
| P_typ (typ, P_aux (P_id id, _)) when not (is_enum env id) -> Some (Some id, Some typ)
| P_wild | P_typ (_, P_aux (P_wild, _)) -> Some (None, None)
| P_var (_, _) -> Some (Some (Id_aux (Id "var", Unknown)), None)
| P_app (_, _) -> Some (Some (Id_aux (Id "app", Unknown)), None)
| P_vector _ -> Some (Some (Id_aux (Id "vect", Unknown)), None)
| P_tuple _ -> Some (Some (Id_aux (Id "tuple", Unknown)), None)
| P_list _ -> Some (Some (Id_aux (Id "list", Unknown)), None)
| P_cons (_, _) -> Some (Some (Id_aux (Id "cons", Unknown)), None)
| P_var (_, _) -> Some (Some (Id_aux (Id ("var" ^ suffix), Unknown)), None)
| P_app (_, _) -> Some (Some (Id_aux (Id ("app" ^ suffix), Unknown)), None)
| P_vector _ -> Some (Some (Id_aux (Id ("vect" ^ suffix), Unknown)), None)
| P_tuple _ -> Some (Some (Id_aux (Id ("tuple" ^ suffix), Unknown)), None)
| P_list _ -> Some (Some (Id_aux (Id ("list" ^ suffix), Unknown)), None)
| P_cons (_, _) -> Some (Some (Id_aux (Id ("cons" ^ suffix), Unknown)), None)
| P_lit (L_aux (L_unit, _)) -> Some (Some (Id_aux (Id "_", Unknown)), None)
| P_lit _ -> Some (Some (Id_aux (Id "lit", Unknown)), None)
| P_lit _ -> Some (Some (Id_aux (Id ("lit" ^ suffix), Unknown)), None)
| P_typ _ -> Some (Some (Id_aux (Id ("typ" ^ suffix), Unknown)), None)
| P_struct _ -> Some (Some (Id_aux (Id ("struct_pat" ^ suffix), Unknown)), None)
| _ -> None

(* Copied from the Coq PP *)
Expand Down Expand Up @@ -434,7 +437,8 @@ let rec doc_pat ?(in_vector = false) (P_aux (p, (l, annot)) as pat) =
| P_vector pats -> concat (List.map (doc_pat ~in_vector:true) pats)
| P_vector_concat pats -> separate (string ",") (List.map (doc_pat ~in_vector:true) pats) |> brackets
| P_app (Id_aux (Id "None", _), p) -> string "none"
| P_app (cons, pats) -> doc_id_ctor (fixup_match_id cons) ^^ space ^^ separate_map (string ", ") doc_pat pats
| P_app (cons, pats) ->
string "." ^^ doc_id_ctor (fixup_match_id cons) ^^ space ^^ separate_map (string ", ") doc_pat pats
| P_var (p, _) -> doc_pat p
| P_as (pat, id) -> doc_pat pat
| P_struct (pats, _) ->
Expand Down Expand Up @@ -668,15 +672,47 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
and doc_fexp with_arrow ctx (FE_aux (FE_fexp (field, e), _)) = doc_id_ctor field ^^ string " := " ^^ doc_exp false ctx e

let doc_binder ctx i t =
let paranthesizer =
let parenthesizer =
match t with
| Typ_aux (Typ_app (Id_aux (Id "implicit", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]), _) ->
implicit_parens
| _ -> parens
in
(* Overwrite the id if it's captured *)
let ctx = match captured_typ_var (i, t) with Some (i, ki) -> add_single_kid_id_rename ctx i ki | _ -> ctx in
(ctx, separate space [doc_id_ctor i; colon; doc_typ ctx t] |> paranthesizer)
(ctx, separate space [doc_id_ctor i; colon; doc_typ ctx t] |> parenthesizer)

(** Find all patterns in the arguments of the sail function that Lean cannot handle in a [def],
and add them as let bindings in the prelude of the translation of the function. This assumes
that the pattern is irrefutable. *)
let add_function_pattern ctx fixup_binders (P_aux (pat, pat_annot) as pat_full) var typ =
match pat with
| P_id _ | P_typ (_, P_aux (P_id _, _)) | P_tuple [] | P_lit _ | P_wild -> fixup_binders
| _ ->
fun (E_aux (_, body_annot) as body : tannot exp) ->
E_aux
( E_let (LB_aux (LB_val (pat_full, E_aux (E_id var, (Unknown, mk_tannot ctx.env typ))), pat_annot), body),
body_annot
)
|> fixup_binders

(** Find all the [int] and [atom] types in the function pattern and express them as paths that use the
lean variables, so that we can use them in the return type of the function. For example, see the function
[two_tuples_atom] in the test case test/lean/typquant.sail.
*)
let rec add_path_renamings ~path ctx (P_aux (pat, pat_annot)) (Typ_aux (typ, typ_annot) as typ_full) =
match (pat, typ) with
| P_tuple pats, Typ_tuple typs ->
List.fold_left
(fun (ctx, i) (pat, typ) -> (add_path_renamings ~path:(Printf.sprintf "%s.%i" path i) ctx pat typ, i + 1))
(ctx, 1) (List.combine pats typs)
|> fst
| P_id id, typ -> (
match captured_typ_var (id, typ_full) with
| Some (_, kid) -> add_single_kid_id_rename ctx (mk_id path) kid
| None -> ctx
)
| _ -> ctx

let doc_funcl_init global (FCL_aux (FCL_funcl (id, pexp), annot)) =
let env = env_of_tannot (snd annot) in
Expand All @@ -688,23 +724,26 @@ let doc_funcl_init global (FCL_aux (FCL_funcl (id, pexp), annot)) =
in
let pat, _, exp, _ = destruct_pexp pexp in
let pats, fixup_binders = untuple_args_pat arg_typs pat in
let binders : (id * typ) list =
let binders : (tannot pat * id * typ) list =
pats
|> List.map (fun (pat, typ) ->
match pat_is_plain_binder env pat with
| Some (Some id, _) -> (id, typ)
| Some (None, _) -> (Id_aux (Id "x", l), typ) (* TODO fresh name or wildcard instead of x *)
|> List.mapi (fun i (pat, typ) ->
match pat_is_plain_binder ~suffix:(Printf.sprintf "_%i" i) env pat with
| Some (Some id, _) -> (pat, id, typ)
| Some (None, _) ->
(pat, mk_id ~loc:l (Printf.sprintf "x_%i" i), typ) (* TODO fresh name or wildcard instead of x *)
| _ -> failwith "Argument pattern not translatable yet."
)
in
let ctx = context_init env global in
let ctx, binders =
let ctx, binders, fixup_binders =
List.fold_left
(fun (ctx, bs) (i, t) ->
(fun (ctx, bs, fixup_binders) (pat, i, t) ->
let ctx, d = doc_binder ctx i t in
(ctx, bs @ [d])
let fixup_binders = add_function_pattern ctx fixup_binders pat i t in
let ctx = add_path_renamings ~path:(string_of_id i) ctx pat t in
(ctx, bs @ [d], fixup_binders)
)
(ctx, []) binders
(ctx, [], fixup_binders) binders
in
let typ_quant_comment = doc_typ_quant_in_comment ctx tq_all in
(* Use auto-implicits for type quanitifiers for now and see if this works *)
Expand Down
68 changes: 34 additions & 34 deletions test/lean/SailTinyArm.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -558,13 +558,13 @@ def fmod_int (n : Int) (m : Int) : Int :=
/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
match opt with
| some _ => false
| .some _ => false
| none => true

/-- Type quantifiers: k_a : Type -/
def is_some (opt : (Option k_a)) : Bool :=
match opt with
| some _ => true
| .some _ => true
| none => false

/-- Type quantifiers: k_n : Int -/
Expand Down Expand Up @@ -1850,11 +1850,11 @@ def wMem (addr : (BitVec 64)) (value : (BitVec 64)) : SailM Unit := do
value := (some value)
tag := none }
match (← (sail_mem_write req)) with
| Ok _ => (pure ())
| Err _ => throw Error.Exit
| .Ok _ => (pure ())
| .Err _ => throw Error.Exit

/-- Type quantifiers: x : Nat, x ∈ {32, 64} -/
def sail_address_announce (x : Nat) (x : (BitVec x)) : Unit :=
/-- Type quantifiers: x_0 : Nat, x_0 ∈ {32, 64} -/
def sail_address_announce (x_0 : Nat) (x_1 : (BitVec x_0)) : Unit :=
()

def wMem_Addr (addr : (BitVec 64)) : Unit :=
Expand Down Expand Up @@ -1882,8 +1882,8 @@ def rMem (addr : (BitVec 64)) : SailM (BitVec 64) := do
size := 8
tag := false }
match (← (sail_mem_read req)) with
| Ok (value, _) => (pure value)
| Err _ => throw Error.Exit
| .Ok (value, _) => (pure value)
| .Err _ => throw Error.Exit

/-- Type quantifiers: m : Nat, n : Nat, t : Nat, 0 ≤ t ∧ t ≤ 31, 0 ≤ n ∧ n ≤ 31, 0 ≤ m
∧ m ≤ 31 -/
Expand Down Expand Up @@ -1923,11 +1923,11 @@ def execute_CompareAndBranch (t : Nat) (offset : (BitVec 64)) : SailM Unit := do

def execute (merge_var : ast) : SailM Unit := do
match merge_var with
| LoadRegister (t, n, m) => (execute_LoadRegister t n m)
| StoreRegister (t, n, m) => (execute_StoreRegister t n m)
| ExclusiveOr (d, n, m) => (execute_ExclusiveOr d n m)
| DataMemoryBarrier arg0 => (execute_DataMemoryBarrier arg0)
| CompareAndBranch (t, offset) => (execute_CompareAndBranch t offset)
| .LoadRegister (t, n, m) => (execute_LoadRegister t n m)
| .StoreRegister (t, n, m) => (execute_StoreRegister t n m)
| .ExclusiveOr (d, n, m) => (execute_ExclusiveOr d n m)
| .DataMemoryBarrier arg0 => (execute_DataMemoryBarrier arg0)
| .CompareAndBranch (t, offset) => (execute_CompareAndBranch t offset)

def decode (v__0 : (BitVec 32)) : (Option ast) :=
if (Bool.and (Eq (Sail.BitVec.extractLsb v__0 31 24) (0xF8 : (BitVec 8)))
Expand Down Expand Up @@ -1968,52 +1968,52 @@ def iFetch (addr : (BitVec 64)) : SailM (BitVec 32) := do
size := 4
tag := false }
match (← (sail_mem_read req)) with
| Ok (value, _) => (pure value)
| Err _ => throw Error.Exit
| .Ok (value, _) => (pure value)
| .Err _ => throw Error.Exit

def fetch_and_execute (_ : Unit) : SailM Unit := do
let machineCode ← do (iFetch (← readReg _PC))
let instr := (decode machineCode)
match instr with
| some instr => (execute instr)
| .some instr => (execute instr)
| none => assert false "Unsupported Encoding"

/-- Type quantifiers: k_a : Type, k_b : Type -/
def is_ok (r : (Result k_a k_b)) : Bool :=
match r with
| Ok _ => true
| Err _ => false
| .Ok _ => true
| .Err _ => false

/-- Type quantifiers: k_a : Type, k_b : Type -/
def is_err (r : (Result k_a k_b)) : Bool :=
match r with
| Ok _ => false
| Err _ => true
| .Ok _ => false
| .Err _ => true

/-- Type quantifiers: k_a : Type, k_b : Type -/
def ok_option (r : (Result k_a k_b)) : (Option k_a) :=
match r with
| Ok x => (some x)
| Err _ => none
| .Ok x => (some x)
| .Err _ => none

/-- Type quantifiers: k_a : Type, k_b : Type -/
def err_option (r : (Result k_a k_b)) : (Option k_b) :=
match r with
| Ok _ => none
| Err err => (some err)
| .Ok _ => none
| .Err err => (some err)

/-- Type quantifiers: k_a : Type, k_b : Type -/
def unwrap_or (r : (Result k_a k_b)) (y : k_a) : k_a :=
match r with
| Ok x => x
| Err _ => y
| .Ok x => x
| .Err _ => y

/-- Type quantifiers: k_n : Nat, k_n > 0 -/
def sail_instr_announce (x : (BitVec k_n)) : Unit :=
def sail_instr_announce (x_0 : (BitVec k_n)) : Unit :=
()

/-- Type quantifiers: x : Nat, x ∈ {32, 64} -/
def sail_branch_announce (x : Nat) (x : (BitVec x)) : Unit :=
/-- Type quantifiers: x_0 : Nat, x_0 ∈ {32, 64} -/
def sail_branch_announce (x_0 : Nat) (x_1 : (BitVec x_0)) : Unit :=
()

def sail_reset_registers (_ : Unit) : Unit :=
Expand All @@ -2023,11 +2023,11 @@ def sail_synchronize_registers (_ : Unit) : Unit :=
()

/-- Type quantifiers: k_a : Type -/
def sail_mark_register (x : (RegisterRef RegisterType k_a)) (x : String) : Unit :=
def sail_mark_register (x_0 : (RegisterRef RegisterType k_a)) (x_1 : String) : Unit :=
()

/-- Type quantifiers: k_a : Type, k_b : Type -/
def sail_mark_register_pair (x : (RegisterRef RegisterType k_a)) (x : (RegisterRef RegisterType k_b)) (x : String) : Unit :=
def sail_mark_register_pair (x_0 : (RegisterRef RegisterType k_a)) (x_1 : (RegisterRef RegisterType k_b)) (x_2 : String) : Unit :=
()

/-- Type quantifiers: k_a : Type -/
Expand Down Expand Up @@ -2082,7 +2082,7 @@ def undefined_Explicit_access_kind (_ : Unit) : SailM Explicit_access_kind := do
: Type, k_n > 0 ∧ k_vasize > 0 -/
def mem_read_request_is_exclusive (request : Mem_read_request k_n k_vasize k_pa k_translation_summary k_arch_ak) : Bool :=
match request.access_kind with
| AK_explicit eak =>
| .AK_explicit eak =>
match eak.variety with
| AV_exclusive => true
| _ => false
Expand All @@ -2092,7 +2092,7 @@ def mem_read_request_is_exclusive (request : Mem_read_request k_n k_vasize k_pa
: Type, k_n > 0 ∧ k_vasize > 0 -/
def mem_read_request_is_ifetch (request : Mem_read_request k_n k_vasize k_pa k_translation_summary k_arch_ak) : Bool :=
match request.access_kind with
| AK_ifetch () => true
| .AK_ifetch () => true
| _ => false

def __monomorphize_reads : Bool := false
Expand All @@ -2103,7 +2103,7 @@ def __monomorphize_writes : Bool := false
: Type, k_n > 0 ∧ k_vasize > 0 -/
def mem_write_request_is_exclusive (request : Mem_write_request k_n k_vasize k_pa k_translation_summary k_arch_ak) : Bool :=
match request.access_kind with
| AK_explicit eak =>
| .AK_explicit eak =>
match eak.variety with
| AV_exclusive => true
| _ => false
Expand Down
4 changes: 2 additions & 2 deletions test/lean/bitfield.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ def fmod_int (n : Int) (m : Int) : Int :=
/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
match opt with
| some _ => false
| .some _ => false
| none => true

/-- Type quantifiers: k_a : Type -/
def is_some (opt : (Option k_a)) : Bool :=
match opt with
| some _ => true
| .some _ => true
| none => false

/-- Type quantifiers: k_n : Int -/
Expand Down
4 changes: 2 additions & 2 deletions test/lean/bitvec_operation.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def fmod_int (n : Int) (m : Int) : Int :=
/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
match opt with
| some _ => false
| .some _ => false
| none => true

/-- Type quantifiers: k_a : Type -/
def is_some (opt : (Option k_a)) : Bool :=
match opt with
| some _ => true
| .some _ => true
| none => false

/-- Type quantifiers: k_n : Int -/
Expand Down
4 changes: 2 additions & 2 deletions test/lean/enum.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def fmod_int (n : Int) (m : Int) : Int :=
/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
match opt with
| some _ => false
| .some _ => false
| none => true

/-- Type quantifiers: k_a : Type -/
def is_some (opt : (Option k_a)) : Bool :=
match opt with
| some _ => true
| .some _ => true
| none => false

/-- Type quantifiers: k_n : Int -/
Expand Down
4 changes: 2 additions & 2 deletions test/lean/errors.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ def fmod_int (n : Int) (m : Int) : Int :=
/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
match opt with
| some _ => false
| .some _ => false
| none => true

/-- Type quantifiers: k_a : Type -/
def is_some (opt : (Option k_a)) : Bool :=
match opt with
| some _ => true
| .some _ => true
| none => false

/-- Type quantifiers: k_n : Int -/
Expand Down
6 changes: 3 additions & 3 deletions test/lean/extern.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def spc_forwards (_ : Unit) : String :=
def spc_forwards_matches (_ : Unit) : Bool :=
true

def spc_backwards (x : String) : Unit :=
def spc_backwards (x_0 : String) : Unit :=
()

def spc_backwards_matches (s : String) : Bool :=
Expand All @@ -38,7 +38,7 @@ def opt_spc_forwards (_ : Unit) : String :=
def opt_spc_forwards_matches (_ : Unit) : Bool :=
true

def opt_spc_backwards (x : String) : Unit :=
def opt_spc_backwards (x_0 : String) : Unit :=
()

def opt_spc_backwards_matches (s : String) : Bool :=
Expand All @@ -50,7 +50,7 @@ def def_spc_forwards (_ : Unit) : String :=
def def_spc_forwards_matches (_ : Unit) : Bool :=
true

def def_spc_backwards (x : String) : Unit :=
def def_spc_backwards (x_0 : String) : Unit :=
()

def def_spc_backwards_matches (s : String) : Bool :=
Expand Down
Loading

0 comments on commit c6458c5

Please sign in to comment.