diff --git a/src/codegen/codegen.rs b/src/codegen/codegen.rs index e2794e0..accb2a2 100644 --- a/src/codegen/codegen.rs +++ b/src/codegen/codegen.rs @@ -190,7 +190,7 @@ impl<'ctx> CodeGen<'ctx> { let exp_negx = self.builder.build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name).ok()?; let one_plus_exp_negx = self.builder.build_float_add(exp_negx.try_as_basic_value().left().unwrap().into_float_value(), one, name).ok()?; let sigmoid = self.builder.build_float_div(one, one_plus_exp_negx, name).ok()?; - self.builder.build_return(Some(&sigmoid)); + self.builder.build_return(Some(&sigmoid)).ok(); self.builder.position_at_end(current_block); Some(fn_val) }, @@ -233,7 +233,7 @@ impl<'ctx> CodeGen<'ctx> { let name = a.name(); let elmt = a.elmts().first().unwrap(); let float_value = self.jit_compile_expr(name, &elmt.expr(), &[], elmt, None)?; - self.builder.build_store(res_ptr, float_value); + self.builder.build_store(res_ptr, float_value)?; Ok(res_ptr) } @@ -295,8 +295,8 @@ impl<'ctx> CodeGen<'ctx> { let expr_index_ptr = self.builder.build_alloca(int_type, "expr_index")?; let elmt_index_ptr = self.builder.build_alloca(int_type, "elmt_index")?; - self.builder.build_store(expr_index_ptr, zero); - self.builder.build_store(elmt_index_ptr, zero); + self.builder.build_store(expr_index_ptr, zero)?; + self.builder.build_store(elmt_index_ptr, zero)?; // setup indices, loop through the nested loops let mut indices = Vec::new(); @@ -312,7 +312,7 @@ impl<'ctx> CodeGen<'ctx> { for i in 0..expr_rank { let block = self.context.append_basic_block(self.fn_value(), name); - self.builder.build_unconditional_branch(block); + self.builder.build_unconditional_branch(block)?; self.builder.position_at_end(block); let start_index = int_type.const_int(0, false); @@ -320,7 +320,7 @@ impl<'ctx> CodeGen<'ctx> { curr_index.add_incoming(&[(&start_index, preblock)]); if i == expr_rank - contract_by - 1 && contract_sum.is_some() { - self.builder.build_store(contract_sum.unwrap(), self.real_type.const_zero()); + self.builder.build_store(contract_sum.unwrap(), self.real_type.const_zero())?; } indices.push(curr_index); @@ -335,17 +335,17 @@ impl<'ctx> CodeGen<'ctx> { let expr_index = self.builder.build_load(expr_index_ptr, "expr_index")?.into_int_value(); let elmt_index = self.builder.build_load(elmt_index_ptr, "elmt_index")?.into_int_value(); let next_expr_index = self.builder.build_int_add(expr_index, one, "next_expr_index")?; - self.builder.build_store(expr_index_ptr, next_expr_index); + self.builder.build_store(expr_index_ptr, next_expr_index)?; let float_value = self.jit_compile_expr(name, &elmt.expr(), indices_int.as_slice(), elmt, Some(expr_index))?; if contract_sum.is_some() { let contract_sum_value = self.builder.build_load(contract_sum.unwrap(), "contract_sum")?.into_float_value(); let new_contract_sum_value = self.builder.build_float_add(contract_sum_value, float_value, "new_contract_sum")?; - self.builder.build_store(contract_sum.unwrap(), new_contract_sum_value); + self.builder.build_store(contract_sum.unwrap(), new_contract_sum_value)?; } else { preblock = self.jit_compile_broadcast_and_store(name, elmt, float_value, expr_index, translation, preblock)?; let next_elmt_index = self.builder.build_int_add(elmt_index, one, "next_elmt_index")?; - self.builder.build_store(elmt_index_ptr, next_elmt_index); + self.builder.build_store(elmt_index_ptr, next_elmt_index)?; } // unwind the nested loops @@ -357,14 +357,14 @@ impl<'ctx> CodeGen<'ctx> { if i == expr_rank - contract_by - 1 && contract_sum.is_some() { let contract_sum_value= self.builder.build_load(contract_sum.unwrap(), "contract_sum")?.into_float_value(); let next_elmt_index = self.builder.build_int_add(elmt_index, one, "next_elmt_index")?; - self.builder.build_store(elmt_index_ptr, next_elmt_index); + self.builder.build_store(elmt_index_ptr, next_elmt_index)?; self.jit_compile_store(name, elmt, elmt_index, contract_sum_value, translation)?; } // loop condition let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, expr_shape[i], name)?; let block = self.context.append_basic_block(self.fn_value(), name); - self.builder.build_conditional_branch(loop_while, blocks[i], block); + self.builder.build_conditional_branch(loop_while, blocks[i], block)?; self.builder.position_at_end(block); preblock = block; } @@ -391,7 +391,7 @@ impl<'ctx> CodeGen<'ctx> { // loop through each contraction let block = self.context.append_basic_block(self.fn_value(), name); - self.builder.build_unconditional_branch(block); + self.builder.build_unconditional_branch(block)?; self.builder.position_at_end(block); let contract_index = self.builder.build_phi(int_type, "i")?; @@ -418,11 +418,11 @@ impl<'ctx> CodeGen<'ctx> { let end_contract = self.builder.build_load(end_ptr, "end")?.into_int_value(); // initialise the contract sum - self.builder.build_store(contract_sum_ptr, self.real_type.const_float(0.0)); + self.builder.build_store(contract_sum_ptr, self.real_type.const_float(0.0))?; // loop through each element in the contraction let contract_block = self.context.append_basic_block(self.fn_value(), format!("{}_contract", name).as_str()); - self.builder.build_unconditional_branch(contract_block); + self.builder.build_unconditional_branch(contract_block)?; self.builder.position_at_end(contract_block); let expr_index_phi = self.builder.build_phi(int_type, "j")?; @@ -431,18 +431,19 @@ impl<'ctx> CodeGen<'ctx> { // loop body - load index from layout let expr_index = expr_index_phi.as_basic_value().into_int_value(); let elmt_index_mult_rank = self.builder.build_int_mul(expr_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name)?; - let indices_int: Vec = (0..elmt.expr_layout().rank()).map(|i| { + let indices_int = (0..elmt.expr_layout().rank()).map(|i| { let layout_index_plus_offset = int_type.const_int((layout_index + i).try_into().unwrap(), false); let curr_index = self.builder.build_int_add(elmt_index_mult_rank, layout_index_plus_offset, name)?; let ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("indices"), &[curr_index], name)? }; - self.builder.build_load(ptr, name)?.into_int_value() - }).collect(); + let index = self.builder.build_load(ptr, name)?.into_int_value(); + Ok(index) + }).collect::, anyhow::Error>>()?; // loop body - eval expression and increment sum let float_value = self.jit_compile_expr(name, &elmt.expr(), indices_int.as_slice(), elmt, Some(expr_index))?; let contract_sum_value = self.builder.build_load(contract_sum_ptr, "contract_sum")?.into_float_value(); let new_contract_sum_value = self.builder.build_float_add(contract_sum_value, float_value, "new_contract_sum")?; - self.builder.build_store(contract_sum_ptr, new_contract_sum_value); + self.builder.build_store(contract_sum_ptr, new_contract_sum_value)?; // increment contract loop index let next_elmt_index = self.builder.build_int_add(expr_index, int_type.const_int(1, false), name)?; @@ -451,7 +452,7 @@ impl<'ctx> CodeGen<'ctx> { // contract loop condition let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_elmt_index, end_contract, name)?; let post_contract_block = self.context.append_basic_block(self.fn_value(), name); - self.builder.build_conditional_branch(loop_while, contract_block, post_contract_block); + self.builder.build_conditional_branch(loop_while, contract_block, post_contract_block)?; self.builder.position_at_end(post_contract_block); // store the result @@ -464,7 +465,7 @@ impl<'ctx> CodeGen<'ctx> { // outer loop condition let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_contract_index, final_contract_index, name)?; let post_block = self.context.append_basic_block(self.fn_value(), name); - self.builder.build_conditional_branch(loop_while, block, post_block); + self.builder.build_conditional_branch(loop_while, block, post_block)?; self.builder.position_at_end(post_block); Ok(()) @@ -480,7 +481,7 @@ impl<'ctx> CodeGen<'ctx> { let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap(); // loop through the non-zero elements let mut block = self.context.append_basic_block(self.fn_value(), name); - self.builder.build_unconditional_branch(block); + self.builder.build_unconditional_branch(block)?; self.builder.position_at_end(block); let start_index = int_type.const_int(0, false); @@ -491,12 +492,12 @@ impl<'ctx> CodeGen<'ctx> { // loop body - load index from layout let elmt_index = curr_index.as_basic_value().into_int_value(); let elmt_index_mult_rank = self.builder.build_int_mul(elmt_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name)?; - let indices_int: Vec = (0..elmt.expr_layout().rank()).map(|i| { + let indices_int = (0..elmt.expr_layout().rank()).map(|i| { let layout_index_plus_offset = int_type.const_int((layout_index + i).try_into().unwrap(), false); let curr_index = self.builder.build_int_add(elmt_index_mult_rank, layout_index_plus_offset, name)?; let ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("indices"), &[curr_index], name)? }; - self.builder.build_load(ptr, name)?.into_int_value() - }).collect(); + Ok(self.builder.build_load(ptr, name)?.into_int_value()) + }).collect::, anyhow::Error>>()?; // loop body - eval expression let float_value = self.jit_compile_expr(name, &elmt.expr(), indices_int.as_slice(), elmt, Some(elmt_index))?; @@ -511,7 +512,7 @@ impl<'ctx> CodeGen<'ctx> { // loop condition let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, end_index, name)?; let post_block = self.context.append_basic_block(self.fn_value(), name); - self.builder.build_conditional_branch(loop_while, block, post_block); + self.builder.build_conditional_branch(loop_while, block, post_block)?; self.builder.position_at_end(post_block); Ok(()) @@ -525,7 +526,7 @@ impl<'ctx> CodeGen<'ctx> { // loop through the non-zero elements let mut block = self.context.append_basic_block(self.fn_value(), name); - self.builder.build_unconditional_branch(block); + self.builder.build_unconditional_branch(block)?; self.builder.position_at_end(block); let start_index = int_type.const_int(0, false); @@ -553,7 +554,7 @@ impl<'ctx> CodeGen<'ctx> { // loop condition let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, end_index, name)?; let post_block = self.context.append_basic_block(self.fn_value(), name); - self.builder.build_conditional_branch(loop_while, block, post_block); + self.builder.build_conditional_branch(loop_while, block, post_block)?; self.builder.position_at_end(post_block); Ok(()) @@ -570,7 +571,7 @@ impl<'ctx> CodeGen<'ctx> { // setup loop block let bcast_block = self.context.append_basic_block(self.fn_value(), name); - self.builder.build_unconditional_branch(bcast_block); + self.builder.build_unconditional_branch(bcast_block)?; self.builder.position_at_end(bcast_block); let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?; bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]); @@ -590,7 +591,7 @@ impl<'ctx> CodeGen<'ctx> { // loop condition let bcast_cond = self.builder.build_int_compare(IntPredicate::ULT, bcast_next_index, bcast_end_index, "broadcast_cond")?; let post_bcast_block = self.context.append_basic_block(self.fn_value(), name); - self.builder.build_conditional_branch(bcast_cond, bcast_block, post_bcast_block); + self.builder.build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?; self.builder.position_at_end(post_bcast_block); // return the current block for later @@ -627,7 +628,7 @@ impl<'ctx> CodeGen<'ctx> { }, }; let resi_ptr = unsafe { self.builder.build_in_bounds_gep(self.tensor_ptr(), &[res_index], name)? }; - self.builder.build_store(resi_ptr, float_value); + self.builder.build_store(resi_ptr, float_value)?; Ok(()) } @@ -749,8 +750,8 @@ impl<'ctx> CodeGen<'ctx> { match arg { BasicValueEnum::PointerValue(v) => v, BasicValueEnum::FloatValue(v) => { - let alloca = self.create_entry_block_builder().build_alloca(arg.get_type(), name)?; - self.builder.build_store(alloca, v); + let alloca = self.create_entry_block_builder().build_alloca(arg.get_type(), name).unwrap(); + self.builder.build_store(alloca, v).unwrap(); alloca } _ => unreachable!() @@ -789,7 +790,7 @@ impl<'ctx> CodeGen<'ctx> { self.jit_compile_tensor(&model.state(), Some(*self.get_param("u0")))?; self.jit_compile_tensor(&model.state_dot(), Some(*self.get_param("dudt0")))?; - self.builder.build_return(None); + self.builder.build_return(None)?; if function.verify(true) { self.fpm.run_on(&function); @@ -829,7 +830,7 @@ impl<'ctx> CodeGen<'ctx> { self.insert_data(model); self.jit_compile_tensor(model.out(), Some(*self.get_var(model.out())))?; - self.builder.build_return(None); + self.builder.build_return(None)?; if function.verify(true) { self.fpm.run_on(&function); @@ -882,7 +883,7 @@ impl<'ctx> CodeGen<'ctx> { let res_ptr = self.get_param("rr"); let _res_ptr = self.jit_compile_tensor(&residual, Some(*res_ptr))?; - self.builder.build_return(None); + self.builder.build_return(None)?; if function.verify(true) { self.fpm.run_on(&function); @@ -920,12 +921,12 @@ impl<'ctx> CodeGen<'ctx> { let number_of_outputs = model.out().nnz() as u64; let indices_len = self.layout.indices().len() as u64; let data_len = self.layout.data().len() as u64; - self.builder.build_store(*self.get_param("states"), self.int_type.const_int(number_of_states, false)); - self.builder.build_store(*self.get_param("inputs"), self.int_type.const_int(number_of_inputs, false)); - self.builder.build_store(*self.get_param("outputs"), self.int_type.const_int(number_of_outputs, false)); - self.builder.build_store(*self.get_param("indices"), self.int_type.const_int(indices_len, false)); - self.builder.build_store(*self.get_param("data"), self.int_type.const_int(data_len, false)); - self.builder.build_return(None); + self.builder.build_store(*self.get_param("states"), self.int_type.const_int(number_of_states, false))?; + self.builder.build_store(*self.get_param("inputs"), self.int_type.const_int(number_of_inputs, false))?; + self.builder.build_store(*self.get_param("outputs"), self.int_type.const_int(number_of_outputs, false))?; + self.builder.build_store(*self.get_param("indices"), self.int_type.const_int(indices_len, false))?; + self.builder.build_store(*self.get_param("data"), self.int_type.const_int(data_len, false))?; + self.builder.build_return(None)?; if function.verify(true) { self.fpm.run_on(&function); @@ -966,9 +967,9 @@ impl<'ctx> CodeGen<'ctx> { let ptr = self.get_param(name); let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64; let tensor_size_value = self.int_type.const_int(tensor_size, false); - self.builder.build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum()); - self.builder.build_store(*self.get_param("tensor_size"), tensor_size_value); - self.builder.build_return(None); + self.builder.build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?; + self.builder.build_store(*self.get_param("tensor_size"), tensor_size_value)?; + self.builder.build_return(None)?; if function.verify(true) { self.fpm.run_on(&function); @@ -1014,18 +1015,18 @@ impl<'ctx> CodeGen<'ctx> { let end_index = self.int_type.const_int(input.nnz().try_into().unwrap(), false); let input_block = self.context.append_basic_block(function, name.as_str()); - self.builder.build_unconditional_branch(input_block); + self.builder.build_unconditional_branch(input_block)?; self.builder.position_at_end(input_block); - let index = self.builder.build_phi(self.int_type, "i"); + let index = self.builder.build_phi(self.int_type, "i")?; index.add_incoming(&[(&start_index, block)]); // loop body - copy value from inputs to data let curr_input_index = index.as_basic_value().into_int_value(); - let input_ptr = unsafe { self.builder.build_in_bounds_gep(*ptr, &[curr_input_index], name.as_str()) }; - let curr_inputs_index = self.builder.build_int_add(inputs_start_index, curr_input_index, name.as_str()); - let inputs_ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("inputs"), &[curr_inputs_index], name.as_str()) }; - let input_value = self.builder.build_load(inputs_ptr, name.as_str()).into_float_value(); - self.builder.build_store(input_ptr, input_value); + let input_ptr = unsafe { self.builder.build_in_bounds_gep(*ptr, &[curr_input_index], name.as_str())? }; + let curr_inputs_index = self.builder.build_int_add(inputs_start_index, curr_input_index, name.as_str())?; + let inputs_ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("inputs"), &[curr_inputs_index], name.as_str())? }; + let input_value = self.builder.build_load(inputs_ptr, name.as_str())?.into_float_value(); + self.builder.build_store(input_ptr, input_value)?; // increment loop index let one = self.int_type.const_int(1, false); @@ -1033,16 +1034,16 @@ impl<'ctx> CodeGen<'ctx> { index.add_incoming(&[(&next_index, input_block)]); // loop condition - let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, end_index, name.as_str()); + let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, end_index, name.as_str())?; let post_block = self.context.append_basic_block(function, name.as_str()); - self.builder.build_conditional_branch(loop_while, input_block, post_block); + self.builder.build_conditional_branch(loop_while, input_block, post_block)?; self.builder.position_at_end(post_block); // get ready for next input block = post_block; inputs_index = inputs_index + input.nnz(); } - self.builder.build_return(None); + self.builder.build_return(None)?; if function.verify(true) { self.fpm.run_on(&function); @@ -1085,18 +1086,18 @@ impl<'ctx> CodeGen<'ctx> { let blk_end_index = self.int_type.const_int(blk.nnz().try_into().unwrap(), false); let blk_block = self.context.append_basic_block(function, name); - self.builder.build_unconditional_branch(blk_block); + self.builder.build_unconditional_branch(blk_block)?; self.builder.position_at_end(blk_block); - let index = self.builder.build_phi(self.int_type, "i"); + let index = self.builder.build_phi(self.int_type, "i")?; index.add_incoming(&[(&blk_start_index, block)]); // loop body - copy value from inputs to data let curr_blk_index = index.as_basic_value().into_int_value(); - let curr_id_index = self.builder.build_int_add(id_start_index, curr_blk_index, name); - let id_ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("id"), &[curr_id_index], name) }; + let curr_id_index = self.builder.build_int_add(id_start_index, curr_blk_index, name)?; + let id_ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("id"), &[curr_id_index], name)? }; let is_algebraic_float = if *is_algebraic { 0.0 as realtype } else { 1.0 as realtype }; let is_algebraic_value = self.real_type.const_float(is_algebraic_float); - self.builder.build_store(id_ptr, is_algebraic_value); + self.builder.build_store(id_ptr, is_algebraic_value)?; // increment loop index let one = self.int_type.const_int(1, false); @@ -1104,16 +1105,16 @@ impl<'ctx> CodeGen<'ctx> { index.add_incoming(&[(&next_index, blk_block)]); // loop condition - let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, blk_end_index, name); + let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, blk_end_index, name)?; let post_block = self.context.append_basic_block(function, name); - self.builder.build_conditional_branch(loop_while, blk_block, post_block); + self.builder.build_conditional_branch(loop_while, blk_block, post_block)?; self.builder.position_at_end(post_block); // get ready for next blk block = post_block; id_index = id_index + blk.nnz(); } - self.builder.build_return(None); + self.builder.build_return(None)?; if function.verify(true) { self.fpm.run_on(&function); diff --git a/src/codegen/compiler.rs b/src/codegen/compiler.rs index 324628e..8d20851 100644 --- a/src/codegen/compiler.rs +++ b/src/codegen/compiler.rs @@ -265,9 +265,14 @@ impl Compiler { } pub fn write_bitcode_to_path(&self, path: &Path) -> Result<()> { - self.with_data(|data| - data.codegen.module().write_bitcode_to_path(path).map_err(|e| anyhow::anyhow!("Error writing bitcode: {:?}", e)) - ) + self.with_data(|data| { + let result = data.codegen.module().write_bitcode_to_path(path); + if result { + Ok(()) + } else { + Err(anyhow!("Error writing bitcode to path")) + } + }) } pub fn write_object_file(&self, path: &Path) -> Result<()> { diff --git a/src/lib.rs b/src/lib.rs index 5228b68..550ec66 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -99,7 +99,7 @@ pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOp let compiler = Compiler::from_discrete_model(&discrete_model)?; - compiler.write_bitcode_to_path(bytecodefile); + compiler.write_bitcode_to_path(bytecodefile)?; if bytecode_only { return Ok(()); @@ -151,7 +151,7 @@ pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOp } let runtime_path = runtime_path.unwrap(); let mut command = Command::new(command_name); - command.arg("-o").arg(out).arg(objectname.clone()); + command.arg("-o").arg(out).arg(out); for file in linked_files { command.arg(Path::new(runtime_path).join(file)); } @@ -163,7 +163,7 @@ pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOp command.output() } else { let mut command = Command::new(command_name); - command.arg("-o").arg(out).arg(objectname.clone()); + command.arg("-o").arg(out).arg(out); if standalone { command.arg("-ldiffeq_runtime"); } else { @@ -186,9 +186,6 @@ pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOp } } - // clean up the object file - std::fs::remove_file(objectfile)?; - Ok(()) }