diff --git a/middle_end/flambda/simplify/simplify_expr.ml b/middle_end/flambda/simplify/simplify_expr.ml index 34c9c4fc7526..5613172b1aa7 100644 --- a/middle_end/flambda/simplify/simplify_expr.ml +++ b/middle_end/flambda/simplify/simplify_expr.ml @@ -37,6 +37,7 @@ let rec simplify_expr dacc expr ~down_to_up = Simplify_apply_cont_expr.simplify_apply_cont dacc apply_cont ~down_to_up | Switch switch -> Simplify_switch_expr.simplify_switch ~simplify_let dacc switch ~down_to_up + ~original_expr:expr | Invalid _ -> (* CR mshinwell: Make sure that a program can be simplified to just [Invalid]. [Un_cps] should translate any [Invalid] that it sees as if diff --git a/middle_end/flambda/simplify/simplify_switch_expr.ml b/middle_end/flambda/simplify/simplify_switch_expr.ml index 05b9135e5715..2b25c6952119 100644 --- a/middle_end/flambda/simplify/simplify_switch_expr.ml +++ b/middle_end/flambda/simplify/simplify_switch_expr.ml @@ -18,8 +18,8 @@ open! Simplify_import -let rebuild_switch ~simplify_let dacc ~arms ~scrutinee ~scrutinee_ty uacc - ~after_rebuild = +let rebuild_switch ~arms ~scrutinee ~scrutinee_ty + ~tagged_scrutinee ~not_scrutinee uacc ~after_rebuild = let new_let_conts, arms, identity_arms, not_arms = Target_imm.Map.fold (fun arm (action, use_id, arity) @@ -141,38 +141,6 @@ let rebuild_switch ~simplify_let dacc ~arms ~scrutinee ~scrutinee_ty uacc |> Continuation.Set.of_list |> Continuation.Set.get_singleton in - let create_tagged_scrutinee uacc dest ~make_body = - (* A problem with using [simplify_let] below is that the continuation - [dest] might have [Apply_cont_rewrite]s in the environment, left over - from the simplification of the existing uses. We must clear these to - avoid a lookup failure for our new [Apply_cont] when - [Simplify_apply_cont] tries to rewrite the use. There is no need for - the rewrites anyway; they have already been applied. - Likewise, we need to clear the continuation uses environment for - [dest] in [dacc], since our new [Apply_cont] might not match the - original uses (e.g. if a parameter has been removed). *) - let uacc = - UA.map_uenv uacc ~f:(fun uenv -> - UE.delete_apply_cont_rewrite uenv dest) - in - let dacc = DA.delete_continuation_uses dacc dest in - let bound_to = Variable.create "tagged_scrutinee" in - let body = make_body ~tagged_scrutinee:(Simple.var bound_to) in - let bound_to = Var_in_binding_pos.create bound_to NM.normal in - let defining_expr = - Named.create_prim (Unary (Box_number Untagged_immediate, scrutinee)) - Debuginfo.none - in - let let_expr = - Let.create (Bindable_let_bound.singleton bound_to) - defining_expr - ~body - ~free_names_of_body:Unknown - in - simplify_let dacc let_expr - ~down_to_up:(fun _dacc ~rebuild -> - rebuild uacc ~after_rebuild:(fun expr uacc -> expr, uacc)) - in (* CR mshinwell: Here and elsewhere [UA.name_occurrences] should be empty (maybe except for closure vars? -- check). We should add asserts. *) let body, uacc = @@ -185,37 +153,29 @@ let rebuild_switch ~simplify_let dacc ~arms ~scrutinee ~scrutinee_ty uacc let dbg = Debuginfo.none in match switch_is_identity with | Some dest -> + let apply_cont = Apply_cont.create dest ~args:[tagged_scrutinee] ~dbg in + let uacc = + UA.map_uenv uacc ~f:(fun uenv -> + UE.delete_apply_cont_rewrite uenv dest) + in let uacc = - UA.notify_removed ~operation:Removed_operations.branch uacc + UA.add_free_names uacc (Apply_cont.free_names apply_cont) |> + UA.notify_removed ~operation:Removed_operations.branch in - create_tagged_scrutinee uacc dest ~make_body:(fun ~tagged_scrutinee -> - (* No need to increment the cost_metrics inside [create_tagged_scrutinee] as it - will call simplify over the result of [make_body]. *) - Apply_cont.create dest ~args:[tagged_scrutinee] ~dbg - |> Expr.create_apply_cont) + Rebuilt_expr.create_apply_cont apply_cont, uacc | None -> match switch_is_boolean_not with | Some dest -> + let apply_cont = Apply_cont.create dest ~args:[not_scrutinee] ~dbg in let uacc = - UA.notify_removed ~operation:Removed_operations.branch uacc + UA.map_uenv uacc ~f:(fun uenv -> + UE.delete_apply_cont_rewrite uenv dest) in - create_tagged_scrutinee uacc dest ~make_body:(fun ~tagged_scrutinee -> - let not_scrutinee = Variable.create "not_scrutinee" in - let not_scrutinee' = Simple.var not_scrutinee in - let do_tagging = - Named.create_prim (P.Unary (Boolean_not, tagged_scrutinee)) - Debuginfo.none - in - let bound = - VB.create not_scrutinee NM.normal - |> Bindable_let_bound.singleton - in - let body = - Apply_cont.create dest ~args:[not_scrutinee'] ~dbg - |> Expr.create_apply_cont - in - Let.create bound do_tagging ~body ~free_names_of_body:Unknown - |> Expr.create_let) + let uacc = + UA.add_free_names uacc (Apply_cont.free_names apply_cont) |> + UA.notify_removed ~operation:Removed_operations.branch + in + Rebuilt_expr.create_apply_cont apply_cont, uacc | None -> (* In that case, even though some branches were removed by simplify we should not count them in the number of removed operations: these @@ -258,13 +218,10 @@ let rebuild_switch ~simplify_let dacc ~arms ~scrutinee ~scrutinee_ty uacc in after_rebuild expr uacc -let simplify_switch ~simplify_let dacc switch ~down_to_up = +let simplify_switch_aux ~scrutinee ~scrutinee_ty + ~tagged_scrutinee ~not_scrutinee + dacc switch ~down_to_up = let module AC = Apply_cont in - let scrutinee = Switch.scrutinee switch in - let scrutinee_ty = - S.simplify_simple dacc scrutinee ~min_name_mode:NM.normal - in - let scrutinee = T.get_alias_exn scrutinee_ty in let arms, dacc = let typing_env_at_use = DA.typing_env dacc in Target_imm.Map.fold (fun arm action (arms, dacc) -> @@ -306,5 +263,77 @@ let simplify_switch ~simplify_let dacc switch ~down_to_up = (Target_imm.Map.empty, dacc) in down_to_up dacc - ~rebuild:(rebuild_switch ~simplify_let dacc ~arms ~scrutinee - ~scrutinee_ty) + ~rebuild:(rebuild_switch ~arms ~scrutinee + ~scrutinee_ty ~tagged_scrutinee ~not_scrutinee) + +let simplify_switch ~simplify_let ~original_expr dacc switch ~down_to_up = + let scrutinee = Switch.scrutinee switch in + let scrutinee_ty = + S.simplify_simple dacc scrutinee ~min_name_mode:NM.normal + in + let scrutinee = T.get_alias_exn scrutinee_ty in + let find_cse_simple prim = + (* prim is either boolean not or tagging of non constant values *) + let with_fixed_value = P.Eligible_for_cse.create_exn prim in + match DE.find_cse (DA.denv dacc) with_fixed_value with + | None -> None + | Some simple -> + match + TE.get_canonical_simple_exn (DA.typing_env dacc) simple + ~min_name_mode:NM.normal + ~name_mode_of_existing_simple:NM.normal + with + | exception Not_found -> None + | simple -> Some simple + in + let create_def name prim = + let bound_to = Variable.create name in + let bound_to = Var_in_binding_pos.create bound_to NM.normal in + let defining_expr = Named.create_prim prim Debuginfo.none in + let let_expr = + Let.create (Bindable_let_bound.singleton bound_to) + defining_expr + ~body:original_expr + ~free_names_of_body:Unknown + in + simplify_let dacc let_expr ~down_to_up + in + let tag_prim = P.Unary (Box_number Untagged_immediate, scrutinee) in + Simple.pattern_match scrutinee + ~const:(fun const -> + match Reg_width_things.Const.descr const with + | Naked_immediate imm -> + let tagged_scrutinee = + Simple.const (Reg_width_things.Const.tagged_immediate imm) + in + let not_scrutinee = + let not_imm = + if Target_imm.equal imm Target_imm.zero then + Target_imm.one + else + (* If the scrutinee is neither zero nor one, this value + won't be used *) + Target_imm.zero + in + Simple.const (Reg_width_things.Const.tagged_immediate not_imm) + in + simplify_switch_aux dacc switch ~down_to_up + ~tagged_scrutinee ~not_scrutinee + ~scrutinee ~scrutinee_ty + | Tagged_immediate _ | Naked_float _ | Naked_int32 _ + | Naked_int64 _ | Naked_nativeint _ -> + Misc.fatal_errorf "Switch scrutinee is not a naked immediate: %a" + Simple.print scrutinee) + ~name:(fun _ -> + match find_cse_simple tag_prim with + | None -> + create_def "tagged_scrutinee" tag_prim + | Some tagged_scrutinee -> + let not_prim = P.Unary (Boolean_not, tagged_scrutinee) in + match find_cse_simple not_prim with + | None -> + create_def "not_scrutinee" not_prim + | Some not_scrutinee -> + simplify_switch_aux dacc switch ~down_to_up + ~tagged_scrutinee ~not_scrutinee + ~scrutinee ~scrutinee_ty) diff --git a/middle_end/flambda/simplify/simplify_switch_expr.mli b/middle_end/flambda/simplify/simplify_switch_expr.mli index 6d623783953a..c9af3559f259 100644 --- a/middle_end/flambda/simplify/simplify_switch_expr.mli +++ b/middle_end/flambda/simplify/simplify_switch_expr.mli @@ -18,4 +18,5 @@ val simplify_switch : simplify_let:Flambda.Let.t Simplify_common.expr_simplifier + -> original_expr:Flambda.Expr.t -> Flambda.Switch.t Simplify_common.expr_simplifier