Skip to content

fix autodiff macro on generic functions #140049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ mod llvm_enzyme {
}

// Get information about the function the macro is applied to
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
match &iitem.kind {
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
Some((iitem.vis.clone(), sig.clone(), ident.clone()))
ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
}
_ => None,
}
Expand Down Expand Up @@ -210,16 +210,18 @@ mod llvm_enzyme {
}
let dcx = ecx.sess.dcx();

// first get information about the annotable item:
let Some((vis, sig, primal)) = (match &item {
// first get information about the annotable item: visibility, signature, name and generic
// parameters.
// these will be used to generate the differentiated version of the function
let Some((vis, sig, primal, generics)) = (match &item {
Annotatable::Item(iitem) => extract_item_info(iitem),
Annotatable::Stmt(stmt) => match &stmt.kind {
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
_ => None,
},
Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
}
_ => None,
},
Expand Down Expand Up @@ -303,14 +305,15 @@ mod llvm_enzyme {
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
let d_body = gen_enzyme_body(
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
&generics,
);

// The first element of it is the name of the function to be generated
let asdf = Box::new(ast::Fn {
defaultness: ast::Defaultness::Final,
sig: d_sig,
ident: first_ident(&meta_item_vec[0]),
generics: Generics::default(),
generics,
contract: None,
body: Some(d_body),
define_opaque: None,
Expand Down Expand Up @@ -475,6 +478,7 @@ mod llvm_enzyme {
new_decl_span: Span,
idents: &[Ident],
errored: bool,
generics: &Generics,
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
let noop = ast::InlineAsm {
Expand All @@ -497,7 +501,7 @@ mod llvm_enzyme {
};
let unsf_expr = ecx.expr_block(P(unsf_block));
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
let primal_call = gen_primal_call(ecx, span, primal, idents);
let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
let black_box_primal_call = ecx.expr_call(
new_decl_span,
blackbox_call_expr.clone(),
Expand Down Expand Up @@ -546,6 +550,7 @@ mod llvm_enzyme {
sig_span: Span,
idents: Vec<Ident>,
errored: bool,
generics: &Generics,
) -> P<ast::Block> {
let new_decl_span = d_sig.span;

Expand All @@ -566,6 +571,7 @@ mod llvm_enzyme {
new_decl_span,
&idents,
errored,
generics,
);

if !has_ret(&d_sig.decl.output) {
Expand Down Expand Up @@ -608,7 +614,6 @@ mod llvm_enzyme {
panic!("Did not expect Default ret ty: {:?}", span);
}
};

if x.mode.is_fwd() {
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
Expand Down Expand Up @@ -668,8 +673,10 @@ mod llvm_enzyme {
span: Span,
primal: Ident,
idents: &[Ident],
generics: &Generics,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing this through 3x is a bit ugly :/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a discussion about fully dropping the function body and just generating a declaration using #[rustc_intrinsic] ^^ #wg-autodiff > Placeholder function design @ 💬
This would make the frontend trivial, but complicate things a bit in the middle or backend. I think it's still an overall win, but I haven't looked into it.

) -> P<ast::Expr> {
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;

if has_self {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, I left the code only in the lower branch.
@ZuseZ4 could you quickly explain why we have this explicit branching here?

Copy link
Member

@ZuseZ4 ZuseZ4 Apr 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume it's for methods, e.g. #139557
But we don't have tests in tests/pretty/ using self, and as per issue I also broke support for it at some point, so .. that checks out.

I think that this code handles calling another method in the dummy body, so instead of generating
bench_black_box(d_psi(self, j, 1.0)); we'd generate bench_black_box(self.d_psi(j, 1.0)); That code is currently broken and not tested anyway, so I think it's fine to not change it. The next PR fixing it can take a look.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the impl PR is already being rolled up I will try to fix it as part of this PR after I have rebased

let args: ThinVec<_> =
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
Expand All @@ -678,7 +685,51 @@ mod llvm_enzyme {
} else {
let args: ThinVec<_> =
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
let mut primal_path = ecx.path_ident(span, primal);

let is_generic = !generics.params.is_empty();

match (is_generic, primal_path.segments.last_mut()) {
(true, Some(function_path)) => {
let primal_generic_types = generics
.params
.iter()
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));

let generated_generic_types = primal_generic_types
.map(|type_param| {
let generic_param = TyKind::Path(
None,
ast::Path {
span,
segments: thin_vec![ast::PathSegment {
ident: type_param.ident,
args: None,
id: ast::DUMMY_NODE_ID,
}],
tokens: None,
},
);

ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
id: type_param.id,
span,
kind: generic_param,
tokens: None,
})))
})
.collect();

function_path.args =
Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
span,
args: generated_generic_types,
})));
}
_ => {}
}

let primal_call_expr = ecx.expr_path(primal_path);
ecx.expr_call(span, primal_call_expr, args)
}
}
Expand Down
42 changes: 42 additions & 0 deletions tests/codegen/autodiff/generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
#![feature(autodiff)]

use std::autodiff::autodiff;

#[autodiff(d_square, Reverse, Duplicated, Active)]
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
*x * *x
}

// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
//
// CHECK: ; generic::square
// CHECK-NEXT: ; Function Attrs:
// CHECK-NEXT: define internal {{.*}} double
// CHECK-NEXT: start:
// CHECK-NOT: ret
// CHECK: fmul double

// Ensure that `d_square::<f32>` code is generated
//
// CHECK: ; generic::square
// CHECK-NEXT: ; Function Attrs: {{.*}}
// CHECK-NEXT: define internal {{.*}} float
// CHECK-NEXT: start:
// CHECK-NOT: ret
// CHECK: fmul float

fn main() {
let xf32: f32 = std::hint::black_box(3.0);
let xf64: f64 = std::hint::black_box(3.0);

let outputf32 = square::<f32>(&xf32);
assert_eq!(9.0, outputf32);

let mut df_dxf64: f64 = std::hint::black_box(0.0);

let output_f64 = d_square::<f64>(&xf64, &mut df_dxf64, 1.0);
assert_eq!(6.0, df_dxf64);
}
14 changes: 14 additions & 0 deletions tests/pretty/autodiff/autodiff_forward.pp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

// We want to make sure that we can use the macro for functions defined inside of functions

// Make sure we can handle generics

::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
Expand Down Expand Up @@ -181,4 +183,16 @@
::core::hint::black_box(<f32>::default())
}
}
#[rustc_autodiff]
#[inline(never)]
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
#[inline(never)]
pub fn d_square<T: std::ops::Mul<Output = T> +
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f10::<T>(x));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f10::<T>(x))
}
fn main() {}
6 changes: 6 additions & 0 deletions tests/pretty/autodiff/autodiff_forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,10 @@ pub fn f9() {
}
}

// Make sure we can handle generics
#[autodiff(d_square, Reverse, Duplicated, Active)]
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
*x * *x
}

fn main() {}
Loading