From 39b35bc34a27742a84943400f286cb439f179dd9 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sat, 4 Nov 2023 18:00:28 +0000 Subject: [PATCH] grad functions work --- src/codegen/codegen.rs | 4 ++ src/codegen/compiler.rs | 117 ++++++++++++++++++++++++++++++------- src/codegen/data_layout.rs | 22 ++++++- src/codegen/sundials.rs | 6 +- 4 files changed, 121 insertions(+), 28 deletions(-) diff --git a/src/codegen/codegen.rs b/src/codegen/codegen.rs index df57fd1..1fe2032 100644 --- a/src/codegen/codegen.rs +++ b/src/codegen/codegen.rs @@ -849,6 +849,10 @@ impl<'ctx> CodeGen<'ctx> { self.insert_state(model.state(), model.state_dot()); self.insert_data(model); + + for a in model.state_dep_defns() { + self.jit_compile_tensor(a, Some(*self.get_var(a)))?; + } self.jit_compile_tensor(model.out(), Some(*self.get_var(model.out())))?; self.builder.build_return(None)?; diff --git a/src/codegen/compiler.rs b/src/codegen/compiler.rs index 2335711..494962f 100644 --- a/src/codegen/compiler.rs +++ b/src/codegen/compiler.rs @@ -558,12 +558,57 @@ mod tests { object.write_object_file(path).unwrap(); } + fn tensor_test_common(text: &str, tmp_loc: &str, tensor_name: &str) -> Vec> { + let full_text = format!(" + {} + ", text); + let model = parse_ds_string(full_text.as_str()).unwrap(); + let discrete_model = match DiscreteModel::build("$name", &model) { + Ok(model) => { + model + } + Err(e) => { + panic!("{}", e.as_error_message(full_text.as_str())); + } + }; + let compiler = Compiler::from_discrete_model(&discrete_model, tmp_loc).unwrap(); + let mut u0 = vec![1.]; + let mut up0 = vec![1.]; + let mut res = vec![0.]; + let mut data = compiler.get_new_data(); + let mut grad_data = Vec::new(); + let (_n_states, n_inputs, _n_outputs, _n_data, _n_indices) = compiler.get_dims(); + for _ in 0..n_inputs { + grad_data.push(compiler.get_new_data()); + } + let mut results = Vec::new(); + let inputs = vec![1.; n_inputs]; + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()).unwrap(); + compiler.set_u0(u0.as_mut_slice(), up0.as_mut_slice(), data.as_mut_slice()).unwrap(); + compiler.residual(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); + compiler.calc_out(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice()).unwrap(); + results.push(compiler.get_tensor_data(tensor_name, data.as_slice()).unwrap().to_vec()); + for i in 0..n_inputs { + let mut dinputs = vec![0.; n_inputs]; + dinputs[i] = 1.0; + let mut ddata = compiler.get_new_data(); + let mut du0 = vec![0.]; + let mut dup0 = vec![0.]; + let mut dres = vec![0.]; + compiler.set_inputs_grad(inputs.as_slice(), dinputs.as_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice()).unwrap(); + compiler.set_u0_grad(u0.as_mut_slice(), du0.as_mut_slice(), up0.as_mut_slice(), dup0.as_mut_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice()).unwrap(); + compiler.residual_grad(0., u0.as_slice(), du0.as_slice(), up0.as_slice(), dup0.as_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice()).unwrap(); + compiler.calc_out_grad(0., u0.as_slice(), du0.as_slice(), up0.as_slice(), dup0.as_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice()).unwrap(); + results.push(compiler.get_tensor_data(tensor_name, ddata.as_slice()).unwrap().to_vec()); + } + results + } + macro_rules! tensor_test { ($($name:ident: $text:literal expect $tensor_name:literal $expected_value:expr,)*) => { $( #[test] fn $name() { - let text = $text; let full_text = format!(" {} u_i {{ @@ -581,27 +626,10 @@ mod tests { out_i {{ y, }} - ", text); - let model = parse_ds_string(full_text.as_str()).unwrap(); - let discrete_model = match DiscreteModel::build("$name", &model) { - Ok(model) => { - model - } - Err(e) => { - panic!("{}", e.as_error_message(full_text.as_str())); - } - }; - let compiler = Compiler::from_discrete_model(&discrete_model, concat!("test_output/compiler_tensor_test_", stringify!($name))).unwrap(); - let inputs = vec![]; - let mut u0 = vec![1.]; - let mut up0 = vec![1.]; - let mut res = vec![0.]; - let mut data = compiler.get_new_data(); - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()).unwrap(); - compiler.set_u0(u0.as_mut_slice(), up0.as_mut_slice(), data.as_mut_slice()).unwrap(); - compiler.residual(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); - let tensor = compiler.get_tensor_data($tensor_name, data.as_slice()).unwrap(); - assert_relative_eq!(tensor, $expected_value.as_slice()); + ", $text); + let tmp_loc = format!("test_output/compiler_tensor_test_{}", stringify!($name)); + let results = tensor_test_common(full_text.as_str(), tmp_loc.as_str(), $tensor_name); + assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice()); } )* } @@ -629,6 +657,51 @@ mod tests { dense_matrix_vect_multiply: "A_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 } x_i { 1, 2 } b_i { A_ij * x_j }" expect "b" vec![5., 11.], } + macro_rules! tensor_grad_test { + ($($name:ident: $text:literal expect $tensor_name:literal $expected_value:expr,)*) => { + $( + #[test] + fn $name() { + let full_text = format!(" + in = [p] + p {{ + 1, + }} + u_i {{ + y = p, + }} + dudt_i {{ + dydt = p, + }} + {} + F_i {{ + dydt, + }} + G_i {{ + y, + }} + out_i {{ + y, + }} + ", $text); + let tmp_loc = format!("test_output/compiler_tensor_grad_test_{}", stringify!($name)); + let results = tensor_test_common(full_text.as_str(), tmp_loc.as_str(), $tensor_name); + assert_relative_eq!(results[1].as_slice(), $expected_value.as_slice()); + } + )* + } + } + + tensor_grad_test! { + const_grad: "r { 3 }" expect "r" vec![0.], + const_vec_grad: "r_i { 3, 4 }" expect "r" vec![0., 0.], + input_grad: "r { 2 * p * p }" expect "r" vec![4.], + input_vec_grad: "r_i { 2 * p * p, 3 * p }" expect "r" vec![4., 3.], + state_grad: "r { 2 * y }" expect "r" vec![2.], + input_and_state_grad: "r { 2 * y * p }" expect "r" vec![4.], + } + + #[test] fn test_additional_functions() { let full_text = " diff --git a/src/codegen/data_layout.rs b/src/codegen/data_layout.rs index 22e4c6f..5f12fd5 100644 --- a/src/codegen/data_layout.rs +++ b/src/codegen/data_layout.rs @@ -37,11 +37,14 @@ impl DataLayout { let mut layout_map = HashMap::new(); let mut add_tensor = |tensor: &Tensor| { + let is_state = tensor.name() == "u" || tensor.name() == "dudt"; // insert the data (non-zeros) for each tensor layout_map.insert(tensor.name().to_string(), tensor.layout_ptr().clone()); - data_index_map.insert(tensor.name().to_string(), data.len()); - data_length_map.insert(tensor.name().to_string(), tensor.nnz()); - data.extend(vec![0.0; tensor.nnz()]); + if !is_state { + data_index_map.insert(tensor.name().to_string(), data.len()); + data_length_map.insert(tensor.name().to_string(), tensor.nnz()); + data.extend(vec![0.0; tensor.nnz()]); + } // add the translation info for each block-tensor pair @@ -88,6 +91,19 @@ impl DataLayout { pub fn get_data_index(&self, name: &str) -> Option { self.data_index_map.get(name).map(|i| *i) } + + pub fn format_data(&self, data: &[f64]) -> String { + let mut data_index_sorted: Vec<_> = self.data_index_map.iter().collect(); + data_index_sorted.sort_by_key(|(_, index)| **index); + let mut s = String::new(); + s += "["; + for (name, index) in data_index_sorted { + let nnz = self.data_length_map[name]; + s += &format!("{}: {:?}, ", name, &data[*index..*index+nnz]); + } + s += "]"; + s + } pub fn get_tensor_data(&self, name: &str) -> Option<&[f64]> { let index = self.get_data_index(name)?; diff --git a/src/codegen/sundials.rs b/src/codegen/sundials.rs index 0394034..ae4a1c8 100644 --- a/src/codegen/sundials.rs +++ b/src/codegen/sundials.rs @@ -87,8 +87,8 @@ impl Sundials { } - pub fn from_discrete_model<'m>(model: &'m DiscreteModel, options: Options) -> Result { - let compiler = Compiler::from_discrete_model(model, model.name()).unwrap(); + pub fn from_discrete_model<'m>(model: &'m DiscreteModel, options: Options, out: &str) -> Result { + let compiler = Compiler::from_discrete_model(model, out).unwrap(); let number_of_states = compiler.number_of_states() as i64; let number_of_parameters = compiler.number_of_parameters(); @@ -403,7 +403,7 @@ mod tests { let discrete = DiscreteModel::from(&model_info); println!("{}", discrete); let options = Options::new(); - let mut sundials = Sundials::from_discrete_model(&discrete, options).unwrap(); + let mut sundials = Sundials::from_discrete_model(&discrete, options, "test_output/sundials_logistic_growth").unwrap(); let times = Array::linspace(0., 1., 5);