From 01177094d660286ebe8c83c68d1e8890e29d154c Mon Sep 17 00:00:00 2001 From: Alasdair Armstrong Date: Thu, 11 Jul 2019 14:21:25 +0100 Subject: [PATCH] Make sure constant folding won't fold external definitions that also have sail definitions Definitions can be made external on a per-backend basis, so we need to make sure constant folding doesn't inline external functions that have sail definitions for backends other than the ones we are currently targetting --- src/constant_fold.ml | 20 ++++++++++++-------- src/constant_propagation.ml | 12 ++++++------ src/constant_propagation.mli | 1 + src/constant_propagation_mutrec.ml | 13 +++++++------ src/monomorphise.ml | 8 ++++---- src/monomorphise.mli | 1 + src/rewrites.ml | 17 +++++++++-------- 7 files changed, 40 insertions(+), 32 deletions(-) diff --git a/src/constant_fold.ml b/src/constant_fold.ml index abedcf359..7e21294b8 100644 --- a/src/constant_fold.ml +++ b/src/constant_fold.ml @@ -191,7 +191,7 @@ let rec run frame = let initial_state ast env = Interpreter.initial_state ~registers:false ast env safe_primops -let rw_exp ok not_ok istate = +let rw_exp target ok not_ok istate = let evaluate e_aux annot = let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in try @@ -220,7 +220,11 @@ let rw_exp ok not_ok istate = ok (); E_aux (E_lit (L_aux (L_unit, fst annot)), annot) | E_app (id, args) when List.for_all is_constant args -> - evaluate e_aux annot + let env = env_of_annot annot in + if not (Env.is_extern id env target) then + evaluate e_aux annot + else + E_aux (e_aux, annot) | E_cast (typ, (E_aux (E_lit _, _) as lit)) -> ok (); lit @@ -243,9 +247,9 @@ let rw_exp ok not_ok istate = in fold_exp { id_exp_alg with e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot)} -let rewrite_exp_once = rw_exp (fun _ -> ()) (fun _ -> ()) +let rewrite_exp_once target = rw_exp target (fun _ -> ()) (fun _ -> ()) -let rec rewrite_constant_function_calls' ast = +let rec rewrite_constant_function_calls' target ast = let rewrite_count = ref 0 in let ok () = incr rewrite_count in let not_ok () = decr rewrite_count in @@ -253,16 +257,16 @@ let rec rewrite_constant_function_calls' ast = let rw_defs = { rewriters_base with - rewrite_exp = (fun _ -> rw_exp ok not_ok istate) + rewrite_exp = (fun _ -> rw_exp target ok not_ok istate) } in let ast = rewrite_defs_base rw_defs ast in (* We keep iterating until we have no more re-writes to do *) if !rewrite_count > 0 - then rewrite_constant_function_calls' ast + then rewrite_constant_function_calls' target ast else ast -let rewrite_constant_function_calls ast = +let rewrite_constant_function_calls target ast = if !optimize_constant_fold then - rewrite_constant_function_calls' ast + rewrite_constant_function_calls' target ast else ast diff --git a/src/constant_propagation.ml b/src/constant_propagation.ml index 201e43e7c..00b3d1920 100644 --- a/src/constant_propagation.ml +++ b/src/constant_propagation.ml @@ -301,7 +301,7 @@ let is_env_inconsistent env ksubsts = module StringSet = Set.Make(String) module StringMap = Map.Make(String) -let const_props defs ref_vars = +let const_props target defs ref_vars = let const_fold exp = (* Constant-fold function applications with constant arguments *) let interpreter_istate = @@ -316,7 +316,7 @@ let const_props defs ref_vars = try strip_exp exp |> infer_exp (env_of exp) - |> Constant_fold.rewrite_exp_once interpreter_istate + |> Constant_fold.rewrite_exp_once target interpreter_istate |> keep_undef_typ with | _ -> exp @@ -603,7 +603,7 @@ let const_props defs ref_vars = | E_assert (e1,e2) -> let e1',e2',assigns = non_det_exp_2 e1 e2 in re (E_assert (e1',e2')) assigns - + | E_app_infix _ | E_var _ | E_internal_plet _ @@ -803,15 +803,15 @@ let const_props defs ref_vars = | DoesMatch (subst,ksubst) -> Some (exp,subst,ksubst) | GiveUp -> None in findpat_generic (string_of_exp exp0) assigns cases - + and can_match exp = let env = Type_check.env_of exp in can_match_with_env env exp in (const_prop_exp, const_prop_pexp) -let const_prop d r = fst (const_props d r) -let const_prop_pexp d r = snd (const_props d r) +let const_prop target d r = fst (const_props target d r) +let const_prop_pexp target d r = snd (const_props target d r) let referenced_vars exp = let open Rewriter in diff --git a/src/constant_propagation.mli b/src/constant_propagation.mli index 437492c66..9c182cb09 100644 --- a/src/constant_propagation.mli +++ b/src/constant_propagation.mli @@ -59,6 +59,7 @@ open Type_check (and hence we cannot reliably track). *) val const_prop : + string -> tannot defs -> IdSet.t -> tannot exp Bindings.t * nexp KBindings.t -> diff --git a/src/constant_propagation_mutrec.ml b/src/constant_propagation_mutrec.ml index 285ba45d2..6cc6d28c8 100644 --- a/src/constant_propagation_mutrec.ml +++ b/src/constant_propagation_mutrec.ml @@ -130,7 +130,7 @@ let generate_val_spec env id args l annot = | _, Typ_aux (_, l) -> raise (Reporting.err_unreachable l __POS__ "Function val spec is not a function type") -let const_prop defs substs ksubsts exp = +let const_prop target defs substs ksubsts exp = (* Constant_propagation currently only supports nexps for kid substitutions *) let nexp_substs = KBindings.bindings ksubsts @@ -139,6 +139,7 @@ let const_prop defs substs ksubsts exp = |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty in Constant_propagation.const_prop + target (Defs defs) (Constant_propagation.referenced_vars exp) (substs, nexp_substs) @@ -147,7 +148,7 @@ let const_prop defs substs ksubsts exp = |> fst (* Propagate constant arguments into function clause pexp *) -let prop_args_pexp defs ksubsts args pexp = +let prop_args_pexp target defs ksubsts args pexp = let pat, guard, exp, annot = destruct_pexp pexp in let pats = match pat with | P_aux (P_tup pats, _) -> pats @@ -164,14 +165,14 @@ let prop_args_pexp defs ksubsts args pexp = else (pat :: pats, substs) in let pats, substs = List.fold_right2 match_arg args pats ([], Bindings.empty) in - let exp' = const_prop defs substs ksubsts exp in + let exp' = const_prop target defs substs ksubsts exp in let pat' = match pats with | [pat] -> pat | _ -> P_aux (P_tup pats, (Parse_ast.Unknown, empty_tannot)) in construct_pexp (pat', guard, exp', annot) -let rewrite_defs env (Defs defs) = +let rewrite_defs target env (Defs defs) = let rec rewrite = function | [] -> [] | DEF_internal_mutrec mutrecs :: ds -> @@ -194,7 +195,7 @@ let rewrite_defs env (Defs defs) = let valspec, ksubsts = generate_val_spec env id args l annot in let const_prop_funcl (FCL_aux (FCL_Funcl (_, pexp), (l, _))) = let pexp' = - prop_args_pexp defs ksubsts args pexp + prop_args_pexp target defs ksubsts args pexp |> rewrite_pexp |> strip_pexp in @@ -215,7 +216,7 @@ let rewrite_defs env (Defs defs) = let pexp' = if List.exists (fun id' -> Id.compare id id' = 0) !targets then let pat, guard, body, annot = destruct_pexp pexp in - let body' = const_prop defs Bindings.empty KBindings.empty body in + let body' = const_prop target defs Bindings.empty KBindings.empty body in rewrite_pexp (construct_pexp (pat, guard, recheck_exp body', annot)) else pexp in FCL_aux (FCL_Funcl (id, pexp'), a) diff --git a/src/monomorphise.ml b/src/monomorphise.ml index b8b3b9351..7a43ca6cc 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -620,7 +620,7 @@ let apply_pat_choices choices = e_assert = rewrite_assert; e_case = rewrite_case } -let split_defs all_errors splits env defs = +let split_defs target all_errors splits env defs = let no_errors_happened = ref true in let error_opt = if all_errors then Some no_errors_happened else None in let split_constructors (Defs defs) = @@ -651,7 +651,7 @@ let split_defs all_errors splits env defs = let subst_exp ref_vars substs ksubsts exp = let substs = bindings_from_list substs, ksubsts in - fst (Constant_propagation.const_prop defs ref_vars substs Bindings.empty exp) + fst (Constant_propagation.const_prop target defs ref_vars substs Bindings.empty exp) in (* Split a variable pattern into every possible value *) @@ -3789,7 +3789,7 @@ let recheck defs = let mono_rewrites = MonoRewrites.mono_rewrite -let monomorphise opts splits defs = +let monomorphise target opts splits defs = let defs, env = Type_check.check Type_check.initial_env defs in let ok_analysis, new_splits, extra_splits = if opts.auto @@ -3806,7 +3806,7 @@ let monomorphise opts splits defs = then () else raise (Reporting.err_general Unknown "Unable to monomorphise program") in - let ok_split, defs = split_defs opts.all_split_errors splits env defs in + let ok_split, defs = split_defs target opts.all_split_errors splits env defs in let () = if (ok_analysis && ok_extras && ok_split) || opts.continue_anyway then () else raise (Reporting.err_general Unknown "Unable to monomorphise program") diff --git a/src/monomorphise.mli b/src/monomorphise.mli index 1a82c8d0d..39d89461d 100644 --- a/src/monomorphise.mli +++ b/src/monomorphise.mli @@ -56,6 +56,7 @@ type options = { } val monomorphise : + string -> (* Target backend *) options -> ((string * int) * string) list -> (* List of splits from the command line *) Type_check.tannot Ast.defs -> diff --git a/src/rewrites.ml b/src/rewrites.ml index becf2a88d..d396e18bd 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -4772,9 +4772,10 @@ let opt_auto_mono = ref false let opt_dall_split_errors = ref false let opt_dmono_continue = ref false -let monomorphise env defs = +let monomorphise target env defs = let open Monomorphise in monomorphise + target { auto = !opt_auto_mono; debug_analysis = !opt_dmono_analysis; all_split_errors = !opt_dall_split_errors; @@ -4850,12 +4851,12 @@ let all_rewrites = [ ("mapping_builtins", Basic_rewriter rewrite_defs_mapping_patterns); ("mono_rewrites", Basic_rewriter mono_rewrites); ("toplevel_nexps", Basic_rewriter rewrite_toplevel_nexps); - ("monomorphise", Basic_rewriter monomorphise); + ("monomorphise", String_rewriter (fun target -> Basic_rewriter (monomorphise target))); ("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons)); ("add_bitvector_casts", Basic_rewriter (fun _ -> Monomorphise.add_bitvector_casts)); ("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons)); ("remove_impossible_int_cases", Basic_rewriter Constant_propagation.remove_impossible_int_cases); - ("const_prop_mutrec", Basic_rewriter Constant_propagation_mutrec.rewrite_defs); + ("const_prop_mutrec", String_rewriter (fun target -> Basic_rewriter (Constant_propagation_mutrec.rewrite_defs target))); ("make_cases_exhaustive", Basic_rewriter MakeExhaustive.rewrite); ("undefined", Bool_rewriter (fun b -> Basic_rewriter (rewrite_undefined_if_gen b))); ("vector_string_pats_to_bit_list", Basic_rewriter rewrite_defs_vector_string_pats_to_bit_list); @@ -4887,7 +4888,7 @@ let all_rewrites = [ ("simple_types", Basic_rewriter rewrite_simple_types); ("overload_cast", Basic_rewriter rewrite_overload_cast); ("top_sort_defs", Basic_rewriter (fun _ -> top_sort_defs)); - ("constant_fold", Basic_rewriter (fun _ -> Constant_fold.rewrite_constant_function_calls)); + ("constant_fold", String_rewriter (fun target -> Basic_rewriter (fun _ -> Constant_fold.rewrite_constant_function_calls target))); ("split", String_rewriter (fun str -> Basic_rewriter (rewrite_split_fun_ctor_pats str))); ("properties", Basic_rewriter (fun _ -> Property.rewrite)); ] @@ -4902,7 +4903,7 @@ let rewrites_lem = [ ("recheck_defs", [If_mono_arg]); ("undefined", [Bool_arg false]); ("toplevel_nexps", [If_mono_arg]); - ("monomorphise", [If_mono_arg]); + ("monomorphise", [String_arg "lem"; If_mono_arg]); ("recheck_defs", [If_mwords_arg]); ("add_bitvector_casts", [If_mwords_arg]); ("atoms_to_singletons", [If_mono_arg]); @@ -4925,7 +4926,7 @@ let rewrites_lem = [ ("split", [String_arg "execute"]); ("recheck_defs", []); ("top_sort_defs", []); - ("const_prop_mutrec", []); + ("const_prop_mutrec", [String_arg "lem"]); ("vector_string_pats_to_bit_list", []); ("exp_lift_assign", []); ("early_return", []); @@ -5021,7 +5022,7 @@ let rewrites_c = [ ("mono_rewrites", [If_mono_arg]); ("recheck_defs", [If_mono_arg]); ("toplevel_nexps", [If_mono_arg]); - ("monomorphise", [If_mono_arg]); + ("monomorphise", [String_arg "c"; If_mono_arg]); ("atoms_to_singletons", [If_mono_arg]); ("recheck_defs", [If_mono_arg]); ("undefined", [Bool_arg false]); @@ -5036,7 +5037,7 @@ let rewrites_c = [ ("exp_lift_assign", []); ("merge_function_clauses", []); ("optimize_recheck_defs", []); - ("constant_fold", []) + ("constant_fold", [String_arg "c"]) ] let rewrites_interpreter = [