Skip to content

Commit

Permalink
Add vectorization pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ncough committed Jul 9, 2024
1 parent 0bd3d12 commit 17b0aa9
Show file tree
Hide file tree
Showing 9 changed files with 1,396 additions and 210 deletions.
56 changes: 42 additions & 14 deletions libASL/dis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ module StringCmp = struct
end
module StringMap = Map.Make(StringCmp)

let unroll_bound = Z.of_int 2

let debug_level_none = -1
let debug_level = ref debug_level_none
Expand Down Expand Up @@ -116,6 +115,8 @@ let no_inline_pure = [
"ASR",0;
"SignExtend",0;
"ZeroExtend",0;
"Elem.set",0;
"Elem.read",0;
]

(** A variable's stack level and original identifier name.
Expand Down Expand Up @@ -327,9 +328,14 @@ let rec flatten x acc =
flatten x acc
| Node i -> i@acc

type config = {
eval_env: Eval.Env.t;
unroll_bound: Z.t;
}

module DisEnv = struct
include Rws.Make(struct
type r = Eval.Env.t
type r = config
type w = tree
type s = LocalEnv.t
let mempty = empty
Expand All @@ -342,9 +348,9 @@ module DisEnv = struct
let (_,v) = LocalEnv.resolveGetVar loc x s in
(v,s,empty)

let uninit (t: ty) (env: Eval.Env.t): value =
let uninit (t: ty) (config): value =
try
Eval.mk_uninitialized Unknown env t
Eval.mk_uninitialized Unknown config.eval_env t
with
e -> unsupported Unknown @@
"mkUninit: failed to evaluate type " ^ pp_type t ^ " due to " ^
Expand Down Expand Up @@ -393,7 +399,6 @@ module DisEnv = struct
| _, Val (VUninitialized _) -> v2
| _, Val (VArray (ar,v)) when Primops.ImmutableArray.is_empty ar -> v2
| _ ->
Printf.printf "Merge %s %s %s\n" k (pp_sym v1) (pp_sym v2);
changed := true;
Val (uninit t1 env)) in
out) l r) in (out, !changed)
Expand All @@ -416,7 +421,7 @@ module DisEnv = struct
(changed,s,empty)

let getFun (loc: l) (x: ident): Eval.fun_sig option rws =
reads (fun env -> Eval.Env.getFunOpt loc env x)
reads (fun config -> Eval.Env.getFunOpt loc config.eval_env x)

let nextVarName (prefix: string): ident rws = fun env s ->
let num, s = LocalEnv.incNumSymbols s in
Expand Down Expand Up @@ -860,6 +865,10 @@ and dis_expr' (loc: l) (x: AST.expr): sym rws =
let+ vs = DisEnv.traverse (fun f -> dis_load_with_type loc (Expr_Field(e,f))) fs in
let vs' = List.map (fun (t,x) -> (width_of_type loc t, x)) vs in
sym_concat loc vs'
| Expr_Slices(e, [s]) ->
let@ e' = dis_expr loc e in
let+ (i,w) = dis_slice loc s in
sym_extract_bits loc e' i w
| Expr_Slices(e, ss) ->
let@ e' = dis_expr loc e in
let+ ss' = DisEnv.traverse (dis_slice loc) ss in
Expand Down Expand Up @@ -916,7 +925,7 @@ and dis_expr' (loc: l) (x: AST.expr): sym rws =
let+ t' = dis_type loc t in
Exp (Expr_Unknown(t'))
| Expr_ImpDef(t, Some(s)) ->
DisEnv.reads (fun env -> Val (Eval.Env.getImpdef loc env s))
DisEnv.reads (fun config -> Val (Eval.Env.getImpdef loc config.eval_env s))
| Expr_ImpDef(t, None) ->
raise (EvalError (loc, "unnamed IMPLEMENTATION_DEFINED behavior"))
| Expr_Array(a,i) -> dis_load loc x
Expand Down Expand Up @@ -1340,11 +1349,14 @@ and dis_stmt' (x: AST.stmt): unit rws =
| Stmt_For(var, start, dir, stop, body, loc) ->
let@ start' = dis_expr loc start in
let@ stop' = dis_expr loc stop in
let@ unroll_bound = DisEnv.reads (fun config -> config.unroll_bound) in

let unrolling =
(match start', stop', dir with
| Val (VInt startval), Val (VInt stopval), Direction_Up -> Z.leq (Z.sub stopval startval) unroll_bound
| Val (VInt startval), Val (VInt stopval), Direction_Down -> Z.leq (Z.sub startval stopval) unroll_bound
| Val (VInt startval), Val (VInt stopval), Direction_Up ->
Z.lt (Z.sub stopval startval) unroll_bound
| Val (VInt startval), Val (VInt stopval), Direction_Down ->
Z.lt (Z.sub startval stopval) unroll_bound
| _ -> false) in

(match unrolling, start', stop' with
Expand Down Expand Up @@ -1376,11 +1388,13 @@ and dis_stmt' (x: AST.stmt): unit rws =
(* Join state after, determining whether anything changed *)
let@ changed = DisEnv.join_locals_fp env pre_env in
(* If a change occurred, run again *)

if changed then loop v else
let start' = sym_expr start' in
let stop' = sym_expr stop' in
DisEnv.write [Stmt_For(var_ident v,start',dir,stop',flatten stmts [],loc)]
in

(* Add the loop variable to state, make it unknown *)
let@ uninit = DisEnv.mkUninit type_integer in
let@ v = DisEnv.stateful (LocalEnv.addLocalVar loc var (Val uninit) type_integer) in
Expand All @@ -1407,7 +1421,7 @@ let dis_encoding (x: encoding) (op: Primops.bigint): bool rws =
(* todo: consider checking iset *)
(* Printf.printf "Checking opcode match %s == %s\n" (Utils.to_string (PP.pp_opcode_value opcode)) (pp_value op); *)
match Eval.eval_opcode_guard loc opcode op with
| Some op ->
| Some op ->
if !Eval.trace_instruction then Printf.printf "TRACE: instruction %s\n" (pprint_ident nm);

let@ () = DisEnv.traverse_ (function (IField_Field (f, lo, wd)) ->
Expand Down Expand Up @@ -1479,7 +1493,7 @@ and dis_decode_alt' (loc: AST.l) (DecoderAlt_Alt (ps, b)) (vs: value list) (op:
| DecoderBody_UNALLOC loc -> raise (Throw (loc, Exc_Undefined))
| DecoderBody_NOP loc -> DisEnv.pure true
| DecoderBody_Encoding (inst, l) ->
let@ (enc, opost, cond, exec) = DisEnv.reads (fun env -> Eval.Env.getInstruction loc env inst) in
let@ (enc, opost, cond, exec) = DisEnv.reads (fun config -> Eval.Env.getInstruction loc config.eval_env inst) in
let@ enc_match = dis_encoding enc op in
if enc_match then begin
(* todo: should evaluate ConditionHolds to decide whether to execute body *)
Expand Down Expand Up @@ -1539,15 +1553,19 @@ let enum_types env i =
| Some l -> Some (Z.log2up (Z.of_int (List.length l)))
| _ -> None

let dis_decode_entry (env: Eval.Env.t) ((lenv,globals): env) (decode: decode_case) (op: Primops.bigint): stmt list =
(* Actually perform dis *)
let dis_core (env: Eval.Env.t) (unroll_bound) ((lenv,globals): env) (decode: decode_case) (op: Primops.bigint): stmt list =
let DecoderCase_Case (_,_,loc) = decode in
let ((),lenv',stmts) = (dis_decode_case loc decode op) env lenv in
let config = { eval_env = env ; unroll_bound } in

let ((),lenv',stmts) = (dis_decode_case loc decode op) config lenv in
let varentries = List.(concat @@ map (fun vars -> StringMap.(bindings (map fst vars))) lenv.locals) in
let bindings = Asl_utils.Bindings.of_seq @@ List.to_seq @@ List.map (fun (x,y) -> (Ident x,y)) varentries in
(* List.iter (fun (v,t) -> Printf.printf ("%s:%s\n") v (pp_type t)) varentries; *)
let stmts = flatten stmts [] in
let stmts' = Transforms.RemoveUnused.remove_unused globals @@ stmts in
let stmts' = Transforms.RedundantSlice.do_transform Bindings.empty stmts' in
let stmts' = Transforms.FixRedefinitions.run (globals : IdentSet.t) stmts' in
let stmts' = Transforms.StatefulIntToBits.run (enum_types env) stmts' in
let stmts' = Transforms.IntToBits.ints_to_bits stmts' in
let stmts' = Transforms.CommonSubExprElim.do_transform stmts' in
Expand All @@ -1556,7 +1574,6 @@ let dis_decode_entry (env: Eval.Env.t) ((lenv,globals): env) (decode: decode_cas
let stmts' = Transforms.RemoveUnused.remove_unused globals @@ stmts' in
let stmts' = Transforms.CaseSimp.do_transform stmts' in
let stmts' = Transforms.RemoveRegisters.run stmts' in
let stmts' = Transforms.FixRedefinitions.run (globals : IdentSet.t) stmts' in

if !debug_level >= 2 then begin
let stmts' = Asl_visitor.visit_stmts (new Asl_utils.resugarClass (!TC.binop_table)) stmts' in
Expand All @@ -1566,6 +1583,17 @@ let dis_decode_entry (env: Eval.Env.t) ((lenv,globals): env) (decode: decode_cas
end;
stmts'

(* Wrapper around the core to attempt loop vectorization, reverting back if this fails.
This is a complete hack, but it is nicer to make the loop unrolling decision during
partial evaluation, rather than having to unroll after we know vectorization failed.
*)
let dis_decode_entry (env: Eval.Env.t) ((lenv,globals): env) (decode: decode_case) (op: Primops.bigint): stmt list =
let unroll_bound = Z.of_int 1 in
let stmts' = dis_core env unroll_bound (lenv,globals) decode op in
let (res,stmts') = Transforms.LoopClassify.run stmts' env in
if res then stmts' else
dis_core env (Z.of_int 1000) (lenv,globals) decode op

let build_env (env: Eval.Env.t): env =
let env = Eval.Env.freeze env in
let lenv = LocalEnv.init env in
Expand Down
2 changes: 1 addition & 1 deletion libASL/dune
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
lexer lexersupport loadASL monad primops rws symbolic tcheck testing transforms value
symbolic_lifter decoder_program call_graph req_analysis
offline_transform ocaml_backend dis_tc offline_opt
arm_env
arm_env pretransforms
)
(preprocessor_deps (alias ../asl_files))
(preprocess (pps ppx_blob))
Expand Down
10 changes: 5 additions & 5 deletions libASL/eval.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ and eval_encoding (env: Env.t) (x: encoding) (op: Primops.bigint): bool =
(* todo: consider checking iset *)
(* Printf.printf "Checking opcode match %s == %s\n" (Utils.to_string (PP.pp_opcode_value opcode)) (pp_value op); *)
match eval_opcode_guard loc opcode op with
| Some op ->
| Some op ->
if !trace_instruction then Printf.printf "TRACE: instruction %s\n" (pprint_ident nm);
List.iter (function (IField_Field (f, lo, wd)) ->
let v = extract_bits' loc op lo wd in
Expand Down Expand Up @@ -1211,7 +1211,7 @@ let build_evaluation_environment (ds: AST.declaration list): Env.t = begin
if false then Printf.printf "Building environment from %d declarations\n" (List.length ds);

(* perform reference parameter transformation. *)
let ds = Transforms.RefParams.ref_param_conversion ds in
let ds = Pretransforms.RefParams.ref_param_conversion ds in

let env = Env.empty in
(* todo?: first pull out the constants/configs and evaluate all of them
Expand Down Expand Up @@ -1346,14 +1346,14 @@ let evaluate_prj_minimal (tcenv: Tcheck.Env.t) (env: Env.t) (source: LoadASL.sou

(** Constructs an evaluation environment with the given prelude file and .asl/.prj files.
.prj files given here are required to be minimal. *)
let evaluation_environment (prelude: LoadASL.source) (files: LoadASL.source list) (verbose: bool) =
let evaluation_environment (prelude: LoadASL.source) (files: LoadASL.source list) (verbose: bool) =
let t = LoadASL.read_file (prelude) true verbose in
let ts = List.map (fun file ->
let filename = LoadASL.name_of_source file in
if Utils.endswith filename ".spec" then begin
LoadASL.read_spec file verbose
LoadASL.read_spec file verbose
end else if Utils.endswith filename ".asl" then begin
LoadASL.read_file file false verbose
LoadASL.read_file file false verbose
end else if Utils.endswith filename ".prj" then begin
[] (* ignore project files here and process later *)
end else begin
Expand Down
174 changes: 174 additions & 0 deletions libASL/pretransforms.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
open Asl_utils

open AST
open Visitor

(** Transforms setters using formal reference (in/out) parameters
into functions returning modified versions of the reference parameters.
*)
module RefParams = struct

(** Filters the given list of sformal, returning a list of
(argument index, type, argument name) with only the ref params. *)
let get_ref_params (xs: sformal list): (int * ty * ident) list =
let xs = List.mapi (fun i x -> (i,x)) xs in
List.filter_map
(fun (n,f) ->
match f with
| Formal_InOut (t,i) -> Some (n,t,i)
| _ -> None)
xs

(** Replaces all procedure returns in the given statement list
with the given statement. *)
let replace_returns ss s =
let visit = object
inherit Asl_visitor.nopAslVisitor
method! vstmt =
function
| Stmt_ProcReturn _ -> ChangeTo s
| Stmt_FunReturn _ -> failwith "unexpected function return in ref param conversion."
| _ -> DoChildren
end
in
Asl_visitor.visit_stmts visit ss

(** Replaces setter declarations which use formal in-out parameters with
functions which return their modified parameters.
For example,
Elem[bits(N) &vector, integer e] = bits(size) value
...
return;
is transformed to
(bits(N)) Elem.read(bits(N) vector, integer e, bits(size) value)
...
return (vector);
*)
class visit_decls = object
inherit Asl_visitor.nopAslVisitor

(* mapping of function identifiers to their (new) signature along with
the indices of their. *)
val mutable ref_params : (Tcheck.funtype * int list) Bindings.t = Bindings.empty

method ref_params = ref_params

method! vdecl (d: declaration): declaration Visitor.visitAction =
match d with
| Decl_ArraySetterDefn (nm, args, vty, vnm, body, loc)->
(match get_ref_params args with
| [] -> DoChildren
| refs ->
(* indices, types, and identifiers for the ref params. *)
let ns = List.map (fun (n,_,_) -> n) refs in
let ts = List.map (fun (_,t,_) -> t) refs in
let is = List.map (fun (_,_,i) -> i) refs in

(* append setter value argument to formal argument list. *)
let args' = List.map Tcheck.formal_of_sformal args @ [vty, vnm] in


(* construct return expression to return modified ref vars. *)
let vars = List.map (fun x -> Expr_Var x) is in
let ret = Stmt_FunReturn ((match vars with [x] -> x | _ -> Expr_Tuple vars), loc) in
let body' = replace_returns body [ret] in

let rty = match ts with [t] -> t | _ -> Type_Tuple ts in
let funty = (nm, false, [], [], List.map Asl_visitor.arg_of_sformal args @ [(vty, vnm)], rty) in
ref_params <- Bindings.add nm (funty,ns) ref_params;
ChangeTo (Decl_FunDefn (rty, nm, args', body', loc))
)
| _ -> DoChildren
end

(** Replaces writes to the setters modified above to assign
the return value back to the original variables.
For example,
Elem[vector, 2] = '1001';
is transformed to
vector = Elem.read(vector, 2, '1001');
*)
class visit_writes (ref_params: (Tcheck.funtype * int list) Bindings.t) = object
inherit Asl_visitor.nopAslVisitor

val mutable n = 0;

method! vstmt (s: stmt) =
match s with
| Stmt_Assign (LExpr_Write (setter, targs, args), r, loc) ->
(match Bindings.find_opt setter ref_params with
| None -> DoChildren
| Some (_,ns) ->
let refs = List.map (List.nth args) ns in
(* Printf.printf "ref param: %s\n" (pp_expr a); *)

let les = List.map Symbolic.expr_to_lexpr refs in
let call = Expr_TApply (setter, targs, args @ [r]) in
ChangeTo [Stmt_Assign ((match les with [x] -> x | _ -> LExpr_Tuple les), call, loc)]
)
(* case where a write expression is used within a tuple destructuring. *)
| Stmt_Assign (LExpr_Tuple(LExpr_Write (setter, tes, es) :: rest), r, loc) ->
(match Bindings.find_opt setter ref_params with
| None -> DoChildren
| Some ((nm, _, _, _, args, _),ns) ->

n <- n + 1;
(* create new variable to store value to be passed to setter. *)
let rvar = Ident ("Write_" ^ pprint_ident (stripTag setter) ^ string_of_int n) in
(* arguments to setter function appended with r-value. *)
let es' = es @ [Expr_Var rvar] in

(* infer value argument type of setter by substituting arguments into
the last type argument. *)
let subs = List.combine (List.map snd args) es' in
let sub_bindings = Bindings.of_seq (List.to_seq subs) in
let (vty,_) = List.hd (List.rev args) in
let vty = subst_type sub_bindings vty in

(* emit: vty rvar declaration *)
let decl_var = Stmt_VarDeclsNoInit (vty, [rvar], loc) in
(* emit: (rvar, ...) = r *)
let assign_tuple = Stmt_Assign (LExpr_Tuple (LExpr_Var rvar :: rest), r, loc) in

let refs = List.map (List.nth es') ns in
let les = List.map Symbolic.expr_to_lexpr refs in
let write_call = Expr_TApply (setter, tes, es') in
(* emit: (refparams) = __write(es, rvar) *)
let assign_write = Stmt_Assign ((match les with [x] -> x | _ -> LExpr_Tuple les), write_call, loc) in

let x =
[decl_var; assign_tuple; assign_write]
in
ChangeTo x
)
| _ -> DoChildren

method! vlexpr le =
match le with
| LExpr_Write (nm, _, _) when Bindings.mem nm ref_params ->
failwith @@ "unexpected write using parameters by reference: " ^ pp_lexpr le
| _ -> DoChildren
end

let ref_param_conversion (ds: declaration list) =
let v1 = new visit_decls in
let ds = List.map (Asl_visitor.visit_decl (v1 :> Asl_visitor.aslVisitor)) ds in
let v2 = new visit_writes (v1#ref_params) in
let ds = List.map (Asl_visitor.visit_decl v2) ds in
ds
(* Tcheck.GlobalEnv.clear Tcheck.env0;
Tcheck.tc_declarations false ds *)
end


Loading

0 comments on commit 17b0aa9

Please sign in to comment.