From d9e9c9ce1c98ace037c091c90db19d908f3043fd Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Thu, 21 Dec 2023 20:00:30 -0700 Subject: [PATCH] autodiff: no_std support (switch std:: to core::) I can now do this no a device function and the IR looks okay by eyeball. argo +enzyme rustc --release --target=nvptx64-nvidia-cuda -Zbuild-std -- --emit=llvm-ir --- library/autodiff/src/gen.rs | 8 ++++---- .../autodiff/tests/expand/forward_duplicated.expanded.rs | 4 ++-- .../tests/expand/forward_duplicated_return.expanded.rs | 4 ++-- .../autodiff/tests/expand/reverse_duplicated.expanded.rs | 4 ++-- .../tests/expand/reverse_return_array.expanded.rs | 4 ++-- .../tests/expand/reverse_return_mixed.expanded.rs | 4 ++-- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/library/autodiff/src/gen.rs b/library/autodiff/src/gen.rs index 68aae56ea3311..59b37b997e8cf 100644 --- a/library/autodiff/src/gen.rs +++ b/library/autodiff/src/gen.rs @@ -107,11 +107,11 @@ pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream { res_inputs.push(input.clone()); match (item.header.mode, activity, is_ref_mut(&input)) { - (Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(true)) => { + (Mode::Forward, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(true)) => { res_inputs.push(as_ref_mut(&input, "grad", true)); add_inputs.push(as_ref_mut(&input, "grad", true)); } - (Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(false)) => { + (Mode::Forward, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(false)) => { res_inputs.push(as_ref_mut(&input, "dual", false)); add_inputs.push(as_ref_mut(&input, "dual", false)); out_type.clone().map(|x| outputs.push(x)); @@ -203,9 +203,9 @@ pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream { }; let body = quote!({ - std::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*)); + core::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box(unsafe { core::mem::zeroed() }) }); let header = generate_header(&item); diff --git a/library/autodiff/tests/expand/forward_duplicated.expanded.rs b/library/autodiff/tests/expand/forward_duplicated.expanded.rs index bf3890154ab8e..c3e30939f92d4 100644 --- a/library/autodiff/tests/expand/forward_duplicated.expanded.rs +++ b/library/autodiff/tests/expand/forward_duplicated.expanded.rs @@ -5,6 +5,6 @@ fn square(a: &Vec, b: &mut f32) { } #[autodiff_into(Forward, Const, Duplicated, Duplicated)] fn d_square(a: &Vec, dual_a: &Vec, b: &mut f32, grad_b: &mut f32) { - std::hint::black_box((square(a, b), dual_a, grad_b)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box((square(a, b), dual_a, grad_b)); + core::hint::black_box(unsafe { core::mem::zeroed() }) } diff --git a/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs index a3754de7ab70b..12b3cb898797f 100644 --- a/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs +++ b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs @@ -10,6 +10,6 @@ fn d_square2( b: &Vec, dual_b: &Vec, ) -> (f32, f32, f32) { - std::hint::black_box((square2(a, b), dual_a, dual_b)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box((square2(a, b), dual_a, dual_b)); + core::hint::black_box(unsafe { core::mem::zeroed() }) } diff --git a/library/autodiff/tests/expand/reverse_duplicated.expanded.rs b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs index 60c0d7f2f696b..04f462a09b2dd 100644 --- a/library/autodiff/tests/expand/reverse_duplicated.expanded.rs +++ b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs @@ -5,6 +5,6 @@ fn square(a: &Vec, b: &mut f32) { } #[autodiff_into(Reverse, Const, Duplicated, Duplicated)] fn d_square(a: &Vec, grad_a: &mut Vec, b: &mut f32, grad_b: &f32) { - std::hint::black_box((square(a, b), grad_a, grad_b)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box((square(a, b), grad_a, grad_b)); + core::hint::black_box(unsafe { core::mem::zeroed() }) } diff --git a/library/autodiff/tests/expand/reverse_return_array.expanded.rs b/library/autodiff/tests/expand/reverse_return_array.expanded.rs index 5b784157fea7b..48e0d99fd2797 100644 --- a/library/autodiff/tests/expand/reverse_return_array.expanded.rs +++ b/library/autodiff/tests/expand/reverse_return_array.expanded.rs @@ -5,6 +5,6 @@ fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 { } #[autodiff_into(Reverse, Active, Duplicated)] fn d_array(arr: &[[[f32; 2]; 2]; 2], grad_arr: &mut [[[f32; 2]; 2]; 2], tang_y: f32) { - std::hint::black_box((array(arr), grad_arr, tang_y)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box((array(arr), grad_arr, tang_y)); + core::hint::black_box(unsafe { core::mem::zeroed() }) } diff --git a/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs index f49864fb7e9b9..3517912222615 100644 --- a/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs +++ b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs @@ -12,6 +12,6 @@ fn d_sqrt( d: f32, tang_y: f32, ) -> (f32, f32) { - std::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y)); + core::hint::black_box(unsafe { core::mem::zeroed() }) }