diff --git a/libASL/dis.ml b/libASL/dis.ml index 13645d85..b5417a5a 100644 --- a/libASL/dis.ml +++ b/libASL/dis.ml @@ -23,7 +23,6 @@ module StringCmp = struct end module StringMap = Map.Make(StringCmp) -let unroll_bound = Z.of_int 2 let debug_level_none = -1 let debug_level = ref debug_level_none @@ -116,6 +115,8 @@ let no_inline_pure = [ "ASR",0; "SignExtend",0; "ZeroExtend",0; + "Elem.set",0; + "Elem.read",0; ] (** A variable's stack level and original identifier name. @@ -327,9 +328,14 @@ let rec flatten x acc = flatten x acc | Node i -> i@acc +type config = { + eval_env: Eval.Env.t; + unroll_bound: Z.t; +} + module DisEnv = struct include Rws.Make(struct - type r = Eval.Env.t + type r = config type w = tree type s = LocalEnv.t let mempty = empty @@ -342,9 +348,9 @@ module DisEnv = struct let (_,v) = LocalEnv.resolveGetVar loc x s in (v,s,empty) - let uninit (t: ty) (env: Eval.Env.t): value = + let uninit (t: ty) (config): value = try - Eval.mk_uninitialized Unknown env t + Eval.mk_uninitialized Unknown config.eval_env t with e -> unsupported Unknown @@ "mkUninit: failed to evaluate type " ^ pp_type t ^ " due to " ^ @@ -393,7 +399,6 @@ module DisEnv = struct | _, Val (VUninitialized _) -> v2 | _, Val (VArray (ar,v)) when Primops.ImmutableArray.is_empty ar -> v2 | _ -> - Printf.printf "Merge %s %s %s\n" k (pp_sym v1) (pp_sym v2); changed := true; Val (uninit t1 env)) in out) l r) in (out, !changed) @@ -416,7 +421,7 @@ module DisEnv = struct (changed,s,empty) let getFun (loc: l) (x: ident): Eval.fun_sig option rws = - reads (fun env -> Eval.Env.getFunOpt loc env x) + reads (fun config -> Eval.Env.getFunOpt loc config.eval_env x) let nextVarName (prefix: string): ident rws = fun env s -> let num, s = LocalEnv.incNumSymbols s in @@ -860,6 +865,10 @@ and dis_expr' (loc: l) (x: AST.expr): sym rws = let+ vs = DisEnv.traverse (fun f -> dis_load_with_type loc (Expr_Field(e,f))) fs in let vs' = List.map (fun (t,x) -> (width_of_type loc t, x)) vs in sym_concat loc vs' + | Expr_Slices(e, [s]) -> + let@ e' = dis_expr loc e in + let+ (i,w) = dis_slice loc s in + sym_extract_bits loc e' i w | Expr_Slices(e, ss) -> let@ e' = dis_expr loc e in let+ ss' = DisEnv.traverse (dis_slice loc) ss in @@ -916,7 +925,7 @@ and dis_expr' (loc: l) (x: AST.expr): sym rws = let+ t' = dis_type loc t in Exp (Expr_Unknown(t')) | Expr_ImpDef(t, Some(s)) -> - DisEnv.reads (fun env -> Val (Eval.Env.getImpdef loc env s)) + DisEnv.reads (fun config -> Val (Eval.Env.getImpdef loc config.eval_env s)) | Expr_ImpDef(t, None) -> raise (EvalError (loc, "unnamed IMPLEMENTATION_DEFINED behavior")) | Expr_Array(a,i) -> dis_load loc x @@ -1340,11 +1349,14 @@ and dis_stmt' (x: AST.stmt): unit rws = | Stmt_For(var, start, dir, stop, body, loc) -> let@ start' = dis_expr loc start in let@ stop' = dis_expr loc stop in + let@ unroll_bound = DisEnv.reads (fun config -> config.unroll_bound) in let unrolling = (match start', stop', dir with - | Val (VInt startval), Val (VInt stopval), Direction_Up -> Z.leq (Z.sub stopval startval) unroll_bound - | Val (VInt startval), Val (VInt stopval), Direction_Down -> Z.leq (Z.sub startval stopval) unroll_bound + | Val (VInt startval), Val (VInt stopval), Direction_Up -> + Z.lt (Z.sub stopval startval) unroll_bound + | Val (VInt startval), Val (VInt stopval), Direction_Down -> + Z.lt (Z.sub startval stopval) unroll_bound | _ -> false) in (match unrolling, start', stop' with @@ -1376,11 +1388,13 @@ and dis_stmt' (x: AST.stmt): unit rws = (* Join state after, determining whether anything changed *) let@ changed = DisEnv.join_locals_fp env pre_env in (* If a change occurred, run again *) + if changed then loop v else let start' = sym_expr start' in let stop' = sym_expr stop' in DisEnv.write [Stmt_For(var_ident v,start',dir,stop',flatten stmts [],loc)] in + (* Add the loop variable to state, make it unknown *) let@ uninit = DisEnv.mkUninit type_integer in let@ v = DisEnv.stateful (LocalEnv.addLocalVar loc var (Val uninit) type_integer) in @@ -1407,7 +1421,7 @@ let dis_encoding (x: encoding) (op: Primops.bigint): bool rws = (* todo: consider checking iset *) (* Printf.printf "Checking opcode match %s == %s\n" (Utils.to_string (PP.pp_opcode_value opcode)) (pp_value op); *) match Eval.eval_opcode_guard loc opcode op with - | Some op -> + | Some op -> if !Eval.trace_instruction then Printf.printf "TRACE: instruction %s\n" (pprint_ident nm); let@ () = DisEnv.traverse_ (function (IField_Field (f, lo, wd)) -> @@ -1479,7 +1493,7 @@ and dis_decode_alt' (loc: AST.l) (DecoderAlt_Alt (ps, b)) (vs: value list) (op: | DecoderBody_UNALLOC loc -> raise (Throw (loc, Exc_Undefined)) | DecoderBody_NOP loc -> DisEnv.pure true | DecoderBody_Encoding (inst, l) -> - let@ (enc, opost, cond, exec) = DisEnv.reads (fun env -> Eval.Env.getInstruction loc env inst) in + let@ (enc, opost, cond, exec) = DisEnv.reads (fun config -> Eval.Env.getInstruction loc config.eval_env inst) in let@ enc_match = dis_encoding enc op in if enc_match then begin (* todo: should evaluate ConditionHolds to decide whether to execute body *) @@ -1539,15 +1553,19 @@ let enum_types env i = | Some l -> Some (Z.log2up (Z.of_int (List.length l))) | _ -> None -let dis_decode_entry (env: Eval.Env.t) ((lenv,globals): env) (decode: decode_case) (op: Primops.bigint): stmt list = +(* Actually perform dis *) +let dis_core (env: Eval.Env.t) (unroll_bound) ((lenv,globals): env) (decode: decode_case) (op: Primops.bigint): stmt list = let DecoderCase_Case (_,_,loc) = decode in - let ((),lenv',stmts) = (dis_decode_case loc decode op) env lenv in + let config = { eval_env = env ; unroll_bound } in + + let ((),lenv',stmts) = (dis_decode_case loc decode op) config lenv in let varentries = List.(concat @@ map (fun vars -> StringMap.(bindings (map fst vars))) lenv.locals) in let bindings = Asl_utils.Bindings.of_seq @@ List.to_seq @@ List.map (fun (x,y) -> (Ident x,y)) varentries in (* List.iter (fun (v,t) -> Printf.printf ("%s:%s\n") v (pp_type t)) varentries; *) let stmts = flatten stmts [] in let stmts' = Transforms.RemoveUnused.remove_unused globals @@ stmts in let stmts' = Transforms.RedundantSlice.do_transform Bindings.empty stmts' in + let stmts' = Transforms.FixRedefinitions.run (globals : IdentSet.t) stmts' in let stmts' = Transforms.StatefulIntToBits.run (enum_types env) stmts' in let stmts' = Transforms.IntToBits.ints_to_bits stmts' in let stmts' = Transforms.CommonSubExprElim.do_transform stmts' in @@ -1556,7 +1574,6 @@ let dis_decode_entry (env: Eval.Env.t) ((lenv,globals): env) (decode: decode_cas let stmts' = Transforms.RemoveUnused.remove_unused globals @@ stmts' in let stmts' = Transforms.CaseSimp.do_transform stmts' in let stmts' = Transforms.RemoveRegisters.run stmts' in - let stmts' = Transforms.FixRedefinitions.run (globals : IdentSet.t) stmts' in if !debug_level >= 2 then begin let stmts' = Asl_visitor.visit_stmts (new Asl_utils.resugarClass (!TC.binop_table)) stmts' in @@ -1566,6 +1583,17 @@ let dis_decode_entry (env: Eval.Env.t) ((lenv,globals): env) (decode: decode_cas end; stmts' +(* Wrapper around the core to attempt loop vectorization, reverting back if this fails. + This is a complete hack, but it is nicer to make the loop unrolling decision during + partial evaluation, rather than having to unroll after we know vectorization failed. + *) +let dis_decode_entry (env: Eval.Env.t) ((lenv,globals): env) (decode: decode_case) (op: Primops.bigint): stmt list = + let unroll_bound = Z.of_int 1 in + let stmts' = dis_core env unroll_bound (lenv,globals) decode op in + let (res,stmts') = Transforms.LoopClassify.run stmts' env in + if res then stmts' else + dis_core env (Z.of_int 1000) (lenv,globals) decode op + let build_env (env: Eval.Env.t): env = let env = Eval.Env.freeze env in let lenv = LocalEnv.init env in diff --git a/libASL/dune b/libASL/dune index 0bb060c8..dd6bf933 100644 --- a/libASL/dune +++ b/libASL/dune @@ -39,7 +39,7 @@ lexer lexersupport loadASL monad primops rws symbolic tcheck testing transforms value symbolic_lifter decoder_program call_graph req_analysis offline_transform ocaml_backend dis_tc offline_opt - arm_env + arm_env pretransforms ) (preprocessor_deps (alias ../asl_files)) (preprocess (pps ppx_blob)) diff --git a/libASL/eval.ml b/libASL/eval.ml index d06c4b0b..ffca6e57 100644 --- a/libASL/eval.ml +++ b/libASL/eval.ml @@ -1179,7 +1179,7 @@ and eval_encoding (env: Env.t) (x: encoding) (op: Primops.bigint): bool = (* todo: consider checking iset *) (* Printf.printf "Checking opcode match %s == %s\n" (Utils.to_string (PP.pp_opcode_value opcode)) (pp_value op); *) match eval_opcode_guard loc opcode op with - | Some op -> + | Some op -> if !trace_instruction then Printf.printf "TRACE: instruction %s\n" (pprint_ident nm); List.iter (function (IField_Field (f, lo, wd)) -> let v = extract_bits' loc op lo wd in @@ -1211,7 +1211,7 @@ let build_evaluation_environment (ds: AST.declaration list): Env.t = begin if false then Printf.printf "Building environment from %d declarations\n" (List.length ds); (* perform reference parameter transformation. *) - let ds = Transforms.RefParams.ref_param_conversion ds in + let ds = Pretransforms.RefParams.ref_param_conversion ds in let env = Env.empty in (* todo?: first pull out the constants/configs and evaluate all of them @@ -1346,14 +1346,14 @@ let evaluate_prj_minimal (tcenv: Tcheck.Env.t) (env: Env.t) (source: LoadASL.sou (** Constructs an evaluation environment with the given prelude file and .asl/.prj files. .prj files given here are required to be minimal. *) -let evaluation_environment (prelude: LoadASL.source) (files: LoadASL.source list) (verbose: bool) = +let evaluation_environment (prelude: LoadASL.source) (files: LoadASL.source list) (verbose: bool) = let t = LoadASL.read_file (prelude) true verbose in let ts = List.map (fun file -> let filename = LoadASL.name_of_source file in if Utils.endswith filename ".spec" then begin - LoadASL.read_spec file verbose + LoadASL.read_spec file verbose end else if Utils.endswith filename ".asl" then begin - LoadASL.read_file file false verbose + LoadASL.read_file file false verbose end else if Utils.endswith filename ".prj" then begin [] (* ignore project files here and process later *) end else begin diff --git a/libASL/pretransforms.ml b/libASL/pretransforms.ml new file mode 100644 index 00000000..23630bb6 --- /dev/null +++ b/libASL/pretransforms.ml @@ -0,0 +1,174 @@ +open Asl_utils + +open AST +open Visitor + +(** Transforms setters using formal reference (in/out) parameters + into functions returning modified versions of the reference parameters. +*) +module RefParams = struct + + (** Filters the given list of sformal, returning a list of + (argument index, type, argument name) with only the ref params. *) + let get_ref_params (xs: sformal list): (int * ty * ident) list = + let xs = List.mapi (fun i x -> (i,x)) xs in + List.filter_map + (fun (n,f) -> + match f with + | Formal_InOut (t,i) -> Some (n,t,i) + | _ -> None) + xs + + (** Replaces all procedure returns in the given statement list + with the given statement. *) + let replace_returns ss s = + let visit = object + inherit Asl_visitor.nopAslVisitor + method! vstmt = + function + | Stmt_ProcReturn _ -> ChangeTo s + | Stmt_FunReturn _ -> failwith "unexpected function return in ref param conversion." + | _ -> DoChildren + end + in + Asl_visitor.visit_stmts visit ss + + (** Replaces setter declarations which use formal in-out parameters with + functions which return their modified parameters. + + For example, + + Elem[bits(N) &vector, integer e] = bits(size) value + ... + return; + + is transformed to + + (bits(N)) Elem.read(bits(N) vector, integer e, bits(size) value) + ... + return (vector); + + + *) + class visit_decls = object + inherit Asl_visitor.nopAslVisitor + + (* mapping of function identifiers to their (new) signature along with + the indices of their. *) + val mutable ref_params : (Tcheck.funtype * int list) Bindings.t = Bindings.empty + + method ref_params = ref_params + + method! vdecl (d: declaration): declaration Visitor.visitAction = + match d with + | Decl_ArraySetterDefn (nm, args, vty, vnm, body, loc)-> + (match get_ref_params args with + | [] -> DoChildren + | refs -> + (* indices, types, and identifiers for the ref params. *) + let ns = List.map (fun (n,_,_) -> n) refs in + let ts = List.map (fun (_,t,_) -> t) refs in + let is = List.map (fun (_,_,i) -> i) refs in + + (* append setter value argument to formal argument list. *) + let args' = List.map Tcheck.formal_of_sformal args @ [vty, vnm] in + + + (* construct return expression to return modified ref vars. *) + let vars = List.map (fun x -> Expr_Var x) is in + let ret = Stmt_FunReturn ((match vars with [x] -> x | _ -> Expr_Tuple vars), loc) in + let body' = replace_returns body [ret] in + + let rty = match ts with [t] -> t | _ -> Type_Tuple ts in + let funty = (nm, false, [], [], List.map Asl_visitor.arg_of_sformal args @ [(vty, vnm)], rty) in + ref_params <- Bindings.add nm (funty,ns) ref_params; + ChangeTo (Decl_FunDefn (rty, nm, args', body', loc)) + ) + | _ -> DoChildren + end + + (** Replaces writes to the setters modified above to assign + the return value back to the original variables. + + For example, + + Elem[vector, 2] = '1001'; + + is transformed to + + vector = Elem.read(vector, 2, '1001'); + + *) + class visit_writes (ref_params: (Tcheck.funtype * int list) Bindings.t) = object + inherit Asl_visitor.nopAslVisitor + + val mutable n = 0; + + method! vstmt (s: stmt) = + match s with + | Stmt_Assign (LExpr_Write (setter, targs, args), r, loc) -> + (match Bindings.find_opt setter ref_params with + | None -> DoChildren + | Some (_,ns) -> + let refs = List.map (List.nth args) ns in + (* Printf.printf "ref param: %s\n" (pp_expr a); *) + + let les = List.map Symbolic.expr_to_lexpr refs in + let call = Expr_TApply (setter, targs, args @ [r]) in + ChangeTo [Stmt_Assign ((match les with [x] -> x | _ -> LExpr_Tuple les), call, loc)] + ) + (* case where a write expression is used within a tuple destructuring. *) + | Stmt_Assign (LExpr_Tuple(LExpr_Write (setter, tes, es) :: rest), r, loc) -> + (match Bindings.find_opt setter ref_params with + | None -> DoChildren + | Some ((nm, _, _, _, args, _),ns) -> + + n <- n + 1; + (* create new variable to store value to be passed to setter. *) + let rvar = Ident ("Write_" ^ pprint_ident (stripTag setter) ^ string_of_int n) in + (* arguments to setter function appended with r-value. *) + let es' = es @ [Expr_Var rvar] in + + (* infer value argument type of setter by substituting arguments into + the last type argument. *) + let subs = List.combine (List.map snd args) es' in + let sub_bindings = Bindings.of_seq (List.to_seq subs) in + let (vty,_) = List.hd (List.rev args) in + let vty = subst_type sub_bindings vty in + + (* emit: vty rvar declaration *) + let decl_var = Stmt_VarDeclsNoInit (vty, [rvar], loc) in + (* emit: (rvar, ...) = r *) + let assign_tuple = Stmt_Assign (LExpr_Tuple (LExpr_Var rvar :: rest), r, loc) in + + let refs = List.map (List.nth es') ns in + let les = List.map Symbolic.expr_to_lexpr refs in + let write_call = Expr_TApply (setter, tes, es') in + (* emit: (refparams) = __write(es, rvar) *) + let assign_write = Stmt_Assign ((match les with [x] -> x | _ -> LExpr_Tuple les), write_call, loc) in + + let x = + [decl_var; assign_tuple; assign_write] + in + ChangeTo x + ) + | _ -> DoChildren + + method! vlexpr le = + match le with + | LExpr_Write (nm, _, _) when Bindings.mem nm ref_params -> + failwith @@ "unexpected write using parameters by reference: " ^ pp_lexpr le + | _ -> DoChildren + end + + let ref_param_conversion (ds: declaration list) = + let v1 = new visit_decls in + let ds = List.map (Asl_visitor.visit_decl (v1 :> Asl_visitor.aslVisitor)) ds in + let v2 = new visit_writes (v1#ref_params) in + let ds = List.map (Asl_visitor.visit_decl v2) ds in + ds + (* Tcheck.GlobalEnv.clear Tcheck.env0; + Tcheck.tc_declarations false ds *) +end + + diff --git a/libASL/symbolic.ml b/libASL/symbolic.ml index f597b28b..b9f7d79b 100644 --- a/libASL/symbolic.ml +++ b/libASL/symbolic.ml @@ -549,11 +549,16 @@ and sym_slice (loc: l) (x: sym) (lo: int) (wd: int): sym = (** Wrapper around sym_slice to handle cases of symbolic slice bounds *) let sym_extract_bits loc v i w = - match ( i, w) with - | (Val i', Val w') -> + match (v, i, w) with + (* Constant slice *) + | _, Val i', Val w' -> let i' = to_int loc i' in let w' = to_int loc w' in sym_slice loc v i' w' + (* Nested slice *) + | Exp (Expr_Slices (e, [Slice_LoWd (lo,wd)])), lo', wd' -> + let lo = sym_add_int loc (Exp lo) lo' in + Exp (Expr_Slices (e, [Slice_LoWd (sym_expr lo, sym_expr wd')])) | _ -> Exp (Expr_Slices (sym_expr v, [Slice_LoWd (sym_expr i, sym_expr w)])) let sym_zero_extend num_zeros old_width e = @@ -589,6 +594,12 @@ let sym_lsl_bits loc w x y = | _ -> sym_prim (FIdent ("LSL", 0)) [sym_of_int w] [x;y] +let zdiv_int x y = + match x, y with + | Val (VInt i), Val (VInt j) -> Val (VInt (Z.div i j)) + | _, Val (VInt i) when i = Z.one -> x + | _ -> Exp (Expr_TApply (FIdent ("sdiv_int", 0), [], [sym_expr x; sym_expr y])) + (** Overwrite bits from position lo up to (lo+wd) exclusive of old with the value v. Needs to know the widths of both old and v to perform the operation. Assumes width of v is equal to wd. @@ -606,6 +617,10 @@ let sym_insert_bits loc (old_width: int) (old: sym) (lo: sym) (wd: sym) (v: sym) else sym_append_bits loc (old_width - up) up (sym_slice loc old up (old_width - up)) (sym_append_bits loc wd lo v (sym_slice loc old 0 lo)) + | (_, _, Val wd', _) when Primops.prim_zrem_int (Z.of_int old_width) (to_integer Unknown wd') = Z.zero -> + (* Elem.set *) + let pos = zdiv_int lo wd in + Exp ( Expr_TApply (FIdent("Elem.set", 0), [expr_of_int old_width ; sym_expr wd], List.map sym_expr [old ; pos ; wd ; v]) ) | (_, _, Val wd', _) -> (* Build an insert out of bitvector masking operations *) let wd = to_int loc wd' in @@ -665,6 +680,25 @@ let is_insert_mask (b: bitvector): (int * int) option = end | _ -> None +(* +let rec elem_read_collapse vw ew v j = + match v with + | Expr_TApply (FIdent ("Elem.set", 0), [Expr_LitInt vw'; Expr_LitInt ew'], [v; Expr_LitInt i; _; e]) + when vw = Z.of_string vw' && ew = Z.of_string ew' && Z.of_string i = j -> + e + | Expr_TApply (FIdent ("Elem.set", 0), [Expr_LitInt vw'; Expr_LitInt ew'], [v; Expr_LitInt i; _; e]) + when vw = Z.of_string vw' && ew = Z.of_string ew' && Z.of_string i <> j -> + elem_read_collapse vw ew v j + | Expr_Slices (v, [Slice_LoWd(Expr_LitInt lo, Expr_LitInt wd)]) + when Z.equal (Z.rem (Z.of_string lo) ew) Z.zero -> + elem_read_collapse (Z.add vw vw) ew v (Z.add j (Z.div (Z.of_string lo) ew)) + | _ -> + (Expr_TApply (FIdent ("Elem.read", 0), [Expr_LitInt (Z.to_string vw); Expr_LitInt (Z.to_string ew)], [v; Expr_LitInt (Z.to_string j); Expr_LitInt (Z.to_string ew)])) +*) + (*| ("Elem.read", [Val (VInt vw); Val (VInt ew)], [Exp v; Val (VInt j); _]) -> + let e = elem_read_collapse vw ew v j in + Some (Exp e)*) + let sym_prim_simplify (name: string) (tes: sym list) (es: sym list): sym option = let loc = Unknown in @@ -680,7 +714,6 @@ let sym_prim_simplify (name: string) (tes: sym list) (es: sym list): sym option | ("sub_int", _, [x1; x2]) -> Some (sym_sub_int loc x1 x2) - | ("mul_int", _, [Val x1; x2]) when is_one x1 -> Some x2 | ("mul_int", _, [x1; Val x2]) when is_one x2 -> Some x1 | ("mul_int", _, [Exp (Expr_TApply (FIdent ("add_int", 0), [], [x1; Expr_LitInt v])); Val (VInt v2)]) -> diff --git a/libASL/symbolic_lifter.ml b/libASL/symbolic_lifter.ml index 6488855c..25d99f43 100644 --- a/libASL/symbolic_lifter.ml +++ b/libASL/symbolic_lifter.ml @@ -296,8 +296,9 @@ let dis_wrapper fn fnsig env = try let body = fnsig_get_body fnsig in let sym = Symbolic.Exp (Expr_Var (Decoder_program.enc)) in - let (_,lenv,_) = (Dis.declare_assign_var Unknown (Type_Bits (Expr_LitInt "32")) (Ident "enc") sym) env lenv in - let ((),lenv',stmts) = (Dis.dis_stmts body) env lenv in + let config = {Dis.eval_env = env ; unroll_bound = Z.zero} in + let (_,lenv,_) = (Dis.declare_assign_var Unknown (Type_Bits (Expr_LitInt "32")) (Ident "enc") sym) config lenv in + let ((),lenv',stmts) = (Dis.dis_stmts body) config lenv in let globals = IdentSet.diff globals dead_globals in let stmts = Dis.flatten stmts [] in let stmts' = Transforms.RemoveUnused.remove_unused globals @@ stmts in diff --git a/libASL/transforms.ml b/libASL/transforms.ml index 4e9a939a..511ce02f 100644 --- a/libASL/transforms.ml +++ b/libASL/transforms.ml @@ -189,178 +189,6 @@ module RemoveUnused = struct ) xs ([], used) end - - -(** Transforms setters using formal reference (in/out) parameters - into functions returning modified versions of the reference parameters. -*) -module RefParams = struct - - (** Filters the given list of sformal, returning a list of - (argument index, type, argument name) with only the ref params. *) - let get_ref_params (xs: sformal list): (int * ty * ident) list = - let xs = List.mapi (fun i x -> (i,x)) xs in - List.filter_map - (fun (n,f) -> - match f with - | Formal_InOut (t,i) -> Some (n,t,i) - | _ -> None) - xs - - (** Replaces all procedure returns in the given statement list - with the given statement. *) - let replace_returns ss s = - let visit = object - inherit Asl_visitor.nopAslVisitor - method! vstmt = - function - | Stmt_ProcReturn _ -> ChangeTo [s] - | Stmt_FunReturn _ -> failwith "unexpected function return in ref param conversion." - | _ -> DoChildren - end - in - Asl_visitor.visit_stmts visit ss - - (** Replaces setter declarations which use formal in-out parameters with - functions which return their modified parameters. - - For example, - - Elem[bits(N) &vector, integer e] = bits(size) value - ... - return; - - is transformed to - - (bits(N)) Elem.read(bits(N) vector, integer e, bits(size) value) - ... - return (vector); - - - *) - class visit_decls = object - inherit Asl_visitor.nopAslVisitor - - (* mapping of function identifiers to their (new) signature along with - the indices of their. *) - val mutable ref_params : (Tcheck.funtype * int list) Bindings.t = Bindings.empty - - method ref_params = ref_params - - method! vdecl (d: declaration): declaration visitAction = - match d with - | Decl_ArraySetterDefn (nm, args, vty, vnm, body, loc)-> - (match get_ref_params args with - | [] -> DoChildren - | refs -> - (* indices, types, and identifiers for the ref params. *) - let ns = List.map (fun (n,_,_) -> n) refs in - let ts = List.map (fun (_,t,_) -> t) refs in - let is = List.map (fun (_,_,i) -> i) refs in - - (* append setter value argument to formal argument list. *) - let args' = List.map Tcheck.formal_of_sformal args @ [vty, vnm] in - - (* construct return expression to return modified ref vars. *) - let vars = List.map (fun x -> Expr_Var x) is in - let ret = Stmt_FunReturn (Expr_Tuple vars, loc) in - let body' = replace_returns body ret in - - let rty = Type_Tuple ts in - let funty = (nm, false, [], [], List.map arg_of_sformal args @ [(vty, vnm)], rty) in - ref_params <- Bindings.add nm (funty,ns) ref_params; - ChangeTo (Decl_FunDefn (rty, nm, args', body', loc)) - ) - | _ -> DoChildren - end - - (** Replaces writes to the setters modified above to assign - the return value back to the original variables. - - For example, - - Elem[vector, 2] = '1001'; - - is transformed to - - vector = Elem.read(vector, 2, '1001'); - - *) - class visit_writes (ref_params: (Tcheck.funtype * int list) Bindings.t) = object - inherit Asl_visitor.nopAslVisitor - - val mutable n = 0; - - method! vstmt (s: stmt): stmt list visitAction = - singletonVisitAction @@ match s with - | Stmt_Assign (LExpr_Write (setter, targs, args), r, loc) -> - (match Bindings.find_opt setter ref_params with - | None -> DoChildren - | Some (_,ns) -> - let refs = List.map (List.nth args) ns in - (* Printf.printf "ref param: %s\n" (pp_expr a); *) - - let les = List.map Symbolic.expr_to_lexpr refs in - let call = Expr_TApply (setter, targs, args @ [r]) in - ChangeTo (Stmt_Assign (LExpr_Tuple les, call, loc)) - ) - (* case where a write expression is used within a tuple destructuring. *) - | Stmt_Assign (LExpr_Tuple(LExpr_Write (setter, tes, es) :: rest), r, loc) -> - (match Bindings.find_opt setter ref_params with - | None -> DoChildren - | Some ((nm, _, _, _, args, _),ns) -> - - n <- n + 1; - (* create new variable to store value to be passed to setter. *) - let rvar = Ident ("Write_" ^ pprint_ident (stripTag setter) ^ string_of_int n) in - (* arguments to setter function appended with r-value. *) - let es' = es @ [Expr_Var rvar] in - - (* infer value argument type of setter by substituting arguments into - the last type argument. *) - let subs = List.combine (List.map snd args) es' in - let sub_bindings = Bindings.of_seq (List.to_seq subs) in - let (vty,_) = List.hd (List.rev args) in - let vty = subst_type sub_bindings vty in - - (* emit: vty rvar declaration *) - let decl_var = Stmt_VarDeclsNoInit (vty, [rvar], loc) in - (* emit: (rvar, ...) = r *) - let assign_tuple = Stmt_Assign (LExpr_Tuple (LExpr_Var rvar :: rest), r, loc) in - - let refs = List.map (List.nth es') ns in - let les = List.map Symbolic.expr_to_lexpr refs in - let write_call = Expr_TApply (setter, tes, es') in - (* emit: (refparams) = __write(es, rvar) *) - let assign_write = Stmt_Assign (LExpr_Tuple les, write_call, loc) in - - let x = (Stmt_If ( - expr_true, - [decl_var; assign_tuple; assign_write], - [], - [], - loc)) in - ChangeTo x - ) - | _ -> DoChildren - - method! vlexpr le = - match le with - | LExpr_Write (nm, _, _) when Bindings.mem nm ref_params -> - failwith @@ "unexpected write using parameters by reference: " ^ pp_lexpr le - | _ -> DoChildren - end - - let ref_param_conversion (ds: declaration list) = - let v1 = new visit_decls in - let ds = List.map (Asl_visitor.visit_decl (v1 :> Asl_visitor.aslVisitor)) ds in - let v2 = new visit_writes (v1#ref_params) in - let ds = List.map (Asl_visitor.visit_decl v2) ds in - ds - (* Tcheck.GlobalEnv.clear Tcheck.env0; - Tcheck.tc_declarations false ds *) -end - module StatefulIntToBits = struct type interval = (Z.t * Z.t) type abs = (int * bool * interval) @@ -379,7 +207,8 @@ module StatefulIntToBits = struct else let u' = if Z.gt u Z.zero then 1 + (Z.log2up (Z.succ u)) else 1 in let l' = if Z.lt l Z.zero then 1 + (Z.log2up (Z.neg l)) else 1 in - (max u' l',true) + let i = max u' l' in + (i,true) (** Build an abstract point to represent a constant integer *) let abs_of_const (c: Z.t): abs = @@ -579,9 +408,8 @@ module StatefulIntToBits = struct let x = bv_of_int_expr st x in let w = abs_of_uop (snd x) Primops.prim_neg_int in let ex = extend w in - let f = sym_prim (FIdent ("not_bits", 0)) [sym_of_abs w] [ex x] in - let offset = Val (VBits {v=Z.one; n=width w}) in - let f = sym_prim (FIdent ("add_bits", 0)) [sym_of_abs w] [f; offset] in + let z = Val (VBits {v=Z.zero; n=width w}) in + let f = sym_prim (FIdent ("sub_bits", 0)) [sym_of_abs w] [z; ex x] in (f,w) (* TODO: Somewhat haphazard translation from old approach *) @@ -741,7 +569,7 @@ module StatefulIntToBits = struct (match Bindings.find_opt v vars with | Some w -> (*Printf.printf "transform_int_expr: Found root var: %s\n" (match v with Ident s -> s | _ -> "");*) - let prim = if signed w then "cvt_bits_int" else "cvt_bits_uint" in + let prim = if signed w then "cvt_bits_sint" else "cvt_bits_uint" in ChangeTo (expr_prim' prim [expr_of_abs w] [e]) | None -> SkipChildren) | _ -> DoChildren @@ -767,7 +595,8 @@ module StatefulIntToBits = struct let (w,s,_) = merge_abs i j in let m = (w,s,interval i) in {st with changed = true ; vars = Bindings.add v m st.vars} - | None -> {st with changed = true ; vars = Bindings.add v i st.vars} + | None -> + {st with changed = true ; vars = Bindings.add v i st.vars} (** Same as above, but keep as int TODO: This shouldn't be necessary, simplify in future. *) let assign_int (v: ident) (i: abs) (st: state): state = @@ -782,7 +611,8 @@ module StatefulIntToBits = struct let (w,s,_) = merge_abs i j in let m = (w,s,interval i) in {st with changed = true ; ints = Bindings.add v m st.ints} - | None -> {st with changed = true ; ints = Bindings.add v i st.ints} + | None -> + {st with changed = true ; ints = Bindings.add v i st.ints} (** Simple test of existence in state *) let tracked (v: ident) (st: state): bool = @@ -989,6 +819,8 @@ module IntToBits = struct | FIdent ("replicate_bits", 0), [Expr_LitInt n; Expr_LitInt m], _ -> int_of_string n * int_of_string m | FIdent ("ZeroExtend", 0), [_; Expr_LitInt m], _ | FIdent ("SignExtend", 0), [_; Expr_LitInt m], _ -> int_of_string m + | FIdent ("Elem.read", 0), [_; Expr_LitInt m], _ -> int_of_string m + | FIdent ("Elem.set", 0), [Expr_LitInt v;_], _ -> int_of_string v | _ -> failwith @@ "bits_size_of_expr: unhandled " ^ pp_expr e ) | Expr_Parens e -> bits_size_of_expr vars e @@ -1509,9 +1341,9 @@ module RedundantSlice = struct (* Last chance to convert dynamic slices into shift & static slice *) | Expr_Slices(x, [Slice_LoWd(l,w)]) when non_const l -> (match option_or (infer_type x) (self#var_type' x) with - | Some (Type_Bits xw) -> + (*| Some (Type_Bits xw) -> let e = Expr_TApply (FIdent ("LSR", 0), [xw], [x; l]) in - Expr_Slices(e, [Slice_LoWd (Expr_LitInt "0", w)]) + Expr_Slices(e, [Slice_LoWd (Expr_LitInt "0", w)])*) | _ -> e) | Expr_Slices(e', [Slice_LoWd (Expr_LitInt "0", wd)]) -> let try_match (opt: ty option): expr = @@ -1832,17 +1664,22 @@ module FixRedefinitions = struct val scoped_bindings : var_t ScopedBindings.t = let s = Stack.create () in Stack.push (Bindings.empty) s ; s + val mutable global_bindings : var_t Bindings.t = + Bindings.empty method push_scope (_:unit) : unit = push_scope scoped_bindings () method pop_scope (_:unit) : unit = pop_scope scoped_bindings () method add_bind (n: var_t) : unit = add_bind scoped_bindings n.name n method existing_binding (i: ident) : var_t option = find_binding scoped_bindings i + method global_binding (i: ident) : var_t option = Bindings.find_opt i global_bindings method incr_binding (i: ident) : var_t = - let v = this#existing_binding i in - match v with + let v = this#global_binding i in + let r = (match v with | Some b -> {b with index = b.index + 1} - | None -> {name=i; index=0} + | None -> {name=i; index=0}) in + global_bindings <- Bindings.add i r global_bindings; + r method! vstmt s = singletonVisitAction @@ match s with @@ -1869,9 +1706,11 @@ module FixRedefinitions = struct let start' = visit_expr this start in let stop' = visit_expr this stop in this#push_scope (); + let v = this#incr_binding var in + this#add_bind v; let body' = visit_stmts this body in this#pop_scope (); - ChangeTo (Stmt_For (var, start', dir, stop', body', loc)) + ChangeTo (Stmt_For (ident_for_v v, start', dir, stop', body', loc)) (* Statements with child scopes that shouldn't appear towards the end of transform pipeline *) | Stmt_Case _ -> failwith "(FixRedefinitions) case not expected" | Stmt_While _ -> failwith "(FixRedefinitions) while not expected" @@ -1895,3 +1734,1009 @@ module FixRedefinitions = struct let v = new redef_renamer g in visit_stmts v s end + + +(* + The analysis attempts to argue that all loop computations + are in parallel, by encoding their simultaneous calculation given + the loop's initial state. + + This is encoded in the abstract domain as the following: + f: var -> state -> value list + where f(var,state) returns a list of values, where the Nth value + of the list corresponds to the value of 'var' on the Nth loop + iteration, when the loop started with state 'state'. + + This function encoding (state -> value list) is encoded in type abs, + which we conceptually consider as type 'abs = {expr list}'. + The semantics of this encoding is listed below: + + collate es = + if List.exists empty es then [] else + let hds = map hd es in + let tails = map tail es in + hds::(collate tails) + + (* + sem : state -> {expr list} -> value list + We use {expr list} to represent our abstract type, corresponding to a list + of expressions implicitly. + Note that state never changes. + N is the number of loop iterations. + *) + + sem state VecOp(op : ident, tes : expr list, es : {expr list} list) : value list = + (* Types should be constant *) + let tes' = map (eval state) tes in + (* Evaluate each argument *) + let es' = map (sem state) es in + (* Collapse list of vector args into vector of args *) + let es' = collate es' in + (* Apply op to each collection of args *) + map (fun es -> op(tes', es)) es' + + sem state Read(vec : expr, pos : {expr list}, width : expr) : value list = + (* Types and width should be constant *) + let vec' = eval state val in + let width' = eval state width in + (* Evaluate pos encoding into a list of values *) + let pos' = sem state pos in + (* Map each position to a result *) + map (fun pos -> Elem.read(vec',pos,width')) pos' + + sem state Constant(e : expr) : value list = + (* Evaluate the constant expression once *) + let e' = eval state e in + (* Produce a vector of this constant N times *) + init N (fun _ -> e') + + sem state Index(e : expr, mul : expr) : value list = + (* Evaluate the constant base & mul expressions once *) + let e' = eval state e in + let m' = eval state mul in + (* Produce a vector of this base + mul * i N times *) + init N (fun i -> e' + m' * i) + + sem state BV(e : {expr list}, w : expr, s : bool) : value list = + let e' = sem state e in + let w' = eval state w in + map (fun e -> int_to_bits(e,w',s)) e' + + sem state Write(var : ident, pos : {expr list}, width : expr, e : {expr list}) : value list = + let pos' = sem state pos in + let width' = eval state width in + let e' = sem state e in + (* This write corresponds to a full definition of var *) + (* Knowing this simplifies its subsequent transform into a vector op *) + assert (unique pos'); + assert (length pos' * width' = width_of_var var); + (* Sort values based on their position *) + map snd (sort fst (zip pos' e')) + + We derive these encodings in a single loop pass, assuming no + interference between loop iterations and no failed assertions. + These are then validated given the summary of loop effects: + - asserts over the Write operation are checked + - interference between loop bodies is checked + + Interference may occur when expressions depend on the results + of prior loop iterators, e.g., a loop counter i := i + 1. + In some cases, such as the loop counter, these values can + be rephrased to be order independent, e.g., Index(i,1) for + the given example. + We apply these transforms and re-attempt the summary. +*) +module LoopClassify = struct + (**************************************************************** + * Symbolic Helpers + ****************************************************************) + + let mk_ite c w a b = + match c with + | Expr_Var (Ident "TRUE" ) -> a + | Expr_Var (Ident "FALSE") -> b + | _ -> Expr_TApply (FIdent("ite", 0), [w], [c;a;b]) + + let expr_of_Z i = Expr_LitInt (Z.to_string i) + let print_bv ({n;v} : Primops.bitvector) = (Z.format ("%0" ^ string_of_int n ^ "b") v) + let expr_of_bv (bv : Primops.bitvector) = Expr_LitBits(print_bv bv) + + let parse_bits s = + match from_bitsLit s with + | VBits v -> v + | _ -> failwith @@ "parse_bits: " ^ s + + let parse_signed_bits s = + match from_bitsLit s with + | VBits v -> Primops.z_signed_extract v.v 0 v.n + | _ -> failwith @@ "parse_signed_bits: " ^ s + + let cvt_int_bits a w = + match a, w with + | Expr_LitInt a, Expr_LitInt w -> + expr_of_bv (Primops.prim_cvt_int_bits (Z.of_string w) (Z.of_string a)) + | _ -> Expr_TApply(FIdent("cvt_int_bits",0), [w], [a; w]) + + let add_int a b = + match a, b with + | Expr_LitInt a, Expr_LitInt b -> Expr_LitInt (Z.to_string (Z.add (Z.of_string a) (Z.of_string b))) + | Expr_LitInt "0", b -> b + | a, Expr_LitInt "0" -> a + | _ -> Expr_TApply(FIdent("add_int",0), [], [a;b]) + + let sub_int a b = + match a, b with + | Expr_LitInt a, Expr_LitInt b -> Expr_LitInt (Z.to_string (Z.sub (Z.of_string a) (Z.of_string b))) + | _ -> Expr_TApply(FIdent("sub_int",0), [], [a;b]) + + let mul_int a b = + match a, b with + | Expr_LitInt a, Expr_LitInt b -> Expr_LitInt (Z.to_string (Z.mul (Z.of_string a) (Z.of_string b))) + | _ -> Expr_TApply(FIdent("mul_int",0), [], [a;b]) + + let div_int a b = + match a, b with + | Expr_LitInt a, Expr_LitInt b -> Expr_LitInt (Z.to_string (Z.div (Z.of_string a) (Z.of_string b))) + | _ -> Expr_TApply(FIdent("div_int",0), [], [a;b]) + + let mod_int a b = + match a, b with + | Expr_LitInt a, Expr_LitInt b -> Expr_LitInt (Z.to_string (Z.rem (Z.of_string a) (Z.of_string b))) + | _ -> Expr_TApply(FIdent("mod_int",0), [], [a;b]) + + let eq_int a b = + match a, b with + | Expr_LitInt a, Expr_LitInt b when Z.of_string a = Z.of_string b -> Expr_Var (Ident "TRUE") + | Expr_LitInt a, Expr_LitInt b -> Expr_Var (Ident "FALSE") + | _ -> Expr_TApply(FIdent("eq_int",0), [], [a;b]) + + let ite_int c a b = + match c, a, b with + | Expr_Var (Ident "TRUE"), _, _ -> a + | Expr_Var (Ident "FALSE"), _, _ -> b + | _, a,b when a = b -> a + | _ -> Expr_TApply(FIdent("ite_int",0), [], [c;a;b]) + + let zero_int = Expr_LitInt "0" + + let add_bits w a b = + match a, b with + | Expr_LitBits a, Expr_LitBits b -> expr_of_bv (Primops.prim_add_bits (parse_bits a) (parse_bits b)) + | _ -> Expr_TApply(FIdent("add_bits",0), [w], [a;b]) + + let sub_bits w a b = + match a, b with + | Expr_LitBits a, Expr_LitBits b -> expr_of_bv (Primops.prim_sub_bits (parse_bits a) (parse_bits b)) + | _ -> Expr_TApply(FIdent("sub_bits",0), [w], [a;b]) + + let mul_bits w a b = + match a, b with + | Expr_LitBits a, Expr_LitBits b -> expr_of_bv (Primops.prim_mul_bits (parse_bits a) (parse_bits b)) + | _ -> Expr_TApply(FIdent("mul_bits",0), [w], [a;b]) + + let zeroes w = + match w with + | Expr_LitInt w -> expr_of_bv { v = Z.zero ; n = int_of_string w } + | _ -> failwith "" + + let neg_bits w a = + sub_bits w (zeroes w) a + + let cvt_bits_sint a w = + match a, w with + | Expr_LitBits bv, Expr_LitInt w -> + let v = parse_signed_bits bv in + Expr_LitInt (Z.to_string v) + | _ -> Expr_TApply(FIdent("cvt_bits_sint",0), [w], [a]) + + let cvt_bits_uint a w = + match a, w with + | Expr_LitBits bv, Expr_LitInt w -> + let v = parse_bits bv in + Expr_LitInt (Z.to_string v.v) + | _ -> Expr_TApply(FIdent("cvt_bits_uint",0), [w], [a]) + + let sign_extend a w = + match a, w with + | _ -> Expr_TApply(FIdent("SignExtend",0), [w], [a]) + + let zero_extend a w = + match a, w with + | _ -> Expr_TApply(FIdent("ZeroExtend",0), [w], [a]) + + let append_bits w1 w2 x y = + match x, y with + | Expr_LitBits x, Expr_LitBits y -> Expr_LitBits (x ^ y) + | _ -> Expr_TApply (FIdent ("append_bits", 0), [w1;w2], [x;y]) + + (**************************************************************** + * Abstract Domain + ****************************************************************) + + type abs = + (* An expression that does not change across loop iterations *) + Constant of expr | + (* A vectorized operation, taking a series of type and standard arguments *) + VecOp of ident * expr list * abs list | + (* A read operation, given position, element width, vector width *) + Read of expr * abs * expr * expr | + (* A write operation, given destination, position, element width, vector width, value *) + Write of ident * abs * expr * expr * abs | + (* Base, mult *) + Index of expr * expr | + (* Base, mult, width *) + BVIndex of expr * expr * expr | + Undecl + + let rec pp_abs e = + match e with + | Constant e -> "Constant(" ^ pp_expr e ^ ")" + | VecOp (f,l,r) -> "VecOp(" ^ pprint_ident f ^ "," ^ Utils.pp_list pp_expr l ^ "," ^ Utils.pp_list pp_abs r ^ ")" + | Read (v,p,w,w') -> "Read(" ^ pp_expr v ^ "," ^ pp_abs p ^ "," ^ pp_expr w ^ "," ^ pp_expr w' ^ ")" + | Write (v,p,w,_,e) -> "Write(" ^ pprint_ident v ^ "," ^ pp_abs p ^ "," ^ pp_expr w ^ "," ^ pp_abs e ^ ")" + | Index (e,m) -> "Index(" ^ pp_expr e ^ "," ^ pp_expr m ^ ")" + | BVIndex (b,m,w) -> "BVIndex(" ^ pp_expr b ^ "," ^ pp_expr m ^ "," ^ pp_expr w ^ ")" + | Undecl -> "Undecl" + + let rec deps e = + match e with + | Constant e -> fv_expr e + | VecOp (_,tes,es) -> + let tes_deps = unionSets (List.map fv_expr tes) in + let es_deps = unionSets (List.map deps es) in + IdentSet.union tes_deps es_deps + | Read (v,p,w,ew) -> unionSets [fv_expr v; deps p; fv_expr w; fv_expr ew] + | Write (v,p,w,vw,e) -> IdentSet.add v (unionSets [deps p; fv_expr w; deps e; fv_expr vw]) + | Index (b,m) -> IdentSet.union (fv_expr b) (fv_expr m) + | BVIndex (b,m,_) -> IdentSet.union (fv_expr b) (fv_expr m) + | Undecl -> IdentSet.empty + + let is_vec_op e = + match e with + | Read _ -> true + | VecOp _ -> true + | _ -> false + + let is_constant e = + match e with + | Constant _ -> true + | _ -> false + + let force_constant e = + match e with + | Constant e -> e + | _ -> failwith @@ "force_constant: " ^ pp_abs e + + let concat_bits ls = + let body = fun (w,x) (yw,y) -> let b = append_bits w yw x y in (add_int w yw,b) in + match ls with + | x::xs -> let (_,r) = List.fold_left body x xs in r + | _ -> failwith "concat" + + (* + Helper to build a select vector operation, where x is bits(elems * elemw) and + integers in sels are less than elems. + *) + let select_vec elems elemw x sels = + let sels_len = Expr_LitInt (string_of_int (List.length sels)) in + let w = Expr_LitInt "32" in + let sels = List.rev sels in + let sels = concat_bits (List.map (fun e -> (w, cvt_int_bits e w)) sels) in + (match sels with + | Expr_LitBits _ -> () + | _ -> failwith @@ "Non-constant sels: " ^ pp_expr sels ); + Expr_TApply(FIdent("select_vec",0), [elems; sels_len; elemw], [x; sels]) + + let shuffle_vec elems elemw x y sels = + let sels_len = Expr_LitInt (string_of_int (List.length sels)) in + let w = Expr_LitInt "32" in + let sels = List.rev sels in + let sels = concat_bits (List.map (fun e -> (w, cvt_int_bits e w)) sels) in + (match sels with + | Expr_LitBits _ -> () + | _ -> failwith @@ "Non-constant sels: " ^ pp_expr sels ); + Expr_TApply(FIdent("shuffle_vec",0), [elems; sels_len; elemw], [x; y; sels]) + + let replicate elems elemw x = + Expr_TApply(FIdent("replicate_bits", 0), [elemw; elems], [x; elems]) + + let build_sels length fn = + match length with + | Expr_LitInt v -> (List.init (int_of_string v) (fun i -> fn (expr_of_int i))) + | _ -> failwith @@ "Non-constant length to build_sels: " ^ pp_expr length + + (**************************************************************** + * Analysis State + ****************************************************************) + + type state = { + (* Base Loop Properties *) + iterations: expr; + (* Variable Classification *) + vars: abs Bindings.t; + (* Loop Defined *) + ld: abs Bindings.t; + + (* Type Info *) + types: ty Bindings.t; + env: Eval.Env.t; + } + + (* Create the state for a single loop analysis, from its definition *) + let init_state var start stop dir types env = + let abs = match dir with + | Direction_Up -> Index(start, expr_of_Z Z.one) + | Direction_Down -> Index(stop, expr_of_Z (Z.neg Z.one)) in + let iterations = match dir with + | Direction_Up -> add_int (sub_int stop start) (expr_of_Z Z.one) + | Direction_Down -> add_int (sub_int start stop) (expr_of_Z Z.one) in + { iterations ; vars = Bindings.empty ; ld = Bindings.add var abs Bindings.empty ; types ; env } + + let get_var v st = + match Bindings.find_opt v st.ld with + | Some v -> Some v + | None -> Bindings.find_opt v st.vars + + let decl_ld v i st = + {st with ld = Bindings.add v i st.ld} + + let assign_var v i st = + if Bindings.mem v st.ld then + {st with ld = Bindings.add v i st.ld} + else + {st with vars = Bindings.add v i st.vars} + + let width_of_expr e st = + match Dis_tc.infer_type e st.types st.env with + | Some (Type_Bits(Expr_LitInt i)) -> (int_of_string i) + | Some (Type_Constructor (Ident "boolean")) -> 1 + | Some (Type_Register(w, _)) -> (int_of_string w) + | _ -> failwith @@ "Unknown expression type: " ^ (pp_expr e) + + (**************************************************************** + * Phase 1: Produce a candidate loop summary + ****************************************************************) + + let vector_ops = [ + "not_bool"; + "and_bool"; + "or_bool"; + "add_bits"; + "sub_bits"; + "mul_bits"; + "sdiv_bits"; + "sle_bits"; + "slt_bits"; + "eq_bits"; + "asr_bits"; + "lsl_bits"; + "not_bits"; + "and_bits"; + "or_bits"; + "eor_bits"; + "append_bits"; + "ZeroExtend"; + "SignExtend"; + ] + + (* Transfer function for a primitive application *) + let tf_prim st f i tes es = + match f, i, tes, es with + (* Everything is constant, can skip *) + | f, i, tes, es when List.for_all is_constant tes && List.for_all is_constant es -> + Constant (Expr_TApply(FIdent(f,i), List.map force_constant tes, List.map force_constant es)) + + (* Supported operations over Index expressions *) + | "cvt_int_bits", 0, [Constant w], [Index(b,m);_] -> + let base = cvt_int_bits b w in + let mult = cvt_int_bits m w in + BVIndex (base, mult, w) + | "add_int", 0, [], [Index (base,mul);Constant offset] + | "add_int", 0, [], [Constant offset;Index (base,mul)] -> + Index (add_int base offset, mul) + | "sub_int", 0, [], [Index (base,mul);Constant offset] -> + Index (sub_int base offset, mul) + | "sub_int", 0, [], [Constant offset;Index (base,mul)] -> + Index (sub_int offset base, mul_int mul (Expr_LitInt "-1")) + | "mul_int", 0, [], [Index (base,mul);Constant offset] + | "mul_int", 0, [], [Constant offset;Index (base,mul)] -> + Index (mul_int base offset, mul_int mul offset) + | "sdiv_int", 0, [], [Index(base,mul);Constant div] -> + Index (div_int base div, div_int mul div) + + (* Supported operations over BVIndex TODO: These don't really handle overflow properly *) + | "cvt_bits_sint", 0, [Constant w], [BVIndex(b,m,_)] -> + let base = cvt_bits_sint b w in + let mult = cvt_bits_sint m w in + Index (base, mult) + | "cvt_bits_uint", 0, [Constant w], [BVIndex(b,m,_)] -> + let base = cvt_bits_uint b w in + let mult = cvt_bits_uint m w in + Index (base, mult) + | "ZeroExtend", 0, [Constant oldw; Constant neww], [BVIndex(b,m,_); _] -> + let base = zero_extend b neww in + let mult = zero_extend m neww in + BVIndex (base, mult, neww) + | "SignExtend", 0, [Constant oldw; Constant neww], [BVIndex(b,m,_); _] -> + let base = sign_extend b neww in + let mult = sign_extend m neww in + BVIndex (base, mult, neww) + | "add_bits", 0, [Constant w], [BVIndex(base,mul,_);Constant offset] -> + BVIndex(add_bits w base offset, mul, w) + | "sub_bits", 0, [Constant w], [BVIndex(base,mul,_);Constant offset] -> + BVIndex(sub_bits w base offset, mul, w) + + (* Reading Operations *) + | "Elem.read", 0, [Constant vecw ; Constant elemw], [Constant v; pos; _] -> + Read (v, pos, elemw, vecw) + + (* Writing Operations *) + | "Elem.set", 0, [Constant vecw ; Constant elemw], [Constant (Expr_Var v); pos; _; arg] -> + Write (v, pos, elemw, vecw, arg) + (* Match offset Elem.set operations TODO: This would be cleaner as a simp rule later on *) + | "Elem.set", 0, [Constant vecw ; Constant elemw], [ + Write(v,Index(Expr_LitInt "0",Expr_LitInt "2"),elemw',_,arg'); + Index(Expr_LitInt "1",Expr_LitInt "2"); _; arg] + when elemw = elemw' -> + let a = VecOp(FIdent("append_bits",0), [elemw;elemw], [arg;arg']) in + Write (v, Index(Expr_LitInt "0",Expr_LitInt "1"), add_int elemw elemw', vecw, a) + + (* Vector Operations *) + | f, 0, tes, es when List.mem f vector_ops && List.exists is_vec_op es && List.for_all is_constant tes -> + VecOp (FIdent (f,0), List.map force_constant tes, es) + + | _ -> failwith @@ "Unknown loop prim: " ^ f ^ " " ^ Utils.pp_list pp_abs tes ^ " " ^ Utils.pp_list pp_abs es + + (* Transfer function for an expression *) + let rec tf_expr st e = + match e with + | Expr_Var v -> + (match get_var v st with + | Some abs -> abs + | None -> Constant e) + | Expr_LitBits _ -> Constant e + | Expr_LitInt _ -> Constant e + | Expr_TApply(FIdent(f,i), tes, es) -> + let tes = List.map (tf_expr st) tes in + let es = List.map (tf_expr st) es in + tf_prim st f i tes es + | Expr_Slices(e', [Slice_LoWd(lo,wd)]) -> + let ow = Expr_LitInt (string_of_int (width_of_expr e' st)) in + (match tf_expr st e', tf_expr st lo, tf_expr st wd with + (* Entirely constant, pass it through *) + | Constant e', Constant lo, Constant wd -> + Constant (Expr_Slices(e', [Slice_LoWd(lo, wd)])) + (* Constant slice position over vectorized expression *) + | e', Constant lo, Constant wd when is_vec_op e' -> + VecOp(FIdent("slice",0), [ow; wd; lo], [e']) + (* Index based slice over constant, corresponding to a vector read. + TODO: There is a more general approach to this. + *) + | Constant e', Index(b,m), Constant (Expr_LitInt "1") -> + Read(e', Index(b,m), Expr_LitInt "1", ow) + | Constant e', Index(b,m), Constant w when m = w -> + Read(e', Index(b,Expr_LitInt "1"), w, ow) + | a, b, c -> failwith @@ "Failed loop slice: " ^ Utils.pp_list pp_abs [a;b;c]) + | _ -> failwith @@ "Failed loop expr: " ^ pp_expr e + + (* Join abs a & b given the condition c *) + let join_abs w c a b = + match c, a, b with + | _, a, b when a = b -> a + (* This is a trivial result, constant for all loop iterations *) + | Constant c, Constant a, Constant b -> + Constant (Expr_TApply(FIdent("ite_bits",0), [w], [c;a;b])) + (* Vector base ite *) + | _ when List.for_all (fun v -> is_vec_op v || is_constant v) [c;a;b] -> + VecOp (FIdent("ite",0), [w], [c;a;b]) + | _ -> failwith @@ "Failed join_abs: " ^ pp_abs c ^ " ? " ^ pp_abs a ^ " : " ^ pp_abs b + + (* Join states a & b given the condition cond *) + let join_st cond st1 st2 = + (* Merge loop defined constructs, assume they are defined + on both paths *) + let ld = Bindings.merge (fun k l r -> + match l, r with + | Some l, Some r when l = r -> Some l + | Some l, Some r -> + let w = expr_of_int (width_of_expr (Expr_Var k) st1) in + Some (join_abs w cond l r) + | _ -> None) st1.ld st2.ld in + (* Merge external constructs, support conditional effects *) + let vars = Bindings.merge (fun k l r -> + match l, r with + (* Same effect *) + | Some l, Some r when l = r -> Some l + (* Conditional write *) + | Some (Write(v,pos,w,we,e)), None -> + let w' = expr_of_int (width_of_expr (Expr_Var k) st1) in + Some (Write(v,pos,w,we,join_abs w' cond e (Read(Expr_Var v,pos,w,we)))) + | None, Some (Write(v,pos,w,we,e)) -> + let w' = expr_of_int (width_of_expr (Expr_Var k) st1) in + Some (Write(v,pos,w,we,join_abs w' cond (Read(Expr_Var v,pos,w,we)) e)) + | Some (Constant e), None -> + let w' = expr_of_int (width_of_expr (Expr_Var k) st1) in + Some (join_abs w' cond (Constant e) (Constant (Expr_Var k))) + (* Conditional write *) + | _ -> + failwith @@ "Failed join_st: " ^ pprint_ident k ^ ":" ^ + (Utils.pp_option pp_abs l) ^ " " ^ Utils.pp_option pp_abs r + ) st1.vars st2.vars in + { st1 with vars; ld } + + (* Transfer function for a list of statements *) + let rec tf_stmts st s = + List.fold_left (fun st stmt -> + match stmt with + (* Loop Internal Calculations *) + | Stmt_ConstDecl(ty, v, e, loc) -> + let abs = tf_expr st e in + decl_ld v abs st + | Stmt_VarDecl(ty, v, e, loc) -> + let abs = tf_expr st e in + decl_ld v abs st + | Stmt_VarDeclsNoInit(ty, [v], loc) -> + decl_ld v Undecl st + | Stmt_Assign(LExpr_Var v, e, loc) -> + let abs = tf_expr st e in + assign_var v abs st + | Stmt_If(c, t, [], f, loc) -> + let abs = tf_expr st c in + let tst = tf_stmts st t in + let fst = tf_stmts st f in + join_st abs tst fst + | Stmt_Assert(e, loc) -> + (* TODO: We should actually validate or keep this around *) + st + | _ -> failwith @@ "Unknown loop stmt: " ^ pp_stmt stmt) st s + + (**************************************************************** + * Phase 2: Fixed Point Identification + ****************************************************************) + + (* + Given summaries of each externally scoped variable write, + determine if they are trivially parallelized. + + As a first phase, we attempt to show all externally scoped variables + are only self-referential, i.e., there is no dependence relation + between externally scoped variables. + The only exception to this is trivial reductions to functions over + the loop index, such as x := x + 1. + + Once we know variables are at most self-referential, we determine the + necessary reduction to capture their cumulative effects. + This occurs in Phase 3. + *) + + (* If we pre-load definitions for external values, fix them up here. + Only handles the case where we define a value to be a function of + index (x := (base + index * mult) and we anticipate the final value + to be an additional increment of mult. + *) + let amend_pre_load init_st st = + let vars = Bindings.mapi (fun var def -> + match def, Bindings.find_opt var init_st.vars with + | x, None -> x + | BVIndex(Expr_TApply(FIdent("add_bits",0), [w], [base;mult]),mult',w'), Some (BVIndex(base',mult'',w'')) + when base = base' && mult = mult' && w = w' && mult = mult'' && w = w'' -> + BVIndex(base',mult'',w'') + | BVIndex(Expr_TApply(FIdent("sub_bits",0), [w], [base;mult]),mult',w'), Some (BVIndex(base',mult'',w'')) + when base = base' && neg_bits w mult = mult' && w = w' && mult' = mult'' && w = w'' -> + BVIndex(base',mult'',w'') + | x, Some y -> + failwith @@ "Failed to re-establish initial conditions: " ^ pp_abs x ^ " and " ^ pp_abs y + ) st.vars in + { st with vars } + + (* Determine if the summary is valid: + 1. All constants are actually constant + 2. Modified variables are at most self-referential + Can produce fixes for cases where a constant is not constant, but instead a function of loop index. + *) + let validate_summary effects = + (* Identify possible fixes before validation *) + let constant_fixes = Bindings.fold (fun var def acc -> + match def with + | Constant (Expr_TApply(FIdent("add_bits",0), [w], [Expr_Var var'; b])) + when var = var' && not (IdentSet.mem var (fv_expr b)) -> + Bindings.add var (BVIndex(Expr_Var var, b, w)) acc + | Constant (Expr_TApply(FIdent("sub_bits",0), [w], [Expr_Var var'; b])) + when var = var' && not (IdentSet.mem var (fv_expr b)) -> + Bindings.add var (BVIndex(Expr_Var var, neg_bits w b, w)) acc + | _ -> acc) effects Bindings.empty in + (* If no fixes, validate *) + if constant_fixes <> Bindings.empty then constant_fixes else + let _ = Bindings.iter (fun var def -> + (* No cross-references *) + let _ = Bindings.iter (fun var' def' -> + match var,def,var',def' with + (* Allow for references to BVIndex vars *) + | _,BVIndex _,_,_-> () + (* Ignore self *) + | v,_,v',_ when v = v' -> () + (* Check for reference to var in def' *) + | v,d,v',d' when IdentSet.mem v (deps d') -> + failwith @@ "Cross-reference: " ^ pprint_ident v ^ " := " ^ pp_abs d ^ " && " ^ pprint_ident v' ^ " := " ^ pp_abs d' + | _ -> () + ) effects in + (* Constants are truely constant *) + match def with + | Constant e when IdentSet.disjoint (deps def) (bindings_domain effects) -> () + | Constant e -> failwith @@ "Failed to generalise: " ^ pprint_ident var ^ " := " ^ pp_abs def + | _ -> () + ) effects in + Bindings.empty + + (* Run the analysis from an initial state. + Re-runs if we identify abstractions for external state. + *) + let rec fixed_point init_st body = + let cand_st = tf_stmts init_st body in + let cand_st = amend_pre_load init_st cand_st in + let fixes = validate_summary cand_st.vars in + if fixes = Bindings.empty then cand_st + else + let init_st' = { init_st with vars = fixes } in + fixed_point init_st' body + + (**************************************************************** + * Phase 3: Build expression from the abstract points + ****************************************************************) + + (* + Convert abstract points into expressions. + *) + + (* + Build vector primitive operations from the abstract state. + *) + let rec build_vec_prim st f i tes es = + let iters = st.iterations in + let mul_iters i = mul_int i iters in + let std_args l = List.map (build_vec_expr st) l in + let vec_args l = (List.map (build_vec_expr st) l) @ [iters] in + + match f, i, tes, es with + (* Bool Ops, all applied bit-wise *) + | "not_bool", 0, [], [x] -> + Expr_TApply(FIdent("not_bits", 0), [iters], std_args [x]) + | "and_bool", 0, [], [x;y] -> + Expr_TApply(FIdent("and_bits", 0), [iters], std_args [x;y]) + | "or_bool", 0, [], [x;y] -> + Expr_TApply(FIdent("or_bits", 0), [iters], std_args [x;y]) + + (* Bit-wise Ops *) + | "not_bits", 0, [w], [x] -> + Expr_TApply(FIdent("not_bits", 0), [mul_iters w], std_args [x]) + | "and_bits", 0, [w], [x;y] -> + Expr_TApply(FIdent("and_bits", 0), [mul_iters w], std_args [x;y]) + | "or_bits", 0, [w], [x;y] -> + Expr_TApply(FIdent("or_bits", 0), [mul_iters w], std_args [x;y]) + | "eor_bits", 0, [w], [x;y] -> + Expr_TApply(FIdent("eor_bits", 0), [mul_iters w], std_args [x;y]) + + (* Element-wise Ops *) + | "add_bits", 0, [w], [x;y] -> + Expr_TApply(FIdent("add_vec", 0), [iters; w], vec_args [x; y]) + | "sub_bits", 0, [w], [x;y] -> + Expr_TApply(FIdent("sub_vec", 0), [iters; w], vec_args [x; y]) + | "mul_bits", 0, [w], [x;y] -> + Expr_TApply(FIdent("mul_vec", 0), [iters; w], vec_args [x; y]) + | "sdiv_bits", 0, [w], [x;y] -> + Expr_TApply(FIdent("sdiv_vec", 0), [iters; w], vec_args [x; y]) + | "sle_bits", 0, [w], [x;y] -> + Expr_TApply(FIdent("sle_vec", 0), [iters; w], vec_args [x; y]) + | "slt_bits", 0, [w], [x;y] -> + Expr_TApply(FIdent("slt_vec", 0), [iters; w], vec_args [x; y]) + | "eq_bits", 0, [w], [x;y] -> + Expr_TApply(FIdent("eq_vec", 0), [iters; w], vec_args [x; y]) + | "asr_bits", 0, [w;w'], [x;y] when w = w' -> + Expr_TApply(FIdent("asr_vec", 0), [iters; w], vec_args [x;y]) + | "asr_bits", 0, [w;w'], [x;y] -> + let y = Expr_TApply(FIdent("scast_vec", 0), [iters; w; w'], (vec_args [y]) @ [w]) in + Expr_TApply(FIdent("asr_vec", 0), [iters; w], [build_vec_expr st x;y;iters]) + | "lsl_bits", 0, [w;w'], [x;y] when w = w' -> + Expr_TApply(FIdent("lsl_vec", 0), [iters; w], vec_args [x;y]) + | "lsl_bits", 0, [w;w'], [x;y] -> + let y = Expr_TApply(FIdent("scast_vec", 0), [iters; w; w'], (vec_args [y]) @ [w]) in + Expr_TApply(FIdent("lsl_vec", 0), [iters; w], [build_vec_expr st x;y;iters]) + | "ite", 0, [w], [c;x;y] -> + Expr_TApply(FIdent("ite_vec", 0), [iters; w], vec_args [c;x;y]) + + (* Casts *) + | "ZeroExtend", 0, [ow;nw], [x;_] -> + Expr_TApply(FIdent("zcast_vec", 0), [iters; nw; ow], (vec_args [x]) @ [nw]) + | "SignExtend", 0, [ow;nw], [x;_] -> + Expr_TApply(FIdent("scast_vec", 0), [iters; nw; ow], (vec_args [x]) @ [nw]) + | "slice", 0, [ow;nw;lo], [x] -> + let shifted = + match lo with + | Expr_LitInt "0" -> build_vec_expr st x + | _ -> Expr_TApply(FIdent("lsr_vec",0), [iters; ow], vec_args [x; Constant (cvt_int_bits lo ow)]) in + Expr_TApply(FIdent("trunc_vec",0), [iters; nw; ow], [shifted; iters; nw]) + + (* Appends *) + (* Special case for a zero extend, convert into ZeroExtend *) + | "append_bits", 0, [wx;wy], [Constant (Expr_LitBits i);y] when (parse_bits i).v = Z.zero -> + let nw = add_int wx wy in + Expr_TApply(FIdent("zcast_vec", 0), [iters; nw; wy], (vec_args [y]) @ [nw]) + | "append_bits", 0, [wx;wy], [y;Constant (Expr_LitBits i)] when (parse_bits i).v = Z.zero -> + let nw = add_int wx wy in + let cast = Expr_TApply(FIdent("zcast_vec", 0), [iters; nw; wx], (vec_args [y]) @ [nw]) in + let shifts = replicate st.iterations nw (cvt_int_bits wy nw) in + Expr_TApply(FIdent("lsl_vec", 0), [iters; nw], [cast;shifts;iters]) + (* Append over two bitvectors turns into zip TODO: Extend to support gcd of their widths *) + | "append_bits", 0, [wx;wy], [x;y] when wx = wy -> + let input = Expr_TApply(FIdent("append_bits",0), [mul_iters wx; mul_iters wy], std_args [x;y]) in + let sels = build_sels st.iterations (fun i -> [ i ; add_int st.iterations i ]) in + let sels = List.fold_left (@) [] sels in + select_vec (add_int st.iterations st.iterations) wx input sels + + | _ -> failwith @@ "Unsupported conversion: " ^ f ^ " " ^ Utils.pp_list pp_expr tes ^ " " ^ Utils.pp_list pp_abs es + + (* + Turn an abstract point into an operation. + In effect, this corresponds to computing the abstract point simultaneously + for all loop iterations. + *) + and build_vec_expr st abs = + match abs with + (* Constant should not change, replicate it *) + | Constant expr -> + let w = expr_of_int (width_of_expr expr st) in + replicate st.iterations w expr + (* Vector Operation *) + | VecOp(FIdent(f,i), tes, es) -> + build_vec_prim st f i tes es + (* Read becomes a select operation, building based on its stride and width *) + | Read(expr,Index(base,mult),width,expr_width) -> + let sels = build_sels st.iterations (fun i -> add_int base (mul_int mult i)) in + let elems = div_int expr_width width in + select_vec elems width expr sels + (* Write is a select between a base variable and the overwritten component *) + | Write(var,Index(base,mult),width,var_width,e) -> + let elems = div_int var_width width in + let sels = build_sels elems (fun i -> + ite_int (eq_int (mod_int (sub_int i base) mult) zero_int) + (add_int (div_int (sub_int i base) mult) elems) i) in + let expr = append_bits (mul_int st.iterations width) var_width (build_vec_expr st e) (Expr_Var var) in + select_vec (add_int elems st.iterations) width expr sels + | _ -> failwith @@ "Failed to build vector expression for: " ^ pp_abs abs + + (* Identify cases where self-reference is limited to reads of a particular + position and width. *) + let rec parallel_write var pos width e = + match e with + | VecOp(f, tes, es) -> + let tes = not (IdentSet.mem var (unionSets (List.map fv_expr tes))) in + let es = List.map (parallel_write var pos width) es in + tes && List.for_all (fun s -> s) es + | Read(Expr_Var v',pos',width',_) when v' = var -> + pos = pos' && width = width' + | _ -> not (IdentSet.mem var (deps e)) + + (* For a variable and abstract point, produce an expression equivalent to the + parallel evaluation of the abstract point. *) + let summarize_assign st var abs = + match abs with + (* Result is not dependent on itself in anyway *) + | a when not (IdentSet.mem var (deps a)) -> + build_vec_expr st abs + (* Parallel Write, element is only dependent on itself *) + | Write(var',pos,width,exprw,e) when var = var' && parallel_write var pos width e -> + build_vec_expr st abs + (* Final result for a function of loop index *) + | BVIndex(Expr_Var base,mul,w) when base = var -> + let iters = cvt_int_bits st.iterations w in + let m = mul_bits w mul iters in + add_bits w (Expr_Var base) m + + (* Reduce Add *) + | VecOp(FIdent("add_bits", 0), [w], [Constant (Expr_Var var') ; e]) when is_vec_op e && var' = var && not (IdentSet.mem var (deps e)) -> + let e = build_vec_expr st e in + Expr_TApply ( FIdent("reduce_add",0), [st.iterations; w], [e ; Expr_Var var]) + + | _ -> failwith @@ "Failed to summarize " ^ pprint_ident var ^ " <- " ^ pp_abs abs + + (* Given a successful abstract representation of a loop, reduce its observable + effects into a series of assignments. + *) + let loop_summary st loc = + Bindings.fold (fun var abs acc -> + let c = summarize_assign st var abs in + let s = Stmt_Assign(LExpr_Var var, c, loc) in + s::acc (* TODO: Need temps for this *) + ) st.vars [] + + (**************************************************************** + * Analysis Entry Point + ****************************************************************) + + (* Map from inside out *) + let rec walk s types env = + List.fold_left (fun acc stmt -> + match stmt with + | Stmt_For(var, start, dir, stop, body, loc) -> + let body = walk body types env in + let st = init_state var start stop dir types env in + let st' = fixed_point st body in + let sum = loop_summary st' loc in + (acc@sum) + | _ -> (acc@[stmt])) ([]) s + + let parse_sels v = + let chars = String.length v in + let elems = chars / 32 in + let e = List.init elems (fun i -> let bv = parse_bits (String.sub v (i * 32) (32)) in Z.to_int bv.v) in + List.rev e + let print_sels sels = + List.fold_left (fun a s -> (print_bv { n = 32 ; v = Z.of_int s } ^ a)) "" sels + + let check_leq x y = + match x, y with + | Expr_LitInt x, Expr_LitInt y -> Z.leq (Z.of_string x) (Z.of_string y) + | _ -> false + + let check_lt x y = + match x, y with + | Expr_LitInt x, Expr_LitInt y -> Z.lt (Z.of_string x) (Z.of_string y) + | _ -> false + + let is_div x y = + match x, y with + | Expr_LitInt x, Expr_LitInt y -> (Z.rem (Z.of_string x) (Z.of_string y)) = Z.zero + | _ -> false + + let is_const_sels sels = + List.for_all (fun i -> match i with Expr_LitInt _ -> true | _ -> false) sels + let force_const_sels sels = + List.map (fun i -> match i with Expr_LitInt i -> int_of_string i | _ -> failwith "force_const") sels + + let apply_sels bv w sels = + let ins = (String.length bv) / w in + let vals = List.init ins (fun i -> String.sub bv (i * w) w) in + let vals = List.rev vals in + let res = List.map (fun i -> List.nth vals i) sels in + let res = List.rev res in + String.concat "" res + + let rec inc_by s is = + match is with + | (Expr_LitInt i)::(Expr_LitInt j)::is when Z.sub (Z.of_string j) (Z.of_string i) = s -> + inc_by s ((Expr_LitInt j)::is) + | [Expr_LitInt l] -> Some (Z.of_string l) + | _ -> None + + let is_slice sels w = + match sels with + | [i] -> Some (Slice_LoWd(mul_int i w,w)) + | (Expr_LitInt i)::is -> + let first = Z.of_string i in + (match inc_by (Z.one) sels with + | Some last -> + let diff = (Z.add (Z.sub last first) Z.one) in + let width = mul_int (expr_of_Z diff) w in + let lo = mul_int (expr_of_Z first) w in + Some (Slice_LoWd(lo, width)) + | _ -> None) + | _ -> None + + let is_const w = + match w with Expr_LitInt _ -> true | _ -> false + + let force_const w = + match w with Expr_LitInt i -> int_of_string i | _ -> failwith "" + + let rec push_select elems w x sels st = + match x with + | Expr_TApply (FIdent ("append_bits", 0), [wr;wl], [r;l]) + when List.for_all (fun i -> check_leq (mul_int w (add_int i (Expr_LitInt "1"))) wl) sels && + is_div wr w && is_div wl w -> + let elems = sub_int elems (div_int wr w) in + push_select elems w l sels st + | Expr_TApply (FIdent ("append_bits", 0), [wr;wl], [r;l]) + when List.for_all (fun i -> check_leq wl (mul_int w i)) sels && is_div wl w && is_div wr w -> + let shift = div_int wl w in + let sels = List.map (fun i -> sub_int i shift) sels in + let elems = sub_int elems (div_int wl w) in + push_select elems w r sels st + | Expr_TApply (FIdent ("append_bits", 0), [wr;wl], [r;l]) when wr = wl -> + let elems = div_int elems (expr_of_int 2) in + shuffle_vec elems w r l sels + + (* Comps *) + | Expr_TApply (FIdent (("sle_vec"|"eq_vec"|"slt_vec") as f, 0), ([_;w] as tes), [l;r;n]) when n = elems -> + let l = push_select elems w l sels st in + let r = push_select elems w r sels st in + Expr_TApply (FIdent (f, 0), tes, [l;r;n]) + (* Binops *) + | Expr_TApply (FIdent (("add_vec"|"sub_vec"|"mul_vec"|"asr_vec"|"lsr_vec"|"lsl_vec") as f, 0), tes, [l;r;n]) when n = elems -> + let l = push_select elems w l sels st in + let r = push_select elems w r sels st in + Expr_TApply (FIdent (f, 0), tes, [l;r;n]) + (* Casts *) + | Expr_TApply (FIdent (("trunc_vec"|"zcast_vec"|"scast_vec"), 0), [n;nw;ow], [x;n';nw']) when nw = ow -> + push_select elems ow x sels st + | Expr_TApply (FIdent (("trunc_vec"|"zcast_vec"|"scast_vec") as f, 0), [n;nw;ow], [x;n';nw']) when n = elems -> + let x = push_select elems ow x sels st in + Expr_TApply (FIdent (f, 0), [n;nw;ow], [x;n';nw']) + (* Ternary *) + | Expr_TApply (FIdent ("ite_vec", 0), tes, [c;l;r;n]) when n = elems -> + let c = push_select elems (expr_of_int 1) c sels st in + let r = push_select elems w r sels st in + let l = push_select elems w l sels st in + Expr_TApply (FIdent ("ite_vec", 0), tes, [c;l;r;n]) + + (* Replicate, given same element count no difference in result *) + | Expr_TApply (FIdent ("replicate_bits", 0), tes, [_;n]) when n = elems -> + x + + (* Slice from 0 to some width is redundant, just slice full expression directly *) + | Expr_Slices (x, [Slice_LoWd(lo, wd)]) when is_div lo w -> + let offset = div_int lo w in + let sels = List.map (add_int offset) sels in + let wd = width_of_expr x st in + let elems = div_int (expr_of_int wd) w in + push_select elems w x sels st + + (* Nested selects, easy case of matching element widths *) + | Expr_TApply (FIdent ("select_vec", 0), [ins'; outs'; w'], [x; Expr_LitBits sels']) when is_const_sels sels && w = w' -> + let sels = force_const_sels sels in + let sels' = parse_sels sels' in + let res = List.map (List.nth sels') sels in + let res = List.map expr_of_int res in + push_select ins' w x res st + + (* Acceptable result, consider possible slice reduction *) + | Expr_Var v -> + (match is_slice sels w with + | Some s -> Expr_Slices(Expr_Var v, [s]) + | _ -> select_vec elems w x sels) + (* Evaluate the select given a constant expression *) + | Expr_LitBits x when is_const_sels sels && is_const w -> + let sels = force_const_sels sels in + let w = force_const w in + Expr_LitBits (apply_sels x w sels) + + (* Failure case, wasn't able to reduce *) + | _ -> + Printf.printf "push_select: %s\n" (pp_expr x); + select_vec elems w x sels + + class cleanup st = object + inherit Asl_visitor.nopAslVisitor + method !vexpr e = + (match e with + | Expr_TApply (FIdent ("select_vec", 0), [ins; outs; w], [x;Expr_LitBits sels]) -> + let sels = parse_sels sels in + let sels = List.map (fun i -> expr_of_int i) sels in + ChangeDoChildrenPost(push_select ins w x sels st, fun e -> e) + | _ -> DoChildren) + + (*| Expr_TApply (FIdent ("zcast_vec", 0), tes, [Expr_TApply (FIdent ("add_vec", 0), tes', [ + Expr_TApply (FIdent ("zcast_vec", 0), [n;nw;ow], [x;_;_]) ; + Expr_TApply (FIdent ("zcast_vec", 0), _, [y;_;_]) ; + _ ]) ; _ ; nw']) -> + Expr_TApply (FIdent ("add_vec", 0), [n;nw'], [ + Expr_TApply (FIdent ("zcast_vec", 0), [n;nw';ow], [x;n;nw']) ; + Expr_TApply (FIdent ("zcast_vec", 0), [n;nw';ow], [y;n;nw']) ; + n]) + + | _ -> e) in + ChangeDoChildrenPost(e,fn) *) + end + + let run (s: stmt list) env : (bool * stmt list) = + let tys = Dis_tc.LocalVarTypes.run [] [] s in + let st = { types = tys ; env ; iterations = Expr_LitInt "0" ; ld = Bindings.empty ; vars = Bindings.empty } in + try + let res = walk s tys env in + let res = visit_stmts (new cleanup st) res in + (true,res) + with e -> + let m = Printexc.to_string e in + Printf.printf "\nVec Failure: %s\n" m; + (false,s) + +end diff --git a/libASL/utils.ml b/libASL/utils.ml index cc0917c1..ef16c906 100644 --- a/libASL/utils.ml +++ b/libASL/utils.ml @@ -249,6 +249,11 @@ let pp_list f xs = Printf.sprintf "[%s]" (String.concat " ; " (List.map f xs)) let pp_pair l r (x,y) = Printf.sprintf "(%s, %s)" (l x) (r y) +let pp_option f a = + match a with + | Some v -> "Some(" ^ f v ^ ")" + | _ -> "None" + (**************************************************************** * End ****************************************************************) diff --git a/tests/override.asl b/tests/override.asl index fe39788d..5662de53 100644 --- a/tests/override.asl +++ b/tests/override.asl @@ -326,6 +326,106 @@ integer LowestSetBit(bits(N) x) return N; +// Vector Operations + +bits(W * N) add_vec(bits(W * N) x, bits(W * N) y, integer N) + bits(W * N) result; + for i = 0 to (N - 1) + Elem[result, i, W] = Elem[x, i, W] + Elem[y, i, W]; + return result; + +bits(W * N) sub_vec(bits(W * N) x, bits(W * N) y, integer N) + bits(W * N) result; + for i = 0 to (N - 1) + Elem[result, i, W] = Elem[x, i, W] - Elem[y, i, W]; + return result; + +bits(W * N) mul_vec(bits(W * N) x, bits(W * N) y, integer N) + bits(W * N) result; + for i = 0 to (N - 1) + Elem[result, i, W] = Elem[x, i, W] * Elem[y, i, W]; + return result; + +bits(W * N) lsr_vec(bits(W * N) x, bits(W * N) y, integer N) + bits(W * N) result; + for i = 0 to (N - 1) + Elem[result, i, W] = lsr_bits(Elem[x, i, W], Elem[y, i, W]); + return result; + +bits(W * N) asr_vec(bits(W * N) x, bits(W * N) y, integer N) + bits(W * N) result; + for i = 0 to (N - 1) + Elem[result, i, W] = asr_bits(Elem[x, i, W], Elem[y, i, W]); + return result; + +bits(W * N) lsl_vec(bits(W * N) x, bits(W * N) y, integer N) + bits(W * N) result; + for i = 0 to (N - 1) + Elem[result, i, W] = lsl_bits(Elem[x, i, W], Elem[y, i, W]); + return result; + +bits(W * N) ite_vec(bits(N) c, bits(W * N) x, bits(W * N) y, integer N) + bits(W * N) result; + for i = 0 to (N - 1) + Elem[result, i, W] = if c[i] == '1' then Elem[x, i, W] else Elem[y, i, W]; + return result; + +bits(N) sle_vec(bits(W * N) x, bits(W * N) y, integer N) + bits(N) result; + for i = 0 to (N - 1) + Elem[result, i, 1] = if sle_bits(Elem[x, i, W],Elem[y, i, W]) then '1' else '0'; + return result; + +bits(N) slt_vec(bits(W * N) x, bits(W * N) y, integer N) + bits(N) result; + for i = 0 to (N - 1) + Elem[result, i, 1] = if slt_bits(Elem[x, i, W], Elem[y, i, W]) then '1' else '0'; + return result; + +bits(N) eq_vec(bits(W * N) x, bits(W * N) y, integer N) + bits(N) result; + for i = 0 to (N - 1) + Elem[result, i, 1] = if (Elem[x, i, W] == Elem[y, i, W]) then '1' else '0'; + return result; + +bits(NW * N) zcast_vec(bits(W * N) x, integer N, integer NW) + bits(NW * N) result; + for i = 0 to (N - 1) + Elem[result, i, NW] = ZeroExtend(Elem[x, i, W], NW); + return result; + +bits(NW * N) scast_vec(bits(W * N) x, integer N, integer NW) + bits(NW * N) result; + for i = 0 to (N - 1) + Elem[result, i, NW] = SignExtend(Elem[x, i, W], NW); + return result; + +bits(NW * N) trunc_vec(bits(W * N) x, integer N, integer NW) + bits(NW * N) result; + for i = 0 to (N - 1) + Elem[result, i, NW] = (Elem[x, i, W])[ 0 +: NW ]; + return result; + +bits(W * N) select_vec(bits(W * M) x, bits(32 * N) sel) + bits(W * N) result; + for i = 0 to (N - 1) + integer pos = UInt(Elem[sel,i,32]); + Elem[result, i, W] = Elem[x,pos,W]; + return result; + +bits(W * N) shuffle_vec(bits(W * M) x, bits(W * M) y, bits(32 * N) sel) + bits(W * N) result; + bits(W * M * 2) input = x:y; + for i = 0 to (N - 1) + integer pos = UInt(Elem[sel,i,32]); + Elem[result, i, W] = Elem[input,pos,W]; + return result; + +bits(W) reduce_add(bits(W * N) x, bits(W) init) + bits(W) result = init; + for i = 0 to (N - 1) + result = result + Elem[x,i,W]; + return result; // bits(8*size) _Mem[AddressDescriptor desc, integer size, AccessDescriptor accdesc]; // _Mem[AddressDescriptor desc, integer size, AccessDescriptor accdesc] = bits(8*size) value;