diff --git a/src/codegen/codegen.rs b/src/codegen/codegen.rs index 1fe2032..412c98d 100644 --- a/src/codegen/codegen.rs +++ b/src/codegen/codegen.rs @@ -35,6 +35,7 @@ pub type GetOutFunc = unsafe extern "C" fn(data: *const realtype, tensor_data: * struct EnzymeGlobals<'ctx> { enzyme_dup: GlobalValue<'ctx>, enzyme_const: GlobalValue<'ctx>, + enzyme_dupnoneed: GlobalValue<'ctx>, } impl<'ctx> EnzymeGlobals<'ctx> { @@ -43,10 +44,17 @@ impl<'ctx> EnzymeGlobals<'ctx> { Ok(Self { enzyme_dup: module.add_global(int_type, Some(AddressSpace::default()), "enzyme_dup"), enzyme_const: module.add_global(int_type, Some(AddressSpace::default()), "enzyme_const"), + enzyme_dupnoneed: module.add_global(int_type, Some(AddressSpace::default()), "enzyme_dupnoneed"), }) } } +pub enum CompileGradientArgType { + Const, + Dup, + DupNoNeed, +} + pub struct CodeGen<'ctx> { context: &'ctx inkwell::context::Context, @@ -850,9 +858,7 @@ impl<'ctx> CodeGen<'ctx> { self.insert_state(model.state(), model.state_dot()); self.insert_data(model); - for a in model.state_dep_defns() { - self.jit_compile_tensor(a, Some(*self.get_var(a)))?; - } + // TODO: could split state dep defns into before and after F and G self.jit_compile_tensor(model.out(), Some(*self.get_var(model.out())))?; self.builder.build_return(None)?; @@ -898,6 +904,11 @@ impl<'ctx> CodeGen<'ctx> { for tensor in model.time_dep_defns() { self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?; } + + // TODO: could split state dep defns into before and after F and G + for a in model.state_dep_defns() { + self.jit_compile_tensor(a, Some(*self.get_var(a)))?; + } // F and G self.jit_compile_tensor(&model.lhs(), Some(*self.get_var(model.lhs())))?; @@ -922,7 +933,7 @@ impl<'ctx> CodeGen<'ctx> { } } - pub fn compile_gradient(&mut self, original_function: FunctionValue<'ctx>, args_is_const: &[bool]) -> Result> { + pub fn compile_gradient(&mut self, original_function: FunctionValue<'ctx>, args_type: &[CompileGradientArgType]) -> Result> { self.clear(); let enzyme_globals = match self.enzyme_globals { @@ -943,10 +954,12 @@ impl<'ctx> CodeGen<'ctx> { enzyme_fn_type.push(self.int_type.into()); enzyme_fn_type.push(arg.get_type().into()); - if !args_is_const[i] { - // active args with type T in the original funciton have 3 args of type [int, T, T] - fn_type.push(arg.get_type().into()); - enzyme_fn_type.push(arg.get_type().into()); + match args_type[i] { + CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => { + fn_type.push(arg.get_type().into()); + enzyme_fn_type.push(arg.get_type().into()); + } + CompileGradientArgType::Const => {}, } } let void_type = self.context.void_type(); @@ -959,26 +972,42 @@ impl<'ctx> CodeGen<'ctx> { let mut enzyme_fn_args: Vec = vec![original_function.as_global_value().as_pointer_value().into()]; let enzyme_const = self.builder.build_load(enzyme_globals.enzyme_const.as_pointer_value(), "enzyme_const")?; - let enzyme_dup= self.builder.build_load(enzyme_globals.enzyme_dup.as_pointer_value(), "enzyme_const")?; + let enzyme_dup= self.builder.build_load(enzyme_globals.enzyme_dup.as_pointer_value(), "enzyme_dup")?; + let enzyme_dupnoneed = self.builder.build_load(enzyme_globals.enzyme_dupnoneed.as_pointer_value(), "enzyme_dupnoneed")?; for (i, _arg) in original_function.get_param_iter().enumerate() { let param_index = start_param_index[i]; let fn_arg = function.get_nth_param(param_index).unwrap(); - if args_is_const[i] { - // let enzyme know its a constant arg - enzyme_fn_args.push(enzyme_const.into()); - - // pass in the arg value - enzyme_fn_args.push(fn_arg.into()); - } else { - // let enzyme know its an active arg - enzyme_fn_args.push(enzyme_dup.into()); - - // pass in the arg value - enzyme_fn_args.push(fn_arg.into()); - - // pass in the darg value - let fn_darg = function.get_nth_param(param_index + 1).unwrap(); - enzyme_fn_args.push(fn_darg.into()); + match args_type[i] { + CompileGradientArgType::Dup => { + // let enzyme know its an active arg + enzyme_fn_args.push(enzyme_dup.into()); + + // pass in the arg value + enzyme_fn_args.push(fn_arg.into()); + + // pass in the darg value + let fn_darg = function.get_nth_param(param_index + 1).unwrap(); + enzyme_fn_args.push(fn_darg.into()); + }, + CompileGradientArgType::DupNoNeed => { + // let enzyme know its an active arg we don't need + enzyme_fn_args.push(enzyme_dupnoneed.into()); + + // pass in the arg value + enzyme_fn_args.push(fn_arg.into()); + + // pass in the darg value + let fn_darg = function.get_nth_param(param_index + 1).unwrap(); + enzyme_fn_args.push(fn_darg.into()); + }, + CompileGradientArgType::Const => { + // let enzyme know its a constant arg + enzyme_fn_args.push(enzyme_const.into()); + + // pass in the arg value + enzyme_fn_args.push(fn_arg.into()); + + }, } } diff --git a/src/codegen/compiler.rs b/src/codegen/compiler.rs index 494962f..bc53fab 100644 --- a/src/codegen/compiler.rs +++ b/src/codegen/compiler.rs @@ -14,6 +14,7 @@ use std::process::Command; use super::codegen::CalcOutGradientFunc; +use super::codegen::CompileGradientArgType; use super::codegen::GetDimsFunc; use super::codegen::GetOutFunc; use super::codegen::ResidualGradientFunc; @@ -90,15 +91,15 @@ impl Compiler { let mut codegen = CodeGen::new(model, &context, module, real_type, real_type_str); let _set_u0 = codegen.compile_set_u0(model)?; - let _set_u0_grad = codegen.compile_gradient(_set_u0, &[false, true, false, false])?; + let _set_u0_grad = codegen.compile_gradient(_set_u0, &[CompileGradientArgType::Dup, CompileGradientArgType::Const, CompileGradientArgType::Dup, CompileGradientArgType::Dup])?; let _residual = codegen.compile_residual(model)?; - let _residual_grad = codegen.compile_gradient(_residual, &[true, false, false, false, true, false])?; + let _residual_grad = codegen.compile_gradient(_residual, &[CompileGradientArgType::Const, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::Const, CompileGradientArgType::DupNoNeed])?; let _calc_out = codegen.compile_calc_out(model)?; - let _calc_out_grad = codegen.compile_gradient(_calc_out, &[true, false, false, false, true])?; + let _calc_out_grad = codegen.compile_gradient(_calc_out, &[CompileGradientArgType::Const, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::Const])?; let _set_id = codegen.compile_set_id(model)?; let _get_dims= codegen.compile_get_dims(model)?; let _set_inputs = codegen.compile_set_inputs(model)?; - let _set_inputs_grad = codegen.compile_gradient(_set_inputs, &[false, false])?; + let _set_inputs_grad = codegen.compile_gradient(_set_inputs, &[CompileGradientArgType::Dup, CompileGradientArgType::Dup])?; let _get_output = codegen.compile_get_tensor(model, "out")?; let pre_enzyme_bitcodefilename = format!("{}.pre-enzyme.bc", out); @@ -702,6 +703,62 @@ mod tests { } + #[test] + fn test_repeated_grad() { + let full_text = " + in = [p] + p { + 1, + } + u_i { + y = p, + } + dudt_i { + dydt = 1, + } + r { + 2 * y * p, + } + F_i { + dydt, + } + G_i { + r, + } + out_i { + y, + } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = match DiscreteModel::build("test_repeated_grad", &model) { + Ok(model) => { + model + } + Err(e) => { + panic!("{}", e.as_error_message(full_text)); + } + }; + let compiler = Compiler::from_discrete_model(&discrete_model, "test_output/compiler_test_repeated_grad").unwrap(); + let mut u0 = vec![1.]; + let mut up0 = vec![1.]; + let mut du0 = vec![1.]; + let mut dup0 = vec![1.]; + let mut res = vec![0.]; + let mut dres = vec![0.]; + let mut data = compiler.get_new_data(); + let mut ddata = compiler.get_new_data(); + let (_n_states, n_inputs, _n_outputs, _n_data, _n_indices) = compiler.get_dims(); + + for _ in 0..3 { + let inputs = vec![2.; n_inputs]; + let dinputs = vec![1.; n_inputs]; + compiler.set_inputs_grad(inputs.as_slice(), dinputs.as_slice(), data.as_mut_slice(), ddata.as_mut_slice()).unwrap(); + compiler.set_u0_grad(u0.as_mut_slice(), du0.as_mut_slice(), up0.as_mut_slice(), dup0.as_mut_slice(), data.as_mut_slice(), ddata.as_mut_slice()).unwrap(); + compiler.residual_grad(0., u0.as_slice(), du0.as_slice(), up0.as_slice(), dup0.as_slice(), data.as_mut_slice(), ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice()).unwrap(); + assert_relative_eq!(dres.as_slice(), vec![-8.].as_slice()); + } + } + #[test] fn test_additional_functions() { let full_text = "