Skip to content

Commit

Permalink
support for implicit arguments in function application
Browse files Browse the repository at this point in the history
  • Loading branch information
javra committed Feb 4, 2025
1 parent a5dbda8 commit 3cbf2fd
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 16 deletions.
48 changes: 32 additions & 16 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ open Rewriter
open PPrint
open Pretty_print_common

type global_context = { effect_info : Effects.side_effect_info }
type global_context = { effect_info : Effects.side_effect_info; implicitness : bool list Bindings.t }

type context = {
global : global_context;
Expand Down Expand Up @@ -422,6 +422,11 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
else doc_exp false ctx (E_aux (E_id f, (l, annot)))
in
let d_args = List.map d_of_arg args in
let d_args =
match Bindings.find_opt f ctx.global.implicitness with
| Some is -> List.map snd (List.filter (fun x -> not (fst x)) (List.combine is d_args))
| None -> d_args
in
let fn_monadic = not (Effects.function_is_pure f ctx.global.effect_info) in
nest 2
(wrap_with_left_arrow ((not as_monadic) && fn_monadic)
Expand Down Expand Up @@ -482,15 +487,15 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
and doc_fexp with_arrow ctx (FE_aux (FE_fexp (field, e), _)) = doc_id_ctor field ^^ string " := " ^^ doc_exp false ctx e

let doc_binder ctx i t =
let paranthesizer =
let parenthesizer, is_implicit =
match t with
| Typ_aux (Typ_app (Id_aux (Id "implicit", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]), _) ->
implicit_parens
| _ -> parens
(implicit_parens, true)
| _ -> (parens, false)
in
(* Overwrite the id if it's captured *)
let ctx = match captured_typ_var (i, t) with Some (i, ki) -> add_single_kid_id_rename ctx i ki | _ -> ctx in
(ctx, separate space [doc_id_ctor i; colon; doc_typ ctx t] |> paranthesizer)
(ctx, separate space [doc_id_ctor i; colon; doc_typ ctx t] |> parenthesizer, is_implicit)

let doc_funcl_init global (FCL_aux (FCL_funcl (id, pexp), annot)) =
let env = env_of_tannot (snd annot) in
Expand All @@ -512,13 +517,13 @@ let doc_funcl_init global (FCL_aux (FCL_funcl (id, pexp), annot)) =
)
in
let ctx = initial_context env global in
let ctx, binders =
let ctx, binders, implicitness =
List.fold_left
(fun (ctx, bs) (i, t) ->
let ctx, d = doc_binder ctx i t in
(ctx, bs @ [d])
(fun (ctx, bs, is) (i, t) ->
let ctx, d, i = doc_binder ctx i t in
(ctx, bs @ [d], is @ [i])
)
(ctx, []) binders
(ctx, [], []) binders
in
let typ_quants = doc_typ_quant ctx tq in
let typ_quant_comment =
Expand All @@ -534,7 +539,12 @@ let doc_funcl_init global (FCL_aux (FCL_funcl (id, pexp), annot)) =
let decl_val = [doc_ret_typ; coloneq] in
(* Add do block for stateful functions *)
let decl_val = if is_monadic then decl_val @ [string "do"] else decl_val in
(typ_quant_comment, separate space ([string "def"; doc_id_ctor id] @ binders @ [colon] @ decl_val), env, fixup_binders)
( typ_quant_comment,
separate space ([string "def"; doc_id_ctor id] @ binders @ [colon] @ decl_val),
env,
fixup_binders,
implicitness
)

let doc_funcl_body fixup_binders global (FCL_aux (FCL_funcl (id, pexp), annot)) =
let env = env_of_tannot (snd annot) in
Expand All @@ -547,13 +557,18 @@ let doc_funcl_body fixup_binders global (FCL_aux (FCL_funcl (id, pexp), annot))
doc_exp is_monadic (initial_context env global) exp

let doc_funcl ctx funcl =
let comment, signature, env, fixup_binders = doc_funcl_init ctx.global funcl in
comment ^^ nest 2 (signature ^^ hardline ^^ doc_funcl_body fixup_binders ctx.global funcl)
let comment, signature, env, fixup_binders, implicitness = doc_funcl_init ctx.global funcl in
(comment ^^ nest 2 (signature ^^ hardline ^^ doc_funcl_body fixup_binders ctx.global funcl), implicitness)

let doc_fundef ctx (FD_aux (FD_function (r, typa, fcls), fannot)) =
match fcls with
| [] -> failwith "FD_function with empty function list"
| [funcl] -> doc_funcl ctx funcl
| [(FCL_aux (FCL_funcl (id, pexp), annot) as funcl)] ->
let pp_funcl, implicitness = doc_funcl ctx funcl in
let ctx =
{ ctx with global = { ctx.global with implicitness = Bindings.add id implicitness ctx.global.implicitness } }
in
(ctx, pp_funcl)
| _ -> failwith "FD_function with more than one clause"

let doc_type_union ctx (Tu_aux (Tu_ty_id (ty, i), _)) =
Expand Down Expand Up @@ -626,7 +641,8 @@ let rec doc_defs_rec ctx defs types docdefs =
match defs with
| [] -> (types, docdefs)
| DEF_aux (DEF_fundef fdef, _) :: defs' ->
doc_defs_rec ctx defs' types (docdefs ^^ group (doc_fundef ctx fdef) ^/^ hardline)
let ctx', pp_fun = doc_fundef ctx fdef in
doc_defs_rec ctx' defs' types (docdefs ^^ group pp_fun ^/^ hardline)
| DEF_aux (DEF_type tdef, _) :: defs' ->
doc_defs_rec ctx defs' (types ^^ group (doc_typdef ctx tdef) ^/^ hardline) docdefs
| DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), _)), _) :: defs' ->
Expand Down Expand Up @@ -689,7 +705,7 @@ let doc_monad_abbrev (has_registers : bool) =
let pp_ast_lean (env : Type_check.env) effect_info ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o =
let defs = remove_imports defs 0 in
let regs = State.find_registers defs in
let global = { effect_info } in
let global = { effect_info; implicitness = Bindings.empty } in
let has_registers = List.length regs > 0 in
let register_refs = if has_registers then doc_reg_info env global regs else empty in
let monad = doc_monad_abbrev has_registers in
Expand Down
16 changes: 16 additions & 0 deletions test/lean/implicit.expected.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import Out.Sail.Sail

open Sail

abbrev SailM := StateM Unit

/-- Type quantifiers: k_n : Int, m : Int, m ≥ k_n -/
def EXTZ {m : _} (v : (BitVec k_n)) : (BitVec m) :=
(Sail.BitVec.zeroExtend v m)

def foo (x : (BitVec 8)) : (BitVec 16) :=
(EXTZ x)

def initialize_registers (lit : Unit) : Unit :=
()

11 changes: 11 additions & 0 deletions test/lean/implicit.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
default Order dec

$include <prelude.sail>

val EXTZ : forall 'n 'm, 'm >= 'n. (implicit('m), bits('n)) -> bits('m)
function EXTZ(m, v) = sail_zero_extend(v, m)

val foo : bits(8) -> bits(16)
function foo x = {
EXTZ(x)
}

0 comments on commit 3cbf2fd

Please sign in to comment.