Skip to content

Commit

Permalink
dont need residual_grad output
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Nov 7, 2023
1 parent 39b35bc commit e08a3c4
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 29 deletions.
79 changes: 54 additions & 25 deletions src/codegen/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub type GetOutFunc = unsafe extern "C" fn(data: *const realtype, tensor_data: *
struct EnzymeGlobals<'ctx> {
enzyme_dup: GlobalValue<'ctx>,
enzyme_const: GlobalValue<'ctx>,
enzyme_dupnoneed: GlobalValue<'ctx>,
}

impl<'ctx> EnzymeGlobals<'ctx> {
Expand All @@ -43,10 +44,17 @@ impl<'ctx> EnzymeGlobals<'ctx> {
Ok(Self {
enzyme_dup: module.add_global(int_type, Some(AddressSpace::default()), "enzyme_dup"),
enzyme_const: module.add_global(int_type, Some(AddressSpace::default()), "enzyme_const"),
enzyme_dupnoneed: module.add_global(int_type, Some(AddressSpace::default()), "enzyme_dupnoneed"),
})
}
}

pub enum CompileGradientArgType {
Const,
Dup,
DupNoNeed,
}


pub struct CodeGen<'ctx> {
context: &'ctx inkwell::context::Context,
Expand Down Expand Up @@ -850,9 +858,7 @@ impl<'ctx> CodeGen<'ctx> {
self.insert_state(model.state(), model.state_dot());
self.insert_data(model);

for a in model.state_dep_defns() {
self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
}
// TODO: could split state dep defns into before and after F and G

self.jit_compile_tensor(model.out(), Some(*self.get_var(model.out())))?;
self.builder.build_return(None)?;
Expand Down Expand Up @@ -898,6 +904,11 @@ impl<'ctx> CodeGen<'ctx> {
for tensor in model.time_dep_defns() {
self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
}

// TODO: could split state dep defns into before and after F and G
for a in model.state_dep_defns() {
self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
}

// F and G
self.jit_compile_tensor(&model.lhs(), Some(*self.get_var(model.lhs())))?;
Expand All @@ -922,7 +933,7 @@ impl<'ctx> CodeGen<'ctx> {
}
}

pub fn compile_gradient(&mut self, original_function: FunctionValue<'ctx>, args_is_const: &[bool]) -> Result<FunctionValue<'ctx>> {
pub fn compile_gradient(&mut self, original_function: FunctionValue<'ctx>, args_type: &[CompileGradientArgType]) -> Result<FunctionValue<'ctx>> {
self.clear();

let enzyme_globals = match self.enzyme_globals {
Expand All @@ -943,10 +954,12 @@ impl<'ctx> CodeGen<'ctx> {
enzyme_fn_type.push(self.int_type.into());
enzyme_fn_type.push(arg.get_type().into());

if !args_is_const[i] {
// active args with type T in the original funciton have 3 args of type [int, T, T]
fn_type.push(arg.get_type().into());
enzyme_fn_type.push(arg.get_type().into());
match args_type[i] {
CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
fn_type.push(arg.get_type().into());
enzyme_fn_type.push(arg.get_type().into());
}
CompileGradientArgType::Const => {},
}
}
let void_type = self.context.void_type();
Expand All @@ -959,26 +972,42 @@ impl<'ctx> CodeGen<'ctx> {

let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = vec![original_function.as_global_value().as_pointer_value().into()];
let enzyme_const = self.builder.build_load(enzyme_globals.enzyme_const.as_pointer_value(), "enzyme_const")?;
let enzyme_dup= self.builder.build_load(enzyme_globals.enzyme_dup.as_pointer_value(), "enzyme_const")?;
let enzyme_dup= self.builder.build_load(enzyme_globals.enzyme_dup.as_pointer_value(), "enzyme_dup")?;
let enzyme_dupnoneed = self.builder.build_load(enzyme_globals.enzyme_dupnoneed.as_pointer_value(), "enzyme_dupnoneed")?;
for (i, _arg) in original_function.get_param_iter().enumerate() {
let param_index = start_param_index[i];
let fn_arg = function.get_nth_param(param_index).unwrap();
if args_is_const[i] {
// let enzyme know its a constant arg
enzyme_fn_args.push(enzyme_const.into());

// pass in the arg value
enzyme_fn_args.push(fn_arg.into());
} else {
// let enzyme know its an active arg
enzyme_fn_args.push(enzyme_dup.into());

// pass in the arg value
enzyme_fn_args.push(fn_arg.into());

// pass in the darg value
let fn_darg = function.get_nth_param(param_index + 1).unwrap();
enzyme_fn_args.push(fn_darg.into());
match args_type[i] {
CompileGradientArgType::Dup => {
// let enzyme know its an active arg
enzyme_fn_args.push(enzyme_dup.into());

// pass in the arg value
enzyme_fn_args.push(fn_arg.into());

// pass in the darg value
let fn_darg = function.get_nth_param(param_index + 1).unwrap();
enzyme_fn_args.push(fn_darg.into());
},
CompileGradientArgType::DupNoNeed => {
// let enzyme know its an active arg we don't need
enzyme_fn_args.push(enzyme_dupnoneed.into());

// pass in the arg value
enzyme_fn_args.push(fn_arg.into());

// pass in the darg value
let fn_darg = function.get_nth_param(param_index + 1).unwrap();
enzyme_fn_args.push(fn_darg.into());
},
CompileGradientArgType::Const => {
// let enzyme know its a constant arg
enzyme_fn_args.push(enzyme_const.into());

// pass in the arg value
enzyme_fn_args.push(fn_arg.into());

},
}
}

Expand Down
65 changes: 61 additions & 4 deletions src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::process::Command;


use super::codegen::CalcOutGradientFunc;
use super::codegen::CompileGradientArgType;
use super::codegen::GetDimsFunc;
use super::codegen::GetOutFunc;
use super::codegen::ResidualGradientFunc;
Expand Down Expand Up @@ -90,15 +91,15 @@ impl Compiler {
let mut codegen = CodeGen::new(model, &context, module, real_type, real_type_str);

let _set_u0 = codegen.compile_set_u0(model)?;
let _set_u0_grad = codegen.compile_gradient(_set_u0, &[false, true, false, false])?;
let _set_u0_grad = codegen.compile_gradient(_set_u0, &[CompileGradientArgType::Dup, CompileGradientArgType::Const, CompileGradientArgType::Dup, CompileGradientArgType::Dup])?;
let _residual = codegen.compile_residual(model)?;
let _residual_grad = codegen.compile_gradient(_residual, &[true, false, false, false, true, false])?;
let _residual_grad = codegen.compile_gradient(_residual, &[CompileGradientArgType::Const, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::Const, CompileGradientArgType::DupNoNeed])?;
let _calc_out = codegen.compile_calc_out(model)?;
let _calc_out_grad = codegen.compile_gradient(_calc_out, &[true, false, false, false, true])?;
let _calc_out_grad = codegen.compile_gradient(_calc_out, &[CompileGradientArgType::Const, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::Const])?;
let _set_id = codegen.compile_set_id(model)?;
let _get_dims= codegen.compile_get_dims(model)?;
let _set_inputs = codegen.compile_set_inputs(model)?;
let _set_inputs_grad = codegen.compile_gradient(_set_inputs, &[false, false])?;
let _set_inputs_grad = codegen.compile_gradient(_set_inputs, &[CompileGradientArgType::Dup, CompileGradientArgType::Dup])?;
let _get_output = codegen.compile_get_tensor(model, "out")?;

let pre_enzyme_bitcodefilename = format!("{}.pre-enzyme.bc", out);
Expand Down Expand Up @@ -702,6 +703,62 @@ mod tests {
}


#[test]
fn test_repeated_grad() {
let full_text = "
in = [p]
p {
1,
}
u_i {
y = p,
}
dudt_i {
dydt = 1,
}
r {
2 * y * p,
}
F_i {
dydt,
}
G_i {
r,
}
out_i {
y,
}
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = match DiscreteModel::build("test_repeated_grad", &model) {
Ok(model) => {
model
}
Err(e) => {
panic!("{}", e.as_error_message(full_text));
}
};
let compiler = Compiler::from_discrete_model(&discrete_model, "test_output/compiler_test_repeated_grad").unwrap();
let mut u0 = vec![1.];
let mut up0 = vec![1.];
let mut du0 = vec![1.];
let mut dup0 = vec![1.];
let mut res = vec![0.];
let mut dres = vec![0.];
let mut data = compiler.get_new_data();
let mut ddata = compiler.get_new_data();
let (_n_states, n_inputs, _n_outputs, _n_data, _n_indices) = compiler.get_dims();

for _ in 0..3 {
let inputs = vec![2.; n_inputs];
let dinputs = vec![1.; n_inputs];
compiler.set_inputs_grad(inputs.as_slice(), dinputs.as_slice(), data.as_mut_slice(), ddata.as_mut_slice()).unwrap();
compiler.set_u0_grad(u0.as_mut_slice(), du0.as_mut_slice(), up0.as_mut_slice(), dup0.as_mut_slice(), data.as_mut_slice(), ddata.as_mut_slice()).unwrap();
compiler.residual_grad(0., u0.as_slice(), du0.as_slice(), up0.as_slice(), dup0.as_slice(), data.as_mut_slice(), ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice()).unwrap();
assert_relative_eq!(dres.as_slice(), vec![-8.].as_slice());
}
}

#[test]
fn test_additional_functions() {
let full_text = "
Expand Down

0 comments on commit e08a3c4

Please sign in to comment.