Skip to content

Commit

Permalink
grad functions work
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Nov 4, 2023
1 parent 86cf87d commit 39b35bc
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 28 deletions.
4 changes: 4 additions & 0 deletions src/codegen/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,10 @@ impl<'ctx> CodeGen<'ctx> {

self.insert_state(model.state(), model.state_dot());
self.insert_data(model);

for a in model.state_dep_defns() {
self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
}

self.jit_compile_tensor(model.out(), Some(*self.get_var(model.out())))?;
self.builder.build_return(None)?;
Expand Down
117 changes: 95 additions & 22 deletions src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,57 @@ mod tests {
object.write_object_file(path).unwrap();
}

fn tensor_test_common(text: &str, tmp_loc: &str, tensor_name: &str) -> Vec<Vec<f64>> {
let full_text = format!("
{}
", text);
let model = parse_ds_string(full_text.as_str()).unwrap();
let discrete_model = match DiscreteModel::build("$name", &model) {
Ok(model) => {
model
}
Err(e) => {
panic!("{}", e.as_error_message(full_text.as_str()));
}
};
let compiler = Compiler::from_discrete_model(&discrete_model, tmp_loc).unwrap();
let mut u0 = vec![1.];
let mut up0 = vec![1.];
let mut res = vec![0.];
let mut data = compiler.get_new_data();
let mut grad_data = Vec::new();
let (_n_states, n_inputs, _n_outputs, _n_data, _n_indices) = compiler.get_dims();
for _ in 0..n_inputs {
grad_data.push(compiler.get_new_data());
}
let mut results = Vec::new();
let inputs = vec![1.; n_inputs];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()).unwrap();
compiler.set_u0(u0.as_mut_slice(), up0.as_mut_slice(), data.as_mut_slice()).unwrap();
compiler.residual(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap();
compiler.calc_out(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice()).unwrap();
results.push(compiler.get_tensor_data(tensor_name, data.as_slice()).unwrap().to_vec());
for i in 0..n_inputs {
let mut dinputs = vec![0.; n_inputs];
dinputs[i] = 1.0;
let mut ddata = compiler.get_new_data();
let mut du0 = vec![0.];
let mut dup0 = vec![0.];
let mut dres = vec![0.];
compiler.set_inputs_grad(inputs.as_slice(), dinputs.as_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice()).unwrap();
compiler.set_u0_grad(u0.as_mut_slice(), du0.as_mut_slice(), up0.as_mut_slice(), dup0.as_mut_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice()).unwrap();
compiler.residual_grad(0., u0.as_slice(), du0.as_slice(), up0.as_slice(), dup0.as_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice()).unwrap();
compiler.calc_out_grad(0., u0.as_slice(), du0.as_slice(), up0.as_slice(), dup0.as_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice()).unwrap();
results.push(compiler.get_tensor_data(tensor_name, ddata.as_slice()).unwrap().to_vec());
}
results
}

macro_rules! tensor_test {
($($name:ident: $text:literal expect $tensor_name:literal $expected_value:expr,)*) => {
$(
#[test]
fn $name() {
let text = $text;
let full_text = format!("
{}
u_i {{
Expand All @@ -581,27 +626,10 @@ mod tests {
out_i {{
y,
}}
", text);
let model = parse_ds_string(full_text.as_str()).unwrap();
let discrete_model = match DiscreteModel::build("$name", &model) {
Ok(model) => {
model
}
Err(e) => {
panic!("{}", e.as_error_message(full_text.as_str()));
}
};
let compiler = Compiler::from_discrete_model(&discrete_model, concat!("test_output/compiler_tensor_test_", stringify!($name))).unwrap();
let inputs = vec![];
let mut u0 = vec![1.];
let mut up0 = vec![1.];
let mut res = vec![0.];
let mut data = compiler.get_new_data();
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()).unwrap();
compiler.set_u0(u0.as_mut_slice(), up0.as_mut_slice(), data.as_mut_slice()).unwrap();
compiler.residual(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap();
let tensor = compiler.get_tensor_data($tensor_name, data.as_slice()).unwrap();
assert_relative_eq!(tensor, $expected_value.as_slice());
", $text);
let tmp_loc = format!("test_output/compiler_tensor_test_{}", stringify!($name));
let results = tensor_test_common(full_text.as_str(), tmp_loc.as_str(), $tensor_name);
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
}
)*
}
Expand Down Expand Up @@ -629,6 +657,51 @@ mod tests {
dense_matrix_vect_multiply: "A_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 } x_i { 1, 2 } b_i { A_ij * x_j }" expect "b" vec![5., 11.],
}

macro_rules! tensor_grad_test {
($($name:ident: $text:literal expect $tensor_name:literal $expected_value:expr,)*) => {
$(
#[test]
fn $name() {
let full_text = format!("
in = [p]
p {{
1,
}}
u_i {{
y = p,
}}
dudt_i {{
dydt = p,
}}
{}
F_i {{
dydt,
}}
G_i {{
y,
}}
out_i {{
y,
}}
", $text);
let tmp_loc = format!("test_output/compiler_tensor_grad_test_{}", stringify!($name));
let results = tensor_test_common(full_text.as_str(), tmp_loc.as_str(), $tensor_name);
assert_relative_eq!(results[1].as_slice(), $expected_value.as_slice());
}
)*
}
}

tensor_grad_test! {
const_grad: "r { 3 }" expect "r" vec![0.],
const_vec_grad: "r_i { 3, 4 }" expect "r" vec![0., 0.],
input_grad: "r { 2 * p * p }" expect "r" vec![4.],
input_vec_grad: "r_i { 2 * p * p, 3 * p }" expect "r" vec![4., 3.],
state_grad: "r { 2 * y }" expect "r" vec![2.],
input_and_state_grad: "r { 2 * y * p }" expect "r" vec![4.],
}


#[test]
fn test_additional_functions() {
let full_text = "
Expand Down
22 changes: 19 additions & 3 deletions src/codegen/data_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ impl DataLayout {
let mut layout_map = HashMap::new();

let mut add_tensor = |tensor: &Tensor| {
let is_state = tensor.name() == "u" || tensor.name() == "dudt";
// insert the data (non-zeros) for each tensor
layout_map.insert(tensor.name().to_string(), tensor.layout_ptr().clone());
data_index_map.insert(tensor.name().to_string(), data.len());
data_length_map.insert(tensor.name().to_string(), tensor.nnz());
data.extend(vec![0.0; tensor.nnz()]);
if !is_state {
data_index_map.insert(tensor.name().to_string(), data.len());
data_length_map.insert(tensor.name().to_string(), tensor.nnz());
data.extend(vec![0.0; tensor.nnz()]);
}


// add the translation info for each block-tensor pair
Expand Down Expand Up @@ -88,6 +91,19 @@ impl DataLayout {
pub fn get_data_index(&self, name: &str) -> Option<usize> {
self.data_index_map.get(name).map(|i| *i)
}

pub fn format_data(&self, data: &[f64]) -> String {
let mut data_index_sorted: Vec<_> = self.data_index_map.iter().collect();
data_index_sorted.sort_by_key(|(_, index)| **index);
let mut s = String::new();
s += "[";
for (name, index) in data_index_sorted {
let nnz = self.data_length_map[name];
s += &format!("{}: {:?}, ", name, &data[*index..*index+nnz]);
}
s += "]";
s
}

pub fn get_tensor_data(&self, name: &str) -> Option<&[f64]> {
let index = self.get_data_index(name)?;
Expand Down
6 changes: 3 additions & 3 deletions src/codegen/sundials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ impl Sundials {
}


pub fn from_discrete_model<'m>(model: &'m DiscreteModel, options: Options) -> Result<Sundials> {
let compiler = Compiler::from_discrete_model(model, model.name()).unwrap();
pub fn from_discrete_model<'m>(model: &'m DiscreteModel, options: Options, out: &str) -> Result<Sundials> {
let compiler = Compiler::from_discrete_model(model, out).unwrap();
let number_of_states = compiler.number_of_states() as i64;
let number_of_parameters = compiler.number_of_parameters();

Expand Down Expand Up @@ -403,7 +403,7 @@ mod tests {
let discrete = DiscreteModel::from(&model_info);
println!("{}", discrete);
let options = Options::new();
let mut sundials = Sundials::from_discrete_model(&discrete, options).unwrap();
let mut sundials = Sundials::from_discrete_model(&discrete, options, "test_output/sundials_logistic_growth").unwrap();

let times = Array::linspace(0., 1., 5);

Expand Down

0 comments on commit 39b35bc

Please sign in to comment.