Skip to content

Commit

Permalink
Fixing calls to effectful functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lfrenot authored and Alasdair committed Jan 29, 2025
1 parent db87e30 commit 5740618
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 33 deletions.
40 changes: 19 additions & 21 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,7 @@ type context = {
kid_id_renames_rev : kid Bindings.t; (** Inverse of the [kid_id_renames] mapping. *)
}

let initial_context env =
{
global = { effect_info = Effects.empty_side_effect_info };
env;
kid_id_renames = KBindings.empty;
kid_id_renames_rev = Bindings.empty;
}
let initial_context env global = { global; env; kid_id_renames = KBindings.empty; kid_id_renames_rev = Bindings.empty }

let add_single_kid_id_rename ctx id kid =
let kir =
Expand Down Expand Up @@ -379,7 +373,7 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
begin
match pat with
| P_aux (P_typ (_, P_aux (P_wild, _)), _) -> string ""
| _ -> flow (break 1) [string "let"; e0; string ""] ^^ space
| _ -> flow (break 1) [string "let"; e0; string ":="] ^^ space
end
in
nest 2 (e0_pp ^^ e1_pp) ^^ hardline ^^ e2_pp
Expand All @@ -390,12 +384,15 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
in
let d_args = List.map d_of_arg args in
let fn_monadic = not (Effects.function_is_pure f ctx.global.effect_info) in
nest 2 (wrap_with_pure (as_monadic && fn_monadic) (parens (flow (break 1) (d_id :: d_args))))
nest 2
(wrap_with_left_arrow ((not as_monadic) && fn_monadic)
(wrap_with_pure (as_monadic && not fn_monadic) (parens (flow (break 1) (d_id :: d_args))))
)
| E_vector vals ->
string "#v" ^^ wrap_with_pure as_monadic (brackets (nest 2 (flow (comma ^^ break 1) (List.map d_of_arg vals))))
| E_typ (typ, e) ->
if effectful (effect_of e) then
parens (separate space [doc_exp false ctx e; colon; string "SailM"; doc_typ ctx typ])
parens (separate space [doc_exp as_monadic ctx e; colon; string "SailM"; doc_typ ctx typ])
else wrap_with_pure as_monadic (parens (separate space [doc_exp false ctx e; colon; doc_typ ctx typ]))
| E_tuple es -> wrap_with_pure as_monadic (parens (separate_map (comma ^^ space) d_of_arg es))
| E_let (LB_aux (LB_val (lpat, lexp), _), e) ->
Expand Down Expand Up @@ -455,7 +452,7 @@ let doc_binder ctx i t =
let ctx = match captured_typ_var (i, t) with Some (i, ki) -> add_single_kid_id_rename ctx i ki | _ -> ctx in
(ctx, separate space [string (string_of_id i); colon; doc_typ ctx t] |> paranthesizer)

let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) =
let doc_funcl_init global (FCL_aux (FCL_funcl (id, pexp), annot)) =
let env = env_of_tannot (snd annot) in
let TypQ_aux (tq, l), typ = Env.get_val_spec_orig id env in
let arg_typs, ret_typ, _ =
Expand All @@ -474,7 +471,7 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) =
| _ -> failwith "Argument pattern not translatable yet."
)
in
let ctx = initial_context env in
let ctx = initial_context env global in
let ctx, binders =
List.fold_left
(fun (ctx, bs) (i, t) ->
Expand Down Expand Up @@ -503,19 +500,19 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) =
fixup_binders
)

let doc_funcl_body fixup_binders (FCL_aux (FCL_funcl (id, pexp), annot)) =
let doc_funcl_body fixup_binders global (FCL_aux (FCL_funcl (id, pexp), annot)) =
let env = env_of_tannot (snd annot) in
let ctx = initial_context env in
let _, _, exp, _ = destruct_pexp pexp in
(* If an argument was [x : (Int, Int)], which is transformed to [(arg0: Int) (arg1: Int)],
this adds a let binding at the beginning of the function, of the form [let x := (arg0, arg1)] *)
let exp = fixup_binders exp in
let is_monadic = effectful (effect_of exp) in
doc_exp is_monadic (initial_context env) exp
doc_exp is_monadic (initial_context env global) exp

let doc_funcl ctx funcl =
let comment, signature, env, fixup_binders = doc_funcl_init funcl in
comment ^^ nest 2 (signature ^^ hardline ^^ doc_funcl_body fixup_binders funcl)
let comment, signature, env, fixup_binders = doc_funcl_init ctx.global funcl in
comment ^^ nest 2 (signature ^^ hardline ^^ doc_funcl_body fixup_binders ctx.global funcl)

let doc_fundef ctx (FD_aux (FD_function (r, typa, fcls), fannot)) =
match fcls with
Expand Down Expand Up @@ -642,8 +639,8 @@ let inhabit_enum ctx typ_map =
)
typ_map

let doc_reg_info env registers =
let ctx = initial_context env in
let doc_reg_info env global registers =
let ctx = initial_context env global in

let type_map = List.fold_left add_reg_typ Bindings.empty registers in
let type_map = Bindings.bindings type_map in
Expand All @@ -660,10 +657,11 @@ let doc_reg_info env registers =
empty;
]

let pp_ast_lean (env : Type_check.env) ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o =
let pp_ast_lean (env : Type_check.env) effect_info ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o =
let defs = remove_imports defs 0 in
let regs = State.find_registers defs in
let register_refs = match regs with [] -> empty | _ -> doc_reg_info env regs in
let types, fundefs = doc_defs (initial_context env) defs in
let global = { effect_info } in
let register_refs = match regs with [] -> empty | _ -> doc_reg_info env global regs in
let types, fundefs = doc_defs (initial_context env global) defs in
print o (types ^^ register_refs ^^ fundefs);
()
6 changes: 3 additions & 3 deletions src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ let create_lake_project (out_name : string) default_sail_dir =
output_string project_main "open Sail\n\n";
project_main

let output (out_name : string) env ast default_sail_dir =
let output (out_name : string) env effect_info ast default_sail_dir =
let project_main = create_lake_project out_name default_sail_dir in
(* Uncomment for debug output of the Sail code after the rewrite passes *)
(* Pretty_print_sail.output_ast stdout (Type_check.strip_ast ast); *)
Pretty_print_lean.pp_ast_lean env ast project_main;
Pretty_print_lean.pp_ast_lean env effect_info ast project_main;
close_out project_main

let lean_target out_name { default_sail_dir; ctx; ast; effect_info; env; _ } =
let out_name = match out_name with Some f -> f | None -> "out" in
output out_name env ast default_sail_dir
output out_name env effect_info ast default_sail_dir

let _ = Target.register ~name:"lean" ~options:lean_options ~rewrites:lean_rewrites ~asserts_termination:true lean_target
14 changes: 7 additions & 7 deletions test/lean/bitfield.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _update_cr_type_bits (v : (BitVec 8)) (x : (BitVec 8)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v (HSub.hSub 8 1) 0 x)

def _set_cr_type_bits (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 8)) : SailM Unit := do
let r ← (reg_deref r_ref)
let r := (← (reg_deref r_ref))
writeRegRef r_ref (_update_cr_type_bits r v)

def _get_cr_type_CR0 (v : (BitVec 8)) : (BitVec 4) :=
Expand All @@ -41,7 +41,7 @@ def _update_cr_type_CR0 (v : (BitVec 8)) (x : (BitVec 4)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 7 4 x)

def _set_cr_type_CR0 (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 4)) : SailM Unit := do
let r ← (reg_deref r_ref)
let r := (← (reg_deref r_ref))
writeRegRef r_ref (_update_cr_type_CR0 r v)

def _get_cr_type_CR1 (v : (BitVec 8)) : (BitVec 2) :=
Expand All @@ -51,7 +51,7 @@ def _update_cr_type_CR1 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 3 2 x)

def _set_cr_type_CR1 (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 2)) : SailM Unit := do
let r ← (reg_deref r_ref)
let r := (← (reg_deref r_ref))
writeRegRef r_ref (_update_cr_type_CR1 r v)

def _get_cr_type_CR3 (v : (BitVec 8)) : (BitVec 2) :=
Expand All @@ -61,7 +61,7 @@ def _update_cr_type_CR3 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 1 0 x)

def _set_cr_type_CR3 (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 2)) : SailM Unit := do
let r ← (reg_deref r_ref)
let r := (← (reg_deref r_ref))
writeRegRef r_ref (_update_cr_type_CR3 r v)

def _get_cr_type_GT (v : (BitVec 8)) : (BitVec 1) :=
Expand All @@ -71,7 +71,7 @@ def _update_cr_type_GT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 6 6 x)

def _set_cr_type_GT (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 1)) : SailM Unit := do
let r ← (reg_deref r_ref)
let r := (← (reg_deref r_ref))
writeRegRef r_ref (_update_cr_type_GT r v)

def _get_cr_type_LT (v : (BitVec 8)) : (BitVec 1) :=
Expand All @@ -81,9 +81,9 @@ def _update_cr_type_LT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 7 7 x)

def _set_cr_type_LT (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 1)) : SailM Unit := do
let r ← (reg_deref r_ref)
let r := (← (reg_deref r_ref))
writeRegRef r_ref (_update_cr_type_LT r v)

def initialize_registers : SailM Unit := do
writeReg R (undefined_cr_type ())
writeReg R (← (undefined_cr_type ()))

18 changes: 16 additions & 2 deletions test/lean/struct.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@ structure My_struct where
field1 : Int
field2 : (BitVec 1)

inductive Register : Type where
| r
deriving DecidableEq, Hashable
open Register

abbrev RegisterType : Register → Type
| .r => My_struct

abbrev SailM := PreSailM RegisterType

open RegisterRef
instance : Inhabited (RegisterRef RegisterType My_struct) where
default := .Reg r

def undefined_My_struct (lit : Unit) : SailM My_struct := do
(pure { field1 := (← sorry)
field2 := (← sorry) })
Expand All @@ -28,6 +42,6 @@ def mk_struct (i : Int) (b : (BitVec 1)) : My_struct :=
def undef_struct (x : (BitVec 1)) : SailM My_struct := do
((undefined_My_struct ()) : SailM My_struct)

def initialize_registers : Unit :=
()
def initialize_registers : SailM Unit := do
writeReg r (← (undefined_My_struct ()))

2 changes: 2 additions & 0 deletions test/lean/struct.sail
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ struct My_struct = {
field2 : bit,
}

register r : My_struct

val struct_field2 : My_struct -> bit
function struct_field2(s) = {
s.field2
Expand Down

0 comments on commit 5740618

Please sign in to comment.