From 49d5bedd6ab684f27fd7ed56ba307f2d228e2443 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Stefanesco?= Date: Mon, 27 Jan 2025 15:17:57 +0000 Subject: [PATCH] Add support to match expressions in Lean backend (#903) --- src/sail_lean_backend/pretty_print_lean.ml | 52 +++++++++++++++------- src/sail_lean_backend/sail_plugin_lean.ml | 4 +- test/lean/enum.expected.lean | 1 + test/lean/match.expected.lean | 40 +++++++++++++++++ test/lean/match.sail | 33 ++++++++++++++ 5 files changed, 113 insertions(+), 17 deletions(-) create mode 100644 test/lean/match.expected.lean create mode 100644 test/lean/match.sail diff --git a/src/sail_lean_backend/pretty_print_lean.ml b/src/sail_lean_backend/pretty_print_lean.ml index 9cfb18e32..c2468d303 100644 --- a/src/sail_lean_backend/pretty_print_lean.ml +++ b/src/sail_lean_backend/pretty_print_lean.ml @@ -82,7 +82,8 @@ let args_of_typ l env typs = especially so that the function is presented in curried form. In particular, if there's a single binder for multiple arguments (which rewriting can currently introduce) then we need to turn it - into multiple binders and reconstruct it in the function body. *) + into multiple binders and reconstruct it in the function body using + the second return value of this function. *) let rec untuple_args_pat typs (P_aux (paux, ((l, _) as annot)) as pat) = let env = env_of_annot annot in let identity body = body in @@ -155,6 +156,7 @@ let rec doc_typ ctx (Typ_aux (t, _) as typ) = string "RegisterRef RegisterType " ^^ separate_map comma (doc_typ_app ctx) t_app | Typ_app (Id_aux (Id "implicit", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]) -> underscore (* TODO check if the type of implicit arguments can really be always inferred *) + | Typ_app (Id_aux (Id "option", _), [A_aux (A_typ typ, _)]) -> parens (string "Option " ^^ doc_typ ctx typ) | Typ_tuple ts -> parens (separate_map (space ^^ string "×" ^^ space) (doc_typ ctx) ts) | Typ_id (Id_aux (Id id, _)) -> string id | Typ_app (Id_aux (Id "range", _), [A_aux (A_nexp low, _); A_aux (A_nexp high, _)]) -> @@ -292,15 +294,19 @@ let string_of_pat_con (P_aux (p, _)) = | P_string_append _ -> "P_string_append" | P_struct _ -> "P_struct" -let rec doc_pat ctxt apat_needed (P_aux (p, (l, annot)) as pat) = - let env = env_of_annot (l, annot) in - let typ = Env.expand_synonyms env (typ_of_annot (l, annot)) in +(** Fix identifiers to match the standard Lean library. *) +let fixup_match_id (Id_aux (id, l) as id') = + match id with Id id -> Id_aux (Id (match id with "Some" -> "some" | "None" -> "none" | _ -> id), l) | _ -> id' + +let rec doc_pat (P_aux (p, (l, annot)) as pat) = match p with - | P_typ (ptyp, p) -> - let doc_p = doc_pat ctxt true p in - doc_p - | P_id id -> doc_id_ctor id | P_wild -> underscore + | P_lit lit -> doc_lit lit + | P_typ (ptyp, p) -> doc_pat p + | P_id id -> fixup_match_id id |> doc_id_ctor + | P_tuple pats -> separate (string ", ") (List.map doc_pat pats) |> parens + | P_list pats -> separate (string ", ") (List.map doc_pat pats) |> brackets + | P_app (cons, pats) -> doc_id_ctor cons ^^ space ^^ separate_map (string ", ") doc_pat pats | _ -> failwith ("Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.") (* Copied from the Coq PP *) @@ -333,7 +339,12 @@ let wrap_with_pure (needs_return : bool) (d : document) = let wrap_with_left_arrow (needs_return : bool) (d : document) = if needs_return then parens (nest 2 (flow space [string "←"; d])) else d -let rec doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) = +let rec doc_match_clause ctx (Pat_aux (cl, l)) = + match cl with + | Pat_exp (pat, branch) -> string "| " ^^ doc_pat pat ^^ string " =>" ^^ space ^^ doc_exp false ctx branch + | Pat_when (pat, when_, branch) -> failwith "The Lean backend does not support 'when' clauses in patterns" + +and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) = let env = env_of_tannot annot in let d_of_arg arg = let arg_monadic = effectful (effect_of arg) in @@ -357,7 +368,7 @@ let rec doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) = (* TODO replace by actual implementation of internal_pick *) string "sorry" | E_internal_plet (pat, e1, e2) -> - let e0 = doc_pat ctx false pat in + let e0 = doc_pat pat in let e1_pp = doc_exp false ctx e1 in let e2' = rebind_cast_pattern_vars pat (typ_of e1) e2 in let e2_pp = doc_exp false ctx e2' in @@ -407,6 +418,9 @@ let rec doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) = (* TODO *) wrap_with_pure as_monadic (braces (space ^^ doc_exp false ctx exp ^^ string " with " ^^ separate (comma ^^ space) args ^^ space)) + | E_match (discr, brs) -> + let cases = hardline ^^ separate_map hardline (fun br -> doc_match_clause ctx br) brs ^^ hardline in + string "match " ^^ doc_exp (effectful (effect_of discr)) ctx discr ^^ string " with" ^^ cases | E_assign ((LE_aux (le_act, tannot) as le), e) -> ( match le_act with | LE_id id | LE_typ (_, id) -> string "writeReg " ^^ doc_id_ctor id ^^ space ^^ doc_exp false ctx e @@ -438,7 +452,7 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) = | _ -> failwith ("Function " ^ string_of_id id ^ " does not have function type") in let pat, _, exp, _ = destruct_pexp pexp in - let pats, _ = untuple_args_pat arg_typs pat in + let pats, fixup_binders = untuple_args_pat arg_typs pat in let binders : (id * typ) list = pats |> List.map (fun (pat, typ) -> @@ -471,18 +485,25 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) = let decl_val = [doc_ret_typ; coloneq] in (* Add do block for stateful functions *) let decl_val = if is_monadic then decl_val @ [string "do"] else decl_val in - (typ_quant_comment, separate space ([string "def"; string (string_of_id id)] @ binders @ [colon] @ decl_val), env) + ( typ_quant_comment, + separate space ([string "def"; string (string_of_id id)] @ binders @ [colon] @ decl_val), + env, + fixup_binders + ) -let doc_funcl_body (FCL_aux (FCL_funcl (id, pexp), annot)) = +let doc_funcl_body fixup_binders (FCL_aux (FCL_funcl (id, pexp), annot)) = let env = env_of_tannot (snd annot) in let ctx = initial_context env in let _, _, exp, _ = destruct_pexp pexp in + (* If an argument was [x : (Int, Int)], which is transformed to [(arg0: Int) (arg1: Int)], + this adds a let binding at the beginning of the function, of the form [let x := (arg0, arg1)] *) + let exp = fixup_binders exp in let is_monadic = effectful (effect_of exp) in doc_exp is_monadic (initial_context env) exp let doc_funcl ctx funcl = - let comment, signature, env = doc_funcl_init funcl in - comment ^^ nest 2 (signature ^^ hardline ^^ doc_funcl_body funcl) + let comment, signature, env, fixup_binders = doc_funcl_init funcl in + comment ^^ nest 2 (signature ^^ hardline ^^ doc_funcl_body fixup_binders funcl) let doc_fundef ctx (FD_aux (FD_function (r, typa, fcls), fannot)) = match fcls with @@ -511,6 +532,7 @@ let doc_typdef ctx (TD_aux (td, tannot) as full_typdef) = ^^ enums_doc ^^ hardline ^^ string "deriving" ^^ space ^^ separate (comma ^^ space) derivers ) + ^^ hardline ^^ string "open " ^^ string id | TD_record (Id_aux (Id id, _), TypQ_aux (tq, _), fields, _) -> let fields = List.map (doc_typ_id ctx) fields in let enums_doc = separate hardline fields in diff --git a/src/sail_lean_backend/sail_plugin_lean.ml b/src/sail_lean_backend/sail_plugin_lean.ml index c9274a18a..dc724d376 100644 --- a/src/sail_lean_backend/sail_plugin_lean.ml +++ b/src/sail_lean_backend/sail_plugin_lean.ml @@ -108,8 +108,8 @@ let lean_rewrites = ("simple_assignments", []); ("remove_vector_concat", []); ("remove_bitvector_pats", []); - ("remove_numeral_pats", []); - ("pattern_literals", [Literal_arg "lem"]); + (* ("remove_numeral_pats", []); *) + (* ("pattern_literals", [Literal_arg "lem"]); *) ("guarded_pats", []); (* ("register_ref_writes", rewrite_register_ref_writes); *) ("nexp_ids", []); diff --git a/test/lean/enum.expected.lean b/test/lean/enum.expected.lean index 637d9c28d..a9a1a2a73 100644 --- a/test/lean/enum.expected.lean +++ b/test/lean/enum.expected.lean @@ -4,6 +4,7 @@ open Sail inductive E where | A | B | C deriving Inhabited +open E def undefined_E : SailM E := do sorry diff --git a/test/lean/match.expected.lean b/test/lean/match.expected.lean new file mode 100644 index 000000000..ac5ba5ce4 --- /dev/null +++ b/test/lean/match.expected.lean @@ -0,0 +1,40 @@ +import Out.Sail.Sail + +open Sail + +inductive E where | A | B | C + deriving Inhabited +open E + +def undefined_E : SailM E := do + sorry + +def match_enum (x : E) : (BitVec 1) := + match x with + | A => 1#1 + | B => 1#1 + | C => 0#1 + + +def match_option (x : (Option (BitVec 1))) : (BitVec 1) := + match x with + | Some x => x + | None () => 0#1 + + +/-- Type quantifiers: y : Int, x : Int -/ +def match_pair_pat (x : Int) (y : Int) : Int := + match (x, y) with + | (a, b) => (HAdd.hAdd a b) + + +/-- Type quantifiers: arg1 : Int, arg0 : Int -/ +def match_pair (arg0 : Int) (arg1 : Int) : Int := + let x := (arg0, arg1) + match x with + | (a, b) => (HAdd.hAdd a b) + + +def initialize_registers : Unit := + () + diff --git a/test/lean/match.sail b/test/lean/match.sail new file mode 100644 index 000000000..45220f4d3 --- /dev/null +++ b/test/lean/match.sail @@ -0,0 +1,33 @@ +default Order dec + +$include + +$[no_enum_number_conversions] +enum E = A | B | C + +function match_enum(x : E) -> bit = { + match x { + A => bitone, + B => bitone, + C => bitzero, + } +} + +function match_option(x : option(bit)) -> bit = { + match x { + Some(x) => x, + None() => bitzero, + } +} + +function match_pair_pat((x, y) : (int, int)) -> int = { + match (x, y) { + (a, b) => a + b, + } +} + +function match_pair(x : (int, int)) -> int = { + match x { + (a, b) => a + b, + } +}