Skip to content

Commit

Permalink
Optimise monadic functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ncough authored and katrinafyi committed Jan 16, 2024
1 parent 5c2a3e7 commit aceb49c
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 66 deletions.
71 changes: 40 additions & 31 deletions libASL/dis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,9 @@ module DisEnv = struct

open Let

let getVar (loc: l) (x: ident): (ty * sym) rws =
let+ (_,v) = gets (LocalEnv.resolveGetVar loc x) in
v
let getVar (loc: l) (x: ident): (ty * sym) rws = fun env s ->
let (_,v) = LocalEnv.resolveGetVar loc x s in
(v,s,empty)

let uninit (t: ty) (env: Eval.Env.t): value =
try
Expand All @@ -351,27 +351,26 @@ module DisEnv = struct
| true -> v1) in
out) l r

let join_locals (l: LocalEnv.t) (r: LocalEnv.t): unit rws =
let* env = read in
let join_locals (l: LocalEnv.t) (r: LocalEnv.t): unit rws = fun env s ->
assert (l.returnSymbols = r.returnSymbols);
assert (l.indent = r.indent);
assert (l.trace = r.trace);
let locals' = List.map2 (merge_bindings env) l.locals r.locals in
put {
let s : LocalEnv.t = {
locals = locals';
returnSymbols = l.returnSymbols;
numSymbols = max l.numSymbols r.numSymbols;
indent = l.indent;
trace = l.trace;
}

} in
((),s,empty)

let getFun (loc: l) (x: ident): Eval.fun_sig option rws =
reads (fun env -> Eval.Env.getFunOpt loc env x)

let nextVarName (prefix: string): ident rws =
let+ num = stateful LocalEnv.incNumSymbols in
Ident (prefix ^ string_of_int num)
let nextVarName (prefix: string): ident rws = fun env s ->
let num, s = LocalEnv.incNumSymbols s in
(Ident (prefix ^ string_of_int num),s,empty)

let indent: string rws =
let+ i = gets (fun l -> l.indent) in
Expand Down Expand Up @@ -436,13 +435,24 @@ end

type 'a rws = 'a DisEnv.rws

let (let@) = DisEnv.Let.(let*)
let (and@) = DisEnv.Let.(and*)
let (let+) = DisEnv.Let.(let+)
let (and+) = DisEnv.Let.(and+)
let (let@) x f = fun env s ->
let (r,s,w) = x env s in
let (r',s,w') = (f r) env s in
(r',s,append w w')

let (let+) x f = fun env s ->
let (r,s,w) = x env s in
(f r,s,w)

let (>>) = DisEnv.(>>)
let (>>=) = DisEnv.(>>=)
let (>>) x f = fun env s ->
let (_,s,w) = x env s in
let (r',s,w') = f env s in
(r',s,append w w')

let (>>=) x f = fun env s ->
let (r,s,w) = x env s in
let (r',s,w') = (f r) env s in
(r',s,append w w')

(** Convert value to a simple expression containing that value, so we can
print it or use it symbolically *)
Expand All @@ -462,7 +472,6 @@ let is_expr (v: sym): bool =
| Exp _ -> true

let declare_var (loc: l) (t: ty) (i: ident): var rws =
let@ env = DisEnv.read in
let@ uninit = DisEnv.mkUninit t in
let@ var = DisEnv.stateful
(LocalEnv.addLocalVar loc i (Val uninit) t) in
Expand Down Expand Up @@ -694,8 +703,8 @@ and dis_pattern (loc: l) (v: sym) (x: AST.pattern): sym rws =
let+ v' = dis_expr loc e in
sym_eq loc v v'
| Pat_Range(lo, hi) ->
let+ lo' = dis_expr loc lo
and+ hi' = dis_expr loc hi in
let@ lo' = dis_expr loc lo in
let+ hi' = dis_expr loc hi in
sym_and_bool loc (sym_le_int loc lo' v) (sym_le_int loc v hi')
)

Expand All @@ -706,13 +715,13 @@ and dis_slice (loc: l) (x: slice): (sym * sym) rws =
let+ i' = dis_expr loc i in
(i', Val (VInt Z.one))
| Slice_HiLo(hi, lo) ->
let+ hi' = dis_expr loc hi
and+ lo' = dis_expr loc lo in
let@ hi' = dis_expr loc hi in
let+ lo' = dis_expr loc lo in
let wd' = sym_add_int loc (sym_sub_int loc hi' lo') (Val (VInt Z.one)) in
(lo', wd')
| Slice_LoWd(lo, wd) ->
let+ lo' = dis_expr loc lo
and+ wd' = dis_expr loc wd in
let@ lo' = dis_expr loc lo in
let+ wd' = dis_expr loc wd in
(lo', wd')
)

Expand Down Expand Up @@ -796,10 +805,10 @@ and dis_expr' (loc: l) (x: AST.expr): sym rws =
let+ vs = DisEnv.traverse (fun f -> dis_load loc (Expr_Field(e,f))) fs in
sym_concat loc vs
| Expr_Slices(e, ss) ->
let@ e' = dis_expr loc e
and@ ss' = DisEnv.traverse (dis_slice loc) ss in
let@ e' = dis_expr loc e in
let+ ss' = DisEnv.traverse (dis_slice loc) ss in
let vs = List.map (fun (i,w) -> sym_extract_bits loc e' i w) ss' in
DisEnv.pure (sym_concat loc vs)
sym_concat loc vs
| Expr_In(e, p) ->
let@ e' = dis_expr loc e in
let@ p' = dis_pattern loc e' p in
Expand Down Expand Up @@ -907,7 +916,7 @@ and dis_call' (loc: l) (f: ident) (tes: sym list) (es: sym list): sym option rws
(match rty with
| Some rty ->
let@ () = DisEnv.modify LocalEnv.addLevel in
let@ () = DisEnv.sequence_ @@ List.map2 (fun arg e ->
let@ () = DisEnv.traverse2_ (fun arg e ->
declare_const loc type_integer arg e
) targs tes in

Expand All @@ -930,18 +939,18 @@ and dis_call' (loc: l) (f: ident) (tes: sym list) (es: sym list): sym option rws
assert (List.length targs == List.length tes);

(* Assign targs := tes *)
let@ () = DisEnv.sequence_ @@ List.map2 (fun arg e ->
let@ () = DisEnv.traverse2_ (fun arg e ->
declare_const loc type_integer arg e
) targs tes in

assert (List.length atys == List.length args);
assert (List.length atys == List.length es);

(* Assign args := es *)
let@ () = DisEnv.sequence_ (Utils.map3 (fun (ty, _) arg e ->
let@ () = DisEnv.traverse3_ (fun (ty, _) arg e ->
let@ ty' = dis_type loc ty in
declare_const loc ty' arg e
) atys args es) in
) atys args es in

(* Create return variable (if necessary).
This is in the inner scope to allow for type parameters. *)
Expand Down
35 changes: 0 additions & 35 deletions libASL/monad.ml
Original file line number Diff line number Diff line change
Expand Up @@ -42,41 +42,6 @@ module Make (M : S) = struct
let (and*) = (and+)
end

open Let

(* higher-order functions and transformations *)

(** Performs a list of computations in sequence, resulting in a list
of their results. *)
let rec sequence (xs: 'a m list): 'a list m =
match xs with
| (x::xs) ->
let+ x = x
and+ xs = sequence xs in
(x :: xs)
| [] -> pure []

(** Performs a list of computations in sequence and discard their results
(but retains their monad effects). *)
let sequence_ (xs : 'a m list): unit m =
let+ _ = sequence xs in ()

(** Uses the given function to create a list of computations which are
then run sequentially. Results in a list of their results. *)
let traverse (f: 'a -> 'b m) (x: 'a list): 'b list m =
sequence (List.map f x)

let traverse2 (f: 'a -> 'b -> 'c m) (x: 'a list) (y: 'b list): 'c list m =
sequence (List.map2 f x y)

(** Uses the given function to create a list of computations which are
then run sequentually. Discards their results. *)
let traverse_ (f: 'a -> 'b m) (x: 'a list): unit m =
let+ _ = sequence (List.map f x) in ()

let traverse2_ (f: 'a -> 'b -> 'c m) (x: 'a list) (y: 'b list): unit m =
let+ _ = sequence (List.map2 f x y) in ()

(** A nil computation. Does nothing and returns nothing of interest. *)
let unit: unit m = pure ()

Expand Down
40 changes: 40 additions & 0 deletions libASL/rws.ml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,46 @@ module RWSBase (T : S) = struct
let bt = Printexc.get_raw_backtrace () in
(Error (e, bt), s, mempty)

let rec traverse (f: 'a -> 'b rws) (x: 'a list) (r: r) (s: s) =
match x with
| [] -> ([],s,mempty)
| x::xs ->
let (i,s,w) = (f x) r s in
let (is,s,w') = traverse f xs r s in
(i::is,s,mappend w w')

let rec traverse_r (w: w) (f: 'a -> 'b rws) (x: 'a list) (r: r) (s: s) =
match x with
| [] -> ((),s,w)
| x::xs ->
let (_,s',w') = (f x) r s in
traverse_r (mappend w w') f xs r s'

let traverse_ (f: 'a -> 'b rws) (x: 'a list): unit rws =
traverse_r mempty f x

let rec traverse2_r (w: w) (f: 'a -> 'b -> 'c rws) (x: 'a list) (y: 'b list) (r: r) (s: s) =
match x, y with
| [], [] -> ((),s,w)
| x::xs, y::ys ->
let (_,s',w') = (f x y) r s in
traverse2_r (mappend w w') f xs ys r s'
| _ -> invalid_arg "traverse2_"

let traverse2_ (f: 'a -> 'b -> 'c rws) (x: 'a list) (y: 'b list): unit rws =
traverse2_r mempty f x y

let rec traverse3_r (w: w) (f: 'a -> 'b -> 'c -> 'd rws) (x: 'a list) (y: 'b list) (z: 'c list) (r: r) (s: s) =
match x, y, z with
| [], [], [] -> ((),s,w)
| x::xs, y::ys, z::zs ->
let (_,s',w') = (f x y z) r s in
traverse3_r (mappend w w') f xs ys zs r s'
| _ -> invalid_arg "traverse3_"

let traverse3_ (f: 'a -> 'b -> 'c -> 'd rws) (x: 'a list) (y: 'b list) (z: 'c list): unit rws =
traverse3_r mempty f x y z

end

(** Constructs a RWS monad using the given signature. *)
Expand Down

0 comments on commit aceb49c

Please sign in to comment.