diff --git a/compiler/rustc_mir_transform/src/cost_checker.rs b/compiler/rustc_mir_transform/src/cost_checker.rs index 2c692c9500303..7e42cf86ca9d0 100644 --- a/compiler/rustc_mir_transform/src/cost_checker.rs +++ b/compiler/rustc_mir_transform/src/cost_checker.rs @@ -2,10 +2,30 @@ use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt}; -const INSTR_COST: usize = 5; -const CALL_PENALTY: usize = 25; -const LANDINGPAD_PENALTY: usize = 50; -const RESUME_PENALTY: usize = 45; +// Even if they're zero-cost at runtime, everything has *some* cost to inline +// in terms of copying them into the MIR caller, processing them in codegen, etc. +// These baseline costs give a simple usually-too-low estimate of the cost, +// which will be updated afterwards to account for the "real" costs. +const STMT_BASELINE_COST: usize = 1; +const BLOCK_BASELINE_COST: usize = 3; +const DEBUG_BASELINE_COST: usize = 1; +const LOCAL_BASELINE_COST: usize = 1; + +// These penalties represent the cost above baseline for those things which +// have substantially more cost than is typical for their kind. +const CALL_PENALTY: usize = 22; +const LANDINGPAD_PENALTY: usize = 47; +const RESUME_PENALTY: usize = 42; +const DEREF_PENALTY: usize = 4; +const CHECKED_OP_PENALTY: usize = 2; +const THREAD_LOCAL_PENALTY: usize = 20; +const SMALL_SWITCH_PENALTY: usize = 3; +const LARGE_SWITCH_PENALTY: usize = 20; + +// Passing arguments isn't free, so give a bonus to functions with lots of them: +// if the body is small despite lots of arguments, some are probably unused. +const EXTRA_ARG_BONUS: usize = 4; +const MAX_ARG_BONUS: usize = CALL_PENALTY; /// Verify that the callee body is compatible with the caller. #[derive(Clone)] @@ -27,6 +47,20 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> { CostChecker { tcx, param_env, callee_body, instance, cost: 0 } } + // `Inline` doesn't call `visit_body`, so this is separate from the visitor. + pub fn before_body(&mut self, body: &Body<'tcx>) { + self.cost += BLOCK_BASELINE_COST * body.basic_blocks.len(); + self.cost += DEBUG_BASELINE_COST * body.var_debug_info.len(); + self.cost += LOCAL_BASELINE_COST * body.local_decls.len(); + + let total_statements = body.basic_blocks.iter().map(|x| x.statements.len()).sum::(); + self.cost += STMT_BASELINE_COST * total_statements; + + if let Some(extra_args) = body.arg_count.checked_sub(2) { + self.cost = self.cost.saturating_sub((EXTRA_ARG_BONUS * extra_args).min(MAX_ARG_BONUS)); + } + } + pub fn cost(&self) -> usize { self.cost } @@ -41,14 +75,70 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> { } impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { - fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) { - // Don't count StorageLive/StorageDead in the inlining cost. - match statement.kind { - StatementKind::StorageLive(_) + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + StatementKind::Assign(place_and_rvalue) => { + if place_and_rvalue.0.is_indirect_first_projection() { + self.cost += DEREF_PENALTY; + } + self.visit_rvalue(&place_and_rvalue.1, location); + } + StatementKind::Intrinsic(intr) => match &**intr { + NonDivergingIntrinsic::Assume(..) => {} + NonDivergingIntrinsic::CopyNonOverlapping(_cno) => { + self.cost += CALL_PENALTY; + } + }, + StatementKind::FakeRead(..) + | StatementKind::SetDiscriminant { .. } + | StatementKind::StorageLive(_) | StatementKind::StorageDead(_) + | StatementKind::Retag(..) + | StatementKind::PlaceMention(..) + | StatementKind::AscribeUserType(..) + | StatementKind::Coverage(..) | StatementKind::Deinit(_) - | StatementKind::Nop => {} - _ => self.cost += INSTR_COST, + | StatementKind::ConstEvalCounter + | StatementKind::Nop => { + // No extra cost for these + } + } + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) { + match rvalue { + Rvalue::Use(operand) => { + if let Some(place) = operand.place() + && place.is_indirect_first_projection() + { + self.cost += DEREF_PENALTY; + } + } + Rvalue::Repeat(_item, count) => { + let count = count.try_to_target_usize(self.tcx).unwrap_or(u64::MAX); + self.cost += (STMT_BASELINE_COST * count as usize).min(CALL_PENALTY); + } + Rvalue::Aggregate(_kind, fields) => { + self.cost += STMT_BASELINE_COST * fields.len(); + } + Rvalue::CheckedBinaryOp(..) => { + self.cost += CHECKED_OP_PENALTY; + } + Rvalue::ThreadLocalRef(..) => { + self.cost += THREAD_LOCAL_PENALTY; + } + Rvalue::Ref(..) + | Rvalue::AddressOf(..) + | Rvalue::Len(..) + | Rvalue::Cast(..) + | Rvalue::BinaryOp(..) + | Rvalue::NullaryOp(..) + | Rvalue::UnaryOp(..) + | Rvalue::Discriminant(..) + | Rvalue::ShallowInitBox(..) + | Rvalue::CopyForDeref(..) => { + // No extra cost for these + } } } @@ -63,24 +153,35 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { if let UnwindAction::Cleanup(_) = unwind { self.cost += LANDINGPAD_PENALTY; } - } else { - self.cost += INSTR_COST; } } - TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => { - let fn_ty = self.instantiate_ty(f.const_.ty()); - self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() + TerminatorKind::Call { ref func, unwind, .. } => { + if let Some(f) = func.constant() + && let fn_ty = self.instantiate_ty(f.ty()) + && let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.intrinsic(def_id).is_some() { // Don't give intrinsics the extra penalty for calls - INSTR_COST } else { - CALL_PENALTY + self.cost += CALL_PENALTY; }; if let UnwindAction::Cleanup(_) = unwind { self.cost += LANDINGPAD_PENALTY; } } + TerminatorKind::SwitchInt { ref discr, ref targets } => { + if let Operand::Constant(..) = discr { + // This'll be a goto once we're monomorphizing + } else { + // 0/1/unreachable is extremely common (bool, Option, Result, ...) + // but once there's more this can be a fair bit of work. + self.cost += if targets.all_targets().len() <= 3 { + SMALL_SWITCH_PENALTY + } else { + LARGE_SWITCH_PENALTY + }; + } + } TerminatorKind::Assert { unwind, .. } => { self.cost += CALL_PENALTY; if let UnwindAction::Cleanup(_) = unwind { @@ -89,12 +190,20 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { } TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY, TerminatorKind::InlineAsm { unwind, .. } => { - self.cost += INSTR_COST; if let UnwindAction::Cleanup(_) = unwind { self.cost += LANDINGPAD_PENALTY; } } - _ => self.cost += INSTR_COST, + TerminatorKind::Goto { .. } + | TerminatorKind::UnwindTerminate(..) + | TerminatorKind::Return + | TerminatorKind::Yield { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::Unreachable => { + // No extra cost for these + } } } } diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 5f74841151cda..ea75a72d21729 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -506,6 +506,17 @@ impl<'tcx> Inliner<'tcx> { let mut checker = CostChecker::new(self.tcx, self.param_env, Some(callsite.callee), callee_body); + checker.before_body(callee_body); + + let baseline_cost = checker.cost(); + if baseline_cost > threshold { + debug!( + "NOT inlining {:?} [baseline_cost={} > threshold={}]", + callsite, baseline_cost, threshold + ); + return Err("baseline_cost above threshold"); + } + // Traverse the MIR manually so we can account for the effects of inlining on the CFG. let mut work_list = vec![START_BLOCK]; let mut visited = BitSet::new_empty(callee_body.basic_blocks.len());