Skip to content

Commit

Permalink
upstream rustc_codegen_ssa/rustc_middle changes for enzyme/autodiff
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Jan 2, 2025
1 parent c528b8c commit 2ad340e
Show file tree
Hide file tree
Showing 25 changed files with 444 additions and 26 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4233,6 +4233,7 @@ name = "rustc_monomorphize"
version = "0.0.0"
dependencies = [
"rustc_abi",
"rustc_ast",
"rustc_attr_parsing",
"rustc_data_structures",
"rustc_errors",
Expand All @@ -4242,6 +4243,7 @@ dependencies = [
"rustc_middle",
"rustc_session",
"rustc_span",
"rustc_symbol_mangling",
"rustc_target",
"serde",
"serde_json",
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ pub(crate) fn differentiate<'ll, 'tcx>(
}

let diag_handler = cgcx.create_dcx();
let (_, cgus) = tcx.collect_and_partition_mono_items(());
let (_, _, cgus) = tcx.collect_and_partition_mono_items(());
let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm);

// Before dumping the module, we want all the TypeTrees to become part of the module.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ struct UsageSets<'tcx> {
/// Prepare sets of definitions that are relevant to deciding whether something
/// is an "unused function" for coverage purposes.
fn prepare_usage_sets<'tcx>(tcx: TyCtxt<'tcx>) -> UsageSets<'tcx> {
let (all_mono_items, cgus) = tcx.collect_and_partition_mono_items(());
let (all_mono_items, _, cgus) = tcx.collect_and_partition_mono_items(());

// Obtain a MIR body for each function participating in codegen, via an
// arbitrary instance.
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_ssa/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,6 @@ codegen_ssa_use_cargo_directive = use the `cargo:rustc-link-lib` directive to sp
codegen_ssa_version_script_write_failure = failed to write version script: {$error}
codegen_ssa_visual_studio_not_installed = you may need to install Visual Studio build tools with the "C++ build tools" workload
codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/assert_module_sources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub fn assert_module_sources(tcx: TyCtxt<'_>, set_reuse: &dyn Fn(&mut CguReuseTr
}

let available_cgus =
tcx.collect_and_partition_mono_items(()).1.iter().map(|cgu| cgu.name()).collect();
tcx.collect_and_partition_mono_items(()).2.iter().map(|cgu| cgu.name()).collect();

let mut ams = AssertModuleSource {
tcx,
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/back/symbol_export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ fn exported_symbols_provider_local(
// external linkage is enough for monomorphization to be linked to.
let need_visibility = tcx.sess.target.dynamic_linking && !tcx.sess.target.only_cdylib;

let (_, cgus) = tcx.collect_and_partition_mono_items(());
let (_, _, cgus) = tcx.collect_and_partition_mono_items(());

// The symbols created in this loop are sorted below it
#[allow(rustc::potential_query_instability)]
Expand Down
37 changes: 33 additions & 4 deletions compiler/rustc_codegen_ssa/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
use std::{fs, io, mem, str, thread};

use rustc_ast::attr;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
use rustc_data_structures::jobserver::{self, Acquired};
use rustc_data_structures::memmap::Mmap;
Expand Down Expand Up @@ -40,7 +41,7 @@ use tracing::debug;
use super::link::{self, ensure_removed};
use super::lto::{self, SerializedModule};
use super::symbol_export::symbol_name_for_instance_in_crate;
use crate::errors::ErrorCreatingRemarkDir;
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
use crate::traits::*;
use crate::{
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
Expand Down Expand Up @@ -118,6 +119,7 @@ pub struct ModuleConfig {
pub merge_functions: bool,
pub emit_lifetime_markers: bool,
pub llvm_plugins: Vec<String>,
pub autodiff: Vec<config::AutoDiff>,
}

impl ModuleConfig {
Expand Down Expand Up @@ -266,6 +268,7 @@ impl ModuleConfig {

emit_lifetime_markers: sess.emit_lifetime_markers(),
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
}
}

Expand Down Expand Up @@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {

fn generate_lto_work<B: ExtraBackendMethods>(
cgcx: &CodegenContext<B>,
autodiff: Vec<AutoDiffItem>,
needs_fat_lto: Vec<FatLtoInput<B>>,
needs_thin_lto: Vec<(String, B::ThinBuffer)>,
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
Expand All @@ -399,9 +403,18 @@ fn generate_lto_work<B: ExtraBackendMethods>(
assert!(needs_thin_lto.is_empty());
let module =
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
if cgcx.lto == Lto::Fat {
let _config = cgcx.config(ModuleKind::Regular);
todo!("fat LTO with autodiff is not yet implemented");
//module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
}
// We are adding a single work item, so the cost doesn't matter.
vec![(WorkItem::LTO(module), 0)]
} else {
if !autodiff.is_empty() {
let dcx = cgcx.create_dcx();
dcx.handle().emit_fatal(AutodiffWithoutLto {});
}
assert!(needs_fat_lto.is_empty());
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
.unwrap_or_else(|e| e.raise());
Expand Down Expand Up @@ -1021,6 +1034,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
/// Sent from a backend worker thread.
WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },

/// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
AddAutoDiffItems(Vec<AutoDiffItem>),

/// The frontend has finished generating something (backend IR or a
/// post-LTO artifact) for a codegen unit, and it should be passed to the
/// backend. Sent from the main thread.
Expand Down Expand Up @@ -1348,6 +1364,7 @@ fn start_executing_work<B: ExtraBackendMethods>(

// This is where we collect codegen units that have gone all the way
// through codegen and LLVM.
let mut autodiff_items = Vec::new();
let mut compiled_modules = vec![];
let mut compiled_allocator_module = None;
let mut needs_link = Vec::new();
Expand Down Expand Up @@ -1459,9 +1476,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
let needs_thin_lto = mem::take(&mut needs_thin_lto);
let import_only_modules = mem::take(&mut lto_import_only_modules);

for (work, cost) in
generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
{
for (work, cost) in generate_lto_work(
&cgcx,
autodiff_items.clone(),
needs_fat_lto,
needs_thin_lto,
import_only_modules,
) {
let insertion_index = work_items
.binary_search_by_key(&cost, |&(_, cost)| cost)
.unwrap_or_else(|e| e);
Expand Down Expand Up @@ -1596,6 +1617,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
main_thread_state = MainThreadState::Idle;
}

Message::AddAutoDiffItems(mut items) => {
autodiff_items.append(&mut items);
}

Message::CodegenComplete => {
if codegen_state != Aborted {
codegen_state = Completed;
Expand Down Expand Up @@ -2070,6 +2095,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
}

pub(crate) fn submit_autodiff_items(&self, items: Vec<AutoDiffItem>) {
drop(self.coordinator.sender.send(Box::new(Message::<B>::AddAutoDiffItems(items))));
}

pub(crate) fn check_for_errors(&self, sess: &Session) {
self.shared_emitter_main.check(sess, false);
}
Expand Down
9 changes: 7 additions & 2 deletions compiler/rustc_codegen_ssa/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,8 @@ pub fn codegen_crate<B: ExtraBackendMethods>(

// Run the monomorphization collector and partition the collected items into
// codegen units.
let codegen_units = tcx.collect_and_partition_mono_items(()).1;
let (_, autodiff_fncs, codegen_units) = tcx.collect_and_partition_mono_items(());
let autodiff_fncs = autodiff_fncs.to_vec();

// Force all codegen_unit queries so they are already either red or green
// when compile_codegen_unit accesses them. We are not able to re-execute
Expand Down Expand Up @@ -691,6 +692,10 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
);
}

if !autodiff_fncs.is_empty() {
ongoing_codegen.submit_autodiff_items(autodiff_fncs);
}

// For better throughput during parallel processing by LLVM, we used to sort
// CGUs largest to smallest. This would lead to better thread utilization
// by, for example, preventing a large CGU from being processed last and
Expand Down Expand Up @@ -1051,7 +1056,7 @@ pub(crate) fn provide(providers: &mut Providers) {
config::OptLevel::SizeMin => config::OptLevel::Default,
};

let (defids, _) = tcx.collect_and_partition_mono_items(cratenum);
let (defids, _, _) = tcx.collect_and_partition_mono_items(cratenum);

let any_for_speed = defids.items().any(|id| {
let CodegenFnAttrs { optimize, .. } = tcx.codegen_fn_attrs(*id);
Expand Down
136 changes: 134 additions & 2 deletions compiler/rustc_codegen_ssa/src/codegen_attrs.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use std::str::FromStr;

use rustc_ast::attr::list_contains_name;
use rustc_ast::{MetaItemInner, attr};
use rustc_ast::expand::autodiff_attrs::{
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
};
use rustc_ast::{MetaItem, MetaItemInner, attr};
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
use rustc_data_structures::fx::FxHashMap;
use rustc_errors::codes::*;
Expand Down Expand Up @@ -828,6 +833,133 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
}
}

/// We now check the #[rustc_autodiff] attributes which we generated from the #[autodiff(...)]
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
/// panic, unless we introduced a bug when parsing the autodiff macro.
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);

let attrs =
attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::<Vec<_>>();

// check for exactly one autodiff attribute on placeholder functions.
// There should only be one, since we generate a new placeholder per ad macro.
// TODO: re-enable this. We should fix that rustc_autodiff isn't applied multiple times to the
// source function.
let msg_once = "cg_ssa: implementation bug. Autodiff attribute can only be applied once";
let attr = match attrs.len() {
0 => return AutoDiffAttrs::error(),
1 => attrs.get(0).unwrap(),
_ => {
attrs.get(0).unwrap()
//tcx.dcx().struct_span_err(attrs[1].span, msg_once).with_note("more than one").emit();
//return AutoDiffAttrs::error();
}
};

let list = attr.meta_item_list().unwrap_or_default();

// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
if list.len() == 0 {
return AutoDiffAttrs::source();
}

let [mode, input_activities @ .., ret_activity] = &list[..] else {
tcx.dcx()
.struct_span_err(attr.span, msg_once)
.with_note("Implementation bug in autodiff_attrs. Please report this!")
.emit();
return AutoDiffAttrs::error();
};
let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode {
p1.segments.first().unwrap().ident
} else {
let msg = "autodiff attribute must contain autodiff mode";
tcx.dcx().struct_span_err(attr.span, msg).with_note("empty argument list").emit();
return AutoDiffAttrs::error();
};

// parse mode
let msg_mode = "mode should be either forward or reverse";
let mode = match mode.as_str() {
"Forward" => DiffMode::Forward,
"Reverse" => DiffMode::Reverse,
"ForwardFirst" => DiffMode::ForwardFirst,
"ReverseFirst" => DiffMode::ReverseFirst,
_ => {
tcx.dcx().struct_span_err(attr.span, msg_mode).with_note("invalid mode").emit();
return AutoDiffAttrs::error();
}
};

// First read the ret symbol from the attribute
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
p1.segments.first().unwrap().ident
} else {
let msg = "autodiff attribute must contain the return activity";
tcx.dcx().struct_span_err(attr.span, msg).with_note("missing return activity").emit();
return AutoDiffAttrs::error();
};

// Then parse it into an actual DiffActivity
let msg_unknown_ret_activity = "unknown return activity";
let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) {
Ok(x) => x,
Err(_) => {
tcx.dcx()
.struct_span_err(attr.span, msg_unknown_ret_activity)
.with_note("invalid return activity")
.emit();
return AutoDiffAttrs::error();
}
};

// Now parse all the intermediate (input) activities
let msg_arg_activity = "autodiff attribute must contain the return activity";
let mut arg_activities: Vec<DiffActivity> = vec![];
for arg in input_activities {
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg {
p2.segments.first().unwrap().ident
} else {
tcx.dcx()
.struct_span_err(attr.span, msg_arg_activity)
.with_note("Implementation bug, please report this!")
.emit();
return AutoDiffAttrs::error();
};

match DiffActivity::from_str(arg_symbol.as_str()) {
Ok(arg_activity) => arg_activities.push(arg_activity),
Err(_) => {
tcx.dcx()
.struct_span_err(attr.span, msg_unknown_ret_activity)
.with_note("invalid input activity")
.emit();
return AutoDiffAttrs::error();
}
}
}

let mut msg = "".to_string();
for &input in &arg_activities {
if !valid_input_activity(mode, input) {
msg = format!("Invalid input activity {} for {} mode", input, mode);
}
}
if !valid_ret_activity(mode, ret_activity) {
msg = format!("Invalid return activity {} for {} mode", ret_activity, mode);
}
if msg != "".to_string() {
tcx.dcx().struct_span_err(attr.span, msg).with_note("invalid activity").emit();
return AutoDiffAttrs::error();
}

AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities }
}

pub(crate) fn provide(providers: &mut Providers) {
*providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
*providers =
Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers };
}
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_ssa/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ pub(crate) struct CguNotRecorded<'a> {
pub cgu_name: &'a str,
}

#[derive(Diagnostic)]
#[diag(codegen_ssa_autodiff_without_lto)]
pub struct AutodiffWithoutLto;

#[derive(Diagnostic)]
#[diag(codegen_ssa_unknown_reuse_kind)]
pub(crate) struct UnknownReuseKind {
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_ssa/src/traits/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone {
type ModuleBuffer: ModuleBufferMethods;
type ThinData: Send + Sync;
type ThinBuffer: ThinBufferMethods;
//type TypeTree: Clone;

/// Merge all modules into main_module and returning it
fn run_link(
Expand All @@ -38,6 +39,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone {
) -> Result<(Vec<LtoModuleCodegen<Self>>, Vec<WorkProduct>), FatalError>;
fn print_pass_timings(&self);
fn print_statistics(&self);
// does enzyme prep work, should do ad too.
unsafe fn optimize(
cgcx: &CodegenContext<Self>,
dcx: DiagCtxtHandle<'_>,
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_interface/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ fn test_unstable_options_tracking_hash() {
tracked!(allow_features, Some(vec![String::from("lang_items")]));
tracked!(always_encode_mir, true);
tracked!(assume_incomplete_release, true);
tracked!(autodiff, vec![String::from("ad_flags")]);
tracked!(binary_dep_depinfo, true);
tracked!(box_noalias, false);
tracked!(
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/messages.ftl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
middle_autodiff_unsafe_inner_const_ref = reading from a `Duplicated` const {$ty} is unsafe
middle_unsupported_union = we don't support unions yet: '{$ty_name}'
middle_adjust_for_foreign_abi_error =
target architecture {$arch} does not support `extern {$abi}` ABI
Expand Down
Loading

0 comments on commit 2ad340e

Please sign in to comment.