diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 8c5c20c7af48c..1ff4fc6aaab22 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -73,10 +73,10 @@ mod llvm_enzyme { } // Get information about the function the macro is applied to - fn extract_item_info(iitem: &P) -> Option<(Visibility, FnSig, Ident)> { + fn extract_item_info(iitem: &P) -> Option<(Visibility, FnSig, Ident, Generics)> { match &iitem.kind { - ItemKind::Fn(box ast::Fn { sig, ident, .. }) => { - Some((iitem.vis.clone(), sig.clone(), ident.clone())) + ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { + Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, } @@ -210,16 +210,18 @@ mod llvm_enzyme { } let dcx = ecx.sess.dcx(); - // first get information about the annotable item: - let Some((vis, sig, primal)) = (match &item { + // first get information about the annotable item: visibility, signature, name and generic + // parameters. + // these will be used to generate the differentiated version of the function + let Some((vis, sig, primal, generics)) = (match &item { Annotatable::Item(iitem) => extract_item_info(iitem), Annotatable::Stmt(stmt) => match &stmt.kind { ast::StmtKind::Item(iitem) => extract_item_info(iitem), _ => None, }, Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => { - Some((assoc_item.vis.clone(), sig.clone(), ident.clone())) + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { + Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, }, @@ -303,6 +305,7 @@ mod llvm_enzyme { let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); let d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, + &generics, ); // The first element of it is the name of the function to be generated @@ -310,7 +313,7 @@ mod llvm_enzyme { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics: Generics::default(), + generics, contract: None, body: Some(d_body), define_opaque: None, @@ -475,6 +478,7 @@ mod llvm_enzyme { new_decl_span: Span, idents: &[Ident], errored: bool, + generics: &Generics, ) -> (P, P, P, P) { let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); let noop = ast::InlineAsm { @@ -497,7 +501,7 @@ mod llvm_enzyme { }; let unsf_expr = ecx.expr_block(P(unsf_block)); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); - let primal_call = gen_primal_call(ecx, span, primal, idents); + let primal_call = gen_primal_call(ecx, span, primal, idents, generics); let black_box_primal_call = ecx.expr_call( new_decl_span, blackbox_call_expr.clone(), @@ -546,6 +550,7 @@ mod llvm_enzyme { sig_span: Span, idents: Vec, errored: bool, + generics: &Generics, ) -> P { let new_decl_span = d_sig.span; @@ -566,6 +571,7 @@ mod llvm_enzyme { new_decl_span, &idents, errored, + generics, ); if !has_ret(&d_sig.decl.output) { @@ -608,7 +614,6 @@ mod llvm_enzyme { panic!("Did not expect Default ret ty: {:?}", span); } }; - if x.mode.is_fwd() { // Fwd mode is easy. If the return activity is Const, we support arbitrary types. // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. @@ -668,8 +673,10 @@ mod llvm_enzyme { span: Span, primal: Ident, idents: &[Ident], + generics: &Generics, ) -> P { let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; + if has_self { let args: ThinVec<_> = idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); @@ -678,7 +685,51 @@ mod llvm_enzyme { } else { let args: ThinVec<_> = idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); - let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); + let mut primal_path = ecx.path_ident(span, primal); + + let is_generic = !generics.params.is_empty(); + + match (is_generic, primal_path.segments.last_mut()) { + (true, Some(function_path)) => { + let primal_generic_types = generics + .params + .iter() + .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. })); + + let generated_generic_types = primal_generic_types + .map(|type_param| { + let generic_param = TyKind::Path( + None, + ast::Path { + span, + segments: thin_vec![ast::PathSegment { + ident: type_param.ident, + args: None, + id: ast::DUMMY_NODE_ID, + }], + tokens: None, + }, + ); + + ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty { + id: type_param.id, + span, + kind: generic_param, + tokens: None, + }))) + }) + .collect(); + + function_path.args = + Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs { + span, + args: generated_generic_types, + }))); + } + _ => {} + } + + let primal_call_expr = ecx.expr_path(primal_path); ecx.expr_call(span, primal_call_expr, args) } } diff --git a/tests/codegen/autodiff/generic.rs b/tests/codegen/autodiff/generic.rs new file mode 100644 index 0000000000000..15e7d8a495745 --- /dev/null +++ b/tests/codegen/autodiff/generic.rs @@ -0,0 +1,42 @@ +//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme +#![feature(autodiff)] + +use std::autodiff::autodiff; + +#[autodiff(d_square, Reverse, Duplicated, Active)] +fn square + Copy>(x: &T) -> T { + *x * *x +} + +// Ensure that `d_square::` code is generated even if `square::` was never called +// +// CHECK: ; generic::square +// CHECK-NEXT: ; Function Attrs: +// CHECK-NEXT: define internal {{.*}} double +// CHECK-NEXT: start: +// CHECK-NOT: ret +// CHECK: fmul double + +// Ensure that `d_square::` code is generated +// +// CHECK: ; generic::square +// CHECK-NEXT: ; Function Attrs: {{.*}} +// CHECK-NEXT: define internal {{.*}} float +// CHECK-NEXT: start: +// CHECK-NOT: ret +// CHECK: fmul float + +fn main() { + let xf32: f32 = std::hint::black_box(3.0); + let xf64: f64 = std::hint::black_box(3.0); + + let outputf32 = square::(&xf32); + assert_eq!(9.0, outputf32); + + let mut df_dxf64: f64 = std::hint::black_box(0.0); + + let output_f64 = d_square::(&xf64, &mut df_dxf64, 1.0); + assert_eq!(6.0, df_dxf64); +} diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index 713b8f541ae0d..8253603e8079c 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -31,6 +31,8 @@ // We want to make sure that we can use the macro for functions defined inside of functions + // Make sure we can handle generics + ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] @@ -181,4 +183,16 @@ ::core::hint::black_box(::default()) } } +#[rustc_autodiff] +#[inline(never)] +pub fn f10 + Copy>(x: &T) -> T { *x * *x } +#[rustc_autodiff(Reverse, 1, Duplicated, Active)] +#[inline(never)] +pub fn d_square + + Copy>(x: &T, dx_0: &mut T, dret: T) -> T { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f10::(x)); + ::core::hint::black_box((dx_0, dret)); + ::core::hint::black_box(f10::(x)) +} fn main() {} diff --git a/tests/pretty/autodiff/autodiff_forward.rs b/tests/pretty/autodiff/autodiff_forward.rs index 5a0660a08e526..ae974f9b4dbf9 100644 --- a/tests/pretty/autodiff/autodiff_forward.rs +++ b/tests/pretty/autodiff/autodiff_forward.rs @@ -63,4 +63,10 @@ pub fn f9() { } } +// Make sure we can handle generics +#[autodiff(d_square, Reverse, Duplicated, Active)] +pub fn f10 + Copy>(x: &T) -> T { + *x * *x +} + fn main() {}