-
Notifications
You must be signed in to change notification settings - Fork 13.3k
fix autodiff macro on generic functions #140049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,10 +73,10 @@ mod llvm_enzyme { | |
} | ||
|
||
// Get information about the function the macro is applied to | ||
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> { | ||
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> { | ||
match &iitem.kind { | ||
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => { | ||
Some((iitem.vis.clone(), sig.clone(), ident.clone())) | ||
ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { | ||
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone())) | ||
} | ||
_ => None, | ||
} | ||
|
@@ -210,16 +210,18 @@ mod llvm_enzyme { | |
} | ||
let dcx = ecx.sess.dcx(); | ||
|
||
// first get information about the annotable item: | ||
let Some((vis, sig, primal)) = (match &item { | ||
// first get information about the annotable item: visibility, signature, name and generic | ||
// parameters. | ||
// these will be used to generate the differentiated version of the function | ||
let Some((vis, sig, primal, generics)) = (match &item { | ||
Annotatable::Item(iitem) => extract_item_info(iitem), | ||
Annotatable::Stmt(stmt) => match &stmt.kind { | ||
ast::StmtKind::Item(iitem) => extract_item_info(iitem), | ||
_ => None, | ||
}, | ||
Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { | ||
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => { | ||
Some((assoc_item.vis.clone(), sig.clone(), ident.clone())) | ||
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { | ||
Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) | ||
} | ||
_ => None, | ||
}, | ||
|
@@ -303,14 +305,15 @@ mod llvm_enzyme { | |
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); | ||
let d_body = gen_enzyme_body( | ||
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, | ||
&generics, | ||
); | ||
|
||
// The first element of it is the name of the function to be generated | ||
let asdf = Box::new(ast::Fn { | ||
defaultness: ast::Defaultness::Final, | ||
sig: d_sig, | ||
ident: first_ident(&meta_item_vec[0]), | ||
generics: Generics::default(), | ||
generics, | ||
contract: None, | ||
body: Some(d_body), | ||
define_opaque: None, | ||
|
@@ -475,6 +478,7 @@ mod llvm_enzyme { | |
new_decl_span: Span, | ||
idents: &[Ident], | ||
errored: bool, | ||
generics: &Generics, | ||
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) { | ||
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); | ||
let noop = ast::InlineAsm { | ||
|
@@ -497,7 +501,7 @@ mod llvm_enzyme { | |
}; | ||
let unsf_expr = ecx.expr_block(P(unsf_block)); | ||
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); | ||
let primal_call = gen_primal_call(ecx, span, primal, idents); | ||
let primal_call = gen_primal_call(ecx, span, primal, idents, generics); | ||
let black_box_primal_call = ecx.expr_call( | ||
new_decl_span, | ||
blackbox_call_expr.clone(), | ||
|
@@ -546,6 +550,7 @@ mod llvm_enzyme { | |
sig_span: Span, | ||
idents: Vec<Ident>, | ||
errored: bool, | ||
generics: &Generics, | ||
) -> P<ast::Block> { | ||
let new_decl_span = d_sig.span; | ||
|
||
|
@@ -566,6 +571,7 @@ mod llvm_enzyme { | |
new_decl_span, | ||
&idents, | ||
errored, | ||
generics, | ||
); | ||
|
||
if !has_ret(&d_sig.decl.output) { | ||
|
@@ -608,7 +614,6 @@ mod llvm_enzyme { | |
panic!("Did not expect Default ret ty: {:?}", span); | ||
} | ||
}; | ||
|
||
if x.mode.is_fwd() { | ||
// Fwd mode is easy. If the return activity is Const, we support arbitrary types. | ||
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. | ||
|
@@ -668,8 +673,10 @@ mod llvm_enzyme { | |
span: Span, | ||
primal: Ident, | ||
idents: &[Ident], | ||
generics: &Generics, | ||
) -> P<ast::Expr> { | ||
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; | ||
|
||
if has_self { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now, I left the code only in the lower branch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume it's for methods, e.g. #139557 I think that this code handles calling another method in the dummy body, so instead of generating There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the |
||
let args: ThinVec<_> = | ||
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); | ||
|
@@ -678,7 +685,51 @@ mod llvm_enzyme { | |
} else { | ||
let args: ThinVec<_> = | ||
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); | ||
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); | ||
let mut primal_path = ecx.path_ident(span, primal); | ||
|
||
let is_generic = !generics.params.is_empty(); | ||
|
||
match (is_generic, primal_path.segments.last_mut()) { | ||
(true, Some(function_path)) => { | ||
let primal_generic_types = generics | ||
.params | ||
.iter() | ||
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. })); | ||
|
||
let generated_generic_types = primal_generic_types | ||
.map(|type_param| { | ||
let generic_param = TyKind::Path( | ||
None, | ||
ast::Path { | ||
span, | ||
segments: thin_vec![ast::PathSegment { | ||
ident: type_param.ident, | ||
args: None, | ||
id: ast::DUMMY_NODE_ID, | ||
}], | ||
tokens: None, | ||
}, | ||
); | ||
|
||
ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty { | ||
id: type_param.id, | ||
span, | ||
kind: generic_param, | ||
tokens: None, | ||
}))) | ||
}) | ||
.collect(); | ||
|
||
function_path.args = | ||
Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs { | ||
span, | ||
args: generated_generic_types, | ||
}))); | ||
} | ||
_ => {} | ||
} | ||
|
||
let primal_call_expr = ecx.expr_path(primal_path); | ||
ecx.expr_call(span, primal_call_expr, args) | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat | ||
//@ no-prefer-dynamic | ||
//@ needs-enzyme | ||
#![feature(autodiff)] | ||
|
||
use std::autodiff::autodiff; | ||
|
||
#[autodiff(d_square, Reverse, Duplicated, Active)] | ||
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { | ||
*x * *x | ||
} | ||
|
||
// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called | ||
// | ||
// CHECK: ; generic::square | ||
// CHECK-NEXT: ; Function Attrs: | ||
// CHECK-NEXT: define internal {{.*}} double | ||
// CHECK-NEXT: start: | ||
// CHECK-NOT: ret | ||
// CHECK: fmul double | ||
|
||
// Ensure that `d_square::<f32>` code is generated | ||
// | ||
// CHECK: ; generic::square | ||
// CHECK-NEXT: ; Function Attrs: {{.*}} | ||
// CHECK-NEXT: define internal {{.*}} float | ||
// CHECK-NEXT: start: | ||
// CHECK-NOT: ret | ||
// CHECK: fmul float | ||
|
||
fn main() { | ||
let xf32: f32 = std::hint::black_box(3.0); | ||
let xf64: f64 = std::hint::black_box(3.0); | ||
|
||
let outputf32 = square::<f32>(&xf32); | ||
assert_eq!(9.0, outputf32); | ||
|
||
let mut df_dxf64: f64 = std::hint::black_box(0.0); | ||
|
||
let output_f64 = d_square::<f64>(&xf64, &mut df_dxf64, 1.0); | ||
assert_eq!(6.0, df_dxf64); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing this through 3x is a bit ugly :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a discussion about fully dropping the function body and just generating a declaration using
#[rustc_intrinsic]
^^ #wg-autodiff > Placeholder function design @ 💬This would make the frontend trivial, but complicate things a bit in the middle or backend. I think it's still an overall win, but I haven't looked into it.