diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 219dc84..3563254 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,23 +25,29 @@ jobs: unit-tests: name: Tests - ${{ matrix.os }} - ${{ matrix.toolchain }} - ${{ matrix.llvm }} - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} strategy: matrix: llvm: - "14" - "15" + - "16" + - "17" toolchain: - stable os: - ubuntu-latest - - macos-latest - - windows-latest include: - toolchain: beta os: ubuntu-latest + llvm: "14" - toolchain: nightly os: ubuntu-latest + llvm: "14" + - toolchain: stable + os: macos-13 + llvm: "14" + steps: - uses: actions/checkout@v4 diff --git a/Cargo.toml b/Cargo.toml index 2342b31..f5058aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,15 +14,6 @@ name = "diffsl" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -llvm4-0 = ["inkwell-40", "llvm-sys-40"] -llvm5-0 = ["inkwell-50", "llvm-sys-50"] -llvm6-0 = ["inkwell-60", "llvm-sys-60"] -llvm7-0 = ["inkwell-70", "llvm-sys-70"] -llvm8-0 = ["inkwell-80", "llvm-sys-80"] -llvm9-0 = ["inkwell-90", "llvm-sys-90"] -llvm10-0 = ["inkwell-100", "llvm-sys-100"] -llvm11-0 = ["inkwell-110", "llvm-sys-110"] -llvm12-0 = ["inkwell-120", "llvm-sys-120"] llvm13-0 = ["inkwell-130", "llvm-sys-130"] llvm14-0 = ["inkwell-140", "llvm-sys-140"] llvm15-0 = ["inkwell-150", "llvm-sys-150"] @@ -39,34 +30,17 @@ itertools = ">=0.10.3" ouroboros = ">=0.17" clap = { version = "4.3.23", features = ["derive"] } uid = "0.1.7" -inkwell-40 = { package = "inkwell", version = ">=0.4.0", features = ["llvm4-0"], optional = true } -inkwell-50 = { package = "inkwell", version = ">=0.4.0", features = ["llvm5-0"], optional = true } -inkwell-60 = { package = "inkwell", version = ">=0.4.0", features = ["llvm6-0"], optional = true } -inkwell-70 = { package = "inkwell", version = ">=0.4.0", features = ["llvm7-0"], optional = true } -inkwell-80 = { package = "inkwell", version = ">=0.4.0", features = ["llvm8-0"], optional = true } -inkwell-90 = { package = "inkwell", version = ">=0.4.0", features = ["llvm9-0"], optional = true } -inkwell-100 = { package = "inkwell", version = ">=0.4.0", features = ["llvm10-0"], optional = true } -inkwell-110 = { package = "inkwell", version = ">=0.4.0", features = ["llvm11-0"], optional = true } -inkwell-120 = { package = "inkwell", version = ">=0.4.0", features = ["llvm12-0"], optional = true } inkwell-130 = { package = "inkwell", version = ">=0.4.0", features = ["llvm13-0"], optional = true } inkwell-140 = { package = "inkwell", version = ">=0.4.0", features = ["llvm14-0"], optional = true } inkwell-150 = { package = "inkwell", version = ">=0.4.0", features = ["llvm15-0"], optional = true } inkwell-160 = { package = "inkwell", version = ">=0.4.0", features = ["llvm16-0"], optional = true } inkwell-170 = { package = "inkwell", version = ">=0.4.0", features = ["llvm17-0"], optional = true } -llvm-sys-40 = { package = "llvm-sys", version = "40.4", optional = true } -llvm-sys-50 = { package = "llvm-sys", version = "50.4", optional = true } -llvm-sys-60 = { package = "llvm-sys", version = "60.6", optional = true } -llvm-sys-70 = { package = "llvm-sys", version = "70.4", optional = true } -llvm-sys-80 = { package = "llvm-sys", version = "80.3", optional = true } -llvm-sys-90 = { package = "llvm-sys", version = "90.2.1", optional = true } -llvm-sys-100 = { package = "llvm-sys", version = "100.2.3", optional = true } -llvm-sys-110 = { package = "llvm-sys", version = "110.0.3", optional = true } -llvm-sys-120 = { package = "llvm-sys", version = "120.2.4", optional = true } llvm-sys-130 = { package = "llvm-sys", version = "130.0.4", optional = true } llvm-sys-140 = { package = "llvm-sys", version = "140.0.2", optional = true } llvm-sys-150 = { package = "llvm-sys", version = "150.0.3", optional = true } llvm-sys-160 = { package = "llvm-sys", version = "160.1.0", optional = true } llvm-sys-170 = { package = "llvm-sys", version = "170.0.1", optional = true } +inkwell_internals = "0.9.0" [build-dependencies] cmake = "0.1.50" diff --git a/README.md b/README.md index a109812..a7e3693 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,6 @@ You can install DiffSL using cargo. You will need to indicate the llvm version y cargo add diffsl --features llvm14-0 ``` -Other versions of llvm are also supported given by the features `llvm4-0`, `llvm5-0`, `llvm6-0`, `llvm7-0`, `llvm8-0`, `llvm9-0`, `llvm10-0`, `llvm11-0`, `llvm12-0`, `llvm13-0`, `llvm14-0`, `llvm15-0`, `llvm16-0`, `llvm17-0`. +Other versions of llvm are also supported given by the features `llvm13-0`, `llvm14-0`, `llvm15-0`, `llvm16-0`, `llvm17-0`. diff --git a/src/execution/codegen.rs b/src/execution/codegen.rs index 7f4aa95..cb3fb78 100644 --- a/src/execution/codegen.rs +++ b/src/execution/codegen.rs @@ -5,13 +5,13 @@ use inkwell::builder::Builder; use inkwell::context::AsContextRef; use inkwell::intrinsics::Intrinsic; use inkwell::module::Module; -use inkwell::passes::PassManager; -use inkwell::types::{AnyTypeEnum, BasicMetadataTypeEnum, BasicTypeEnum, FloatType, IntType}; +use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, IntType}; use inkwell::values::{ AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, FloatValue, FunctionValue, GlobalValue, IntValue, PointerValue, }; use inkwell::{AddressSpace, FloatPredicate, IntPredicate}; +use inkwell_internals::llvm_versions; use llvm_sys::prelude::LLVMValueRef; use std::collections::HashMap; use std::iter::zip; @@ -99,7 +99,7 @@ pub type GetOutFunc = unsafe extern "C" fn( ); struct Globals<'ctx> { - indices: GlobalValue<'ctx>, + indices: Option>, } impl<'ctx> Globals<'ctx> { @@ -107,7 +107,10 @@ impl<'ctx> Globals<'ctx> { layout: &DataLayout, context: &'ctx inkwell::context::Context, module: &Module<'ctx>, - ) -> Result { + ) -> Self { + if layout.indices().is_empty() { + return Self { indices: None }; + } let int_type = context.i32_type(); let indices_array_type = int_type.array_type(u32::try_from(layout.indices().len()).unwrap()); @@ -119,14 +122,14 @@ impl<'ctx> Globals<'ctx> { let indices_value = int_type.const_array(indices_array_values.as_slice()); let _int_ptr_type = int_type.ptr_type(AddressSpace::default()); let globals = Self { - indices: module.add_global( + indices: Some(module.add_global( indices_array_type, Some(AddressSpace::default()), "indices", - ), + )), }; - globals.indices.set_initializer(&indices_value); - Ok(globals) + globals.indices.unwrap().set_initializer(&indices_value); + globals } } @@ -140,7 +143,6 @@ pub struct CodeGen<'ctx> { context: &'ctx inkwell::context::Context, module: Module<'ctx>, builder: Builder<'ctx>, - fpm: PassManager>, variables: HashMap>, functions: HashMap>, fn_value_opt: Option>, @@ -149,7 +151,7 @@ pub struct CodeGen<'ctx> { real_type_str: String, int_type: IntType<'ctx>, layout: DataLayout, - globals: Option>, + globals: Globals<'ctx>, } impl<'ctx> CodeGen<'ctx> { @@ -160,24 +162,13 @@ impl<'ctx> CodeGen<'ctx> { real_type: FloatType<'ctx>, real_type_str: &str, ) -> Self { - let fpm = PassManager::create(&module); - fpm.add_instruction_combining_pass(); - fpm.add_reassociate_pass(); - fpm.add_gvn_pass(); - fpm.add_cfg_simplification_pass(); - fpm.add_basic_alias_analysis_pass(); - fpm.add_promote_memory_to_register_pass(); - fpm.add_instruction_combining_pass(); - fpm.add_reassociate_pass(); - fpm.initialize(); let builder = context.create_builder(); let layout = DataLayout::new(model); - let globals = Globals::new(&layout, context, &module).ok(); + let globals = Globals::new(&layout, context, &module); Self { context, module, builder, - fpm, real_type, real_type_str: real_type_str.to_owned(), variables: HashMap::new(), @@ -214,21 +205,113 @@ impl<'ctx> CodeGen<'ctx> { self.insert_tensor(model.rhs()); } + #[llvm_versions(4.0..=14.0)] fn insert_indices(&mut self) { - let indices = self.globals.as_ref().unwrap().indices; - let zero = self.context.i32_type().const_int(0, false); - let ptr = unsafe { - indices - .as_pointer_value() - .const_in_bounds_gep(&[zero, zero]) - }; - self.variables.insert("indices".to_owned(), ptr); + if let Some(indices) = self.globals.indices.as_ref() { + let zero = self.context.i32_type().const_int(0, false); + let ptr = unsafe { + indices + .as_pointer_value() + .const_in_bounds_gep(&[zero, zero]) + }; + self.variables.insert("indices".to_owned(), ptr); + } + } + + #[llvm_versions(15.0..=latest)] + fn insert_indices(&mut self) { + if let Some(indices) = self.globals.indices.as_ref() { + let i32_type = self.context.i32_type(); + let zero = i32_type.const_int(0, false); + let ptr = unsafe { + indices + .as_pointer_value() + .const_in_bounds_gep(i32_type, &[zero]) + }; + self.variables.insert("indices".to_owned(), ptr); + } } fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) { self.variables.insert(name.to_owned(), value); } + #[llvm_versions(4.0..=14.0)] + fn build_gep>( + &self, + _ty: T, + ptr: PointerValue<'ctx>, + ordered_indexes: &[IntValue<'ctx>], + name: &str, + ) -> Result> { + unsafe { + self.builder + .build_gep(ptr, ordered_indexes, name) + .map_err(|e| e.into()) + } + } + + #[llvm_versions(15.0..=latest)] + fn build_gep>( + &self, + ty: T, + ptr: PointerValue<'ctx>, + ordered_indexes: &[IntValue<'ctx>], + name: &str, + ) -> Result> { + unsafe { + self.builder + .build_gep(ty, ptr, ordered_indexes, name) + .map_err(|e| e.into()) + } + } + + #[llvm_versions(4.0..=14.0)] + fn build_load>( + &self, + _ty: T, + ptr: PointerValue<'ctx>, + name: &str, + ) -> Result> { + self.builder.build_load(ptr, name).map_err(|e| e.into()) + } + + #[llvm_versions(15.0..=latest)] + fn build_load>( + &self, + ty: T, + ptr: PointerValue<'ctx>, + name: &str, + ) -> Result> { + self.builder.build_load(ty, ptr, name).map_err(|e| e.into()) + } + + #[llvm_versions(4.0..=14.0)] + fn get_ptr_to_index>( + builder: &Builder<'ctx>, + _ty: T, + ptr: &PointerValue<'ctx>, + index: IntValue<'ctx>, + name: &str, + ) -> PointerValue<'ctx> { + unsafe { builder.build_in_bounds_gep(*ptr, &[index], name).unwrap() } + } + + #[llvm_versions(15.0..=latest)] + fn get_ptr_to_index>( + builder: &Builder<'ctx>, + ty: T, + ptr: &PointerValue<'ctx>, + index: IntValue<'ctx>, + name: &str, + ) -> PointerValue<'ctx> { + unsafe { + builder + .build_in_bounds_gep(ty, *ptr, &[index], name) + .unwrap() + } + } + fn insert_state(&mut self, u: &Tensor) { let mut data_index = 0; for blk in u.elmts() { @@ -238,11 +321,13 @@ impl<'ctx> CodeGen<'ctx> { .context .i32_type() .const_int(data_index.try_into().unwrap(), false); - let alloca = unsafe { - self.create_entry_block_builder() - .build_in_bounds_gep(*ptr, &[i], blk.name().unwrap()) - .unwrap() - }; + let alloca = Self::get_ptr_to_index( + &self.create_entry_block_builder(), + self.real_type, + ptr, + i, + blk.name().unwrap(), + ); self.variables.insert(name.to_owned(), alloca); } data_index += blk.nnz(); @@ -257,11 +342,13 @@ impl<'ctx> CodeGen<'ctx> { .context .i32_type() .const_int(data_index.try_into().unwrap(), false); - let alloca = unsafe { - self.create_entry_block_builder() - .build_in_bounds_gep(*ptr, &[i], blk.name().unwrap()) - .unwrap() - }; + let alloca = Self::get_ptr_to_index( + &self.create_entry_block_builder(), + self.real_type, + ptr, + i, + blk.name().unwrap(), + ); self.variables.insert(name.to_owned(), alloca); } data_index += blk.nnz(); @@ -274,11 +361,13 @@ impl<'ctx> CodeGen<'ctx> { .context .i32_type() .const_int(data_index.try_into().unwrap(), false); - let alloca = unsafe { - self.create_entry_block_builder() - .build_in_bounds_gep(ptr, &[i], tensor.name()) - .unwrap() - }; + let alloca = Self::get_ptr_to_index( + &self.create_entry_block_builder(), + self.real_type, + &ptr, + i, + tensor.name(), + ); self.variables.insert(tensor.name().to_owned(), alloca); //insert any named blocks @@ -288,11 +377,13 @@ impl<'ctx> CodeGen<'ctx> { .context .i32_type() .const_int(data_index.try_into().unwrap(), false); - let alloca = unsafe { - self.create_entry_block_builder() - .build_in_bounds_gep(ptr, &[i], name) - .unwrap() - }; + let alloca = Self::get_ptr_to_index( + &self.create_entry_block_builder(), + self.real_type, + &ptr, + i, + name, + ); self.variables.insert(name.to_owned(), alloca); } // named blocks only supported for rank <= 1, so we can just add the nnz to get the next data index @@ -734,12 +825,10 @@ impl<'ctx> CodeGen<'ctx> { // load and increment the expression index let expr_index = self - .builder - .build_load(expr_index_ptr, "expr_index")? + .build_load(self.int_type, expr_index_ptr, "expr_index")? .into_int_value(); let elmt_index = self - .builder - .build_load(elmt_index_ptr, "elmt_index")? + .build_load(self.int_type, elmt_index_ptr, "elmt_index")? .into_int_value(); let next_expr_index = self .builder @@ -755,8 +844,7 @@ impl<'ctx> CodeGen<'ctx> { if contract_sum.is_some() { let contract_sum_value = self - .builder - .build_load(contract_sum.unwrap(), "contract_sum")? + .build_load(self.real_type, contract_sum.unwrap(), "contract_sum")? .into_float_value(); let new_contract_sum_value = self.builder.build_float_add( contract_sum_value, @@ -788,8 +876,7 @@ impl<'ctx> CodeGen<'ctx> { if i == expr_rank - contract_by - 1 && contract_sum.is_some() { let contract_sum_value = self - .builder - .build_load(contract_sum.unwrap(), "contract_sum")? + .build_load(self.real_type, contract_sum.unwrap(), "contract_sum")? .into_float_value(); let next_elmt_index = self.builder @@ -860,22 +947,24 @@ impl<'ctx> CodeGen<'ctx> { let end_index = self.builder .build_int_add(start_index, int_type.const_int(1, false), name)?; - let start_ptr = unsafe { - self.builder.build_gep( - *self.get_param("indices"), - &[start_index], - "start_index_ptr", - )? - }; + let start_ptr = self.build_gep( + self.int_type, + *self.get_param("indices"), + &[start_index], + "start_index_ptr", + )?; let start_contract = self - .builder - .build_load(start_ptr, "start")? + .build_load(self.int_type, start_ptr, "start")? + .into_int_value(); + let end_ptr = self.build_gep( + self.int_type, + *self.get_param("indices"), + &[end_index], + "end_index_ptr", + )?; + let end_contract = self + .build_load(self.int_type, end_ptr, "end")? .into_int_value(); - let end_ptr = unsafe { - self.builder - .build_gep(*self.get_param("indices"), &[end_index], "end_index_ptr")? - }; - let end_contract = self.builder.build_load(end_ptr, "end")?.into_int_value(); // initialise the contract sum self.builder @@ -907,14 +996,14 @@ impl<'ctx> CodeGen<'ctx> { layout_index_plus_offset, name, )?; - let ptr = unsafe { - self.builder.build_in_bounds_gep( - *self.get_param("indices"), - &[curr_index], - name, - )? - }; - let index = self.builder.build_load(ptr, name)?.into_int_value(); + let ptr = Self::get_ptr_to_index( + &self.builder, + self.int_type, + self.get_param("indices"), + curr_index, + name, + ); + let index = self.build_load(self.int_type, ptr, name)?.into_int_value(); Ok(index) }) .collect::, anyhow::Error>>()?; @@ -928,8 +1017,7 @@ impl<'ctx> CodeGen<'ctx> { Some(expr_index), )?; let contract_sum_value = self - .builder - .build_load(contract_sum_ptr, "contract_sum")? + .build_load(self.real_type, contract_sum_ptr, "contract_sum")? .into_float_value(); let new_contract_sum_value = self.builder @@ -1025,14 +1113,14 @@ impl<'ctx> CodeGen<'ctx> { layout_index_plus_offset, name, )?; - let ptr = unsafe { - self.builder.build_in_bounds_gep( - *self.get_param("indices"), - &[curr_index], - name, - )? - }; - Ok(self.builder.build_load(ptr, name)?.into_int_value()) + let ptr = Self::get_ptr_to_index( + &self.builder, + self.int_type, + self.get_param("indices"), + curr_index, + name, + ); + Ok(self.build_load(self.int_type, ptr, name)?.into_int_value()) }) .collect::, anyhow::Error>>()?; @@ -1231,20 +1319,24 @@ impl<'ctx> CodeGen<'ctx> { let curr_index = self.builder .build_int_add(elmt_index_strided, translate_store_index, name)?; - let ptr = unsafe { - self.builder.build_in_bounds_gep( - *self.get_param("indices"), - &[curr_index], - name, - )? - }; - self.builder.build_load(ptr, name)?.into_int_value() + let ptr = Self::get_ptr_to_index( + &self.builder, + self.int_type, + self.get_param("indices"), + curr_index, + name, + ); + self.build_load(self.int_type, ptr, name)?.into_int_value() } }; - let resi_ptr = unsafe { - self.builder - .build_in_bounds_gep(self.tensor_ptr(), &[res_index], name)? - }; + + let resi_ptr = Self::get_ptr_to_index( + &self.builder, + self.real_type, + &self.tensor_ptr(), + res_index, + name, + ); self.builder.build_store(resi_ptr, float_value)?; Ok(()) } @@ -1350,24 +1442,27 @@ impl<'ctx> CodeGen<'ctx> { panic!("unexpected layout"); }; let value_ptr = match iname_elmt_index { - Some(index) => unsafe { - self.builder.build_in_bounds_gep(*ptr, &[index], name)? - }, + Some(index) => { + Self::get_ptr_to_index(&self.builder, self.real_type, ptr, index, name) + } None => *ptr, }; - Ok(self.builder.build_load(value_ptr, name)?.into_float_value()) + Ok(self + .build_load(self.real_type, value_ptr, name)? + .into_float_value()) } AstKind::Name(name) => { // must be a scalar, just load the value let ptr = self.get_param(name); - Ok(self.builder.build_load(*ptr, name)?.into_float_value()) + Ok(self + .build_load(self.real_type, *ptr, name)? + .into_float_value()) } AstKind::NamedGradient(name) => { let name_str = name.to_string(); let ptr = self.get_param(name_str.as_str()); Ok(self - .builder - .build_load(*ptr, name_str.as_str())? + .build_load(self.real_type, *ptr, name_str.as_str())? .into_float_value()) } AstKind::Index(_) => todo!(), @@ -1436,8 +1531,6 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_return(None)?; if function.verify(true) { - self.fpm.run_on(&function); - Ok(function) } else { function.print_to_stderr(); @@ -1493,8 +1586,6 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_return(None)?; if function.verify(true) { - self.fpm.run_on(&function); - Ok(function) } else { function.print_to_stderr(); @@ -1552,8 +1643,6 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_return(None)?; if function.verify(true) { - self.fpm.run_on(&function); - Ok(function) } else { function.print_to_stderr(); @@ -1618,7 +1707,6 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_return(None)?; if function.verify(true) { - self.fpm.run_on(&function); Ok(function) } else { function.print_to_stderr(); @@ -1688,7 +1776,6 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_return(None)?; if function.verify(true) { - self.fpm.run_on(&function); Ok(function) } else { function.print_to_stderr(); @@ -1754,13 +1841,13 @@ impl<'ctx> CodeGen<'ctx> { let mut enzyme_fn_args: Vec = Vec::new(); let mut input_activity = Vec::new(); let mut arg_trees = Vec::new(); - for (i, _arg) in original_function.get_param_iter().enumerate() { + for (i, arg) in original_function.get_param_iter().enumerate() { let param_index = start_param_index[i]; let fn_arg = function.get_nth_param(param_index).unwrap(); // we'll probably only get double or pointers to doubles, so let assume this for now // todo: perhaps refactor this into a recursive function, might be overkill - let concrete_type = match _arg.get_type() { + let concrete_type = match arg.get_type() { BasicTypeEnum::PointerType(_) => CConcreteType_DT_Pointer, BasicTypeEnum::FloatType(t) => { if t == self.context.f64_type() { @@ -1781,17 +1868,8 @@ impl<'ctx> CodeGen<'ctx> { // pointer to double if concrete_type == CConcreteType_DT_Pointer { - let inner_concrete_type = - match _arg.get_type().into_pointer_type().get_element_type() { - AnyTypeEnum::FloatType(t) => { - if t == self.context.f64_type() { - CConcreteType_DT_Double - } else { - panic!("unsupported type") - } - } - _ => panic!("unsupported type"), - }; + // assume the pointer is to a double + let inner_concrete_type = CConcreteType_DT_Double; let inner_new_tree = unsafe { EnzymeNewTypeTreeCT( inner_concrete_type, @@ -1904,7 +1982,6 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_return(None)?; if function.verify(true) { - self.fpm.run_on(&function); Ok(function) } else { function.print_to_stderr(); @@ -1975,7 +2052,6 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_return(None)?; if function.verify(true) { - self.fpm.run_on(&function); Ok(function) } else { function.print_to_stderr(); @@ -2033,7 +2109,6 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_return(None)?; if function.verify(true) { - self.fpm.run_on(&function); Ok(function) } else { function.print_to_stderr(); @@ -2082,23 +2157,25 @@ impl<'ctx> CodeGen<'ctx> { // loop body - copy value from inputs to data let curr_input_index = index.as_basic_value().into_int_value(); - let input_ptr = unsafe { - self.builder - .build_in_bounds_gep(*ptr, &[curr_input_index], name.as_str())? - }; + let input_ptr = Self::get_ptr_to_index( + &self.builder, + self.real_type, + ptr, + curr_input_index, + name.as_str(), + ); let curr_inputs_index = self.builder .build_int_add(inputs_start_index, curr_input_index, name.as_str())?; - let inputs_ptr = unsafe { - self.builder.build_in_bounds_gep( - *self.get_param("inputs"), - &[curr_inputs_index], - name.as_str(), - )? - }; + let inputs_ptr = Self::get_ptr_to_index( + &self.builder, + self.real_type, + self.get_param("inputs"), + curr_inputs_index, + name.as_str(), + ); let input_value = self - .builder - .build_load(inputs_ptr, name.as_str())? + .build_load(self.real_type, inputs_ptr, name.as_str())? .into_float_value(); self.builder.build_store(input_ptr, input_value)?; @@ -2128,7 +2205,6 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_return(None)?; if function.verify(true) { - self.fpm.run_on(&function); Ok(function) } else { function.print_to_stderr(); @@ -2177,10 +2253,13 @@ impl<'ctx> CodeGen<'ctx> { let curr_id_index = self .builder .build_int_add(id_start_index, curr_blk_index, name)?; - let id_ptr = unsafe { - self.builder - .build_in_bounds_gep(*self.get_param("id"), &[curr_id_index], name)? - }; + let id_ptr = Self::get_ptr_to_index( + &self.builder, + self.real_type, + self.get_param("id"), + curr_id_index, + name, + ); let is_algebraic_float = if *is_algebraic { 0.0 as RealType } else { @@ -2213,7 +2292,6 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_return(None)?; if function.verify(true) { - self.fpm.run_on(&function); Ok(function) } else { function.print_to_stderr(); diff --git a/src/execution/compiler.rs b/src/execution/compiler.rs index bbc69f4..1bea08f 100644 --- a/src/execution/compiler.rs +++ b/src/execution/compiler.rs @@ -1,6 +1,8 @@ use anyhow::anyhow; -use inkwell::passes::PassManager; -use inkwell::passes::PassManagerBuilder; +use inkwell::{ + passes::PassBuilderOptions, + targets::{CodeModel, InitializationConfig, RelocMode, Target, TargetMachine}, +}; use std::env; use std::path::Path; use uid::Id; @@ -10,11 +12,10 @@ use crate::parser::parse_ds_string; use crate::utils::find_executable; use crate::utils::find_runtime_path; use anyhow::Result; -use inkwell::targets::TargetMachine; use inkwell::{ context::Context, execution_engine::{ExecutionEngine, JitFunction, UnsafeFunctionPointer}, - targets::{CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetTriple}, + targets::{FileType, TargetTriple}, OptimizationLevel, }; use ouroboros::self_referencing; @@ -159,12 +160,38 @@ impl Compiler { let _get_output = codegen.compile_get_tensor(model, "out")?; // optimise at -O2 no unrolling before giving to enzyme - let builder = PassManagerBuilder::create(); - builder.set_optimization_level(OptimizationLevel::Default); - builder.set_disable_unroll_loops(true); - let pass_manager = PassManager::create(()); - builder.populate_module_pass_manager(&pass_manager); - pass_manager.run_on(codegen.module()); + let pass_options = PassBuilderOptions::create(); + //pass_options.set_verify_each(true); + //pass_options.set_debug_logging(true); + //pass_options.set_loop_interleaving(true); + pass_options.set_loop_vectorization(false); + pass_options.set_loop_slp_vectorization(false); + pass_options.set_loop_unrolling(false); + //pass_options.set_forget_all_scev_in_loop_unroll(true); + //pass_options.set_licm_mssa_opt_cap(1); + //pass_options.set_licm_mssa_no_acc_for_promotion_cap(10); + //pass_options.set_call_graph_profile(true); + //pass_options.set_merge_functions(true); + + let initialization_config = &InitializationConfig::default(); + Target::initialize_all(initialization_config); + let triple = TargetMachine::get_default_triple(); + let target = Target::from_triple(&triple).unwrap(); + let machine = target + .create_target_machine( + &triple, + "generic", //TargetMachine::get_host_cpu_name().to_string().as_str(), + "", //TargetMachine::get_host_cpu_features().to_string().as_str(), + inkwell::OptimizationLevel::Default, + inkwell::targets::RelocMode::Default, + inkwell::targets::CodeModel::Default, + ) + .unwrap(); + + codegen + .module() + .run_passes("default", &machine, pass_options) + .unwrap(); let _rhs_grad = codegen.compile_gradient( _rhs, diff --git a/src/lib.rs b/src/lib.rs index bd39e6a..9eb6923 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,51 +17,27 @@ pub mod execution; pub mod parser; pub mod utils; -#[cfg(feature = "inkwell-100")] -pub extern crate inkwell_100 as inkwell; -#[cfg(feature = "inkwell-110")] -pub extern crate inkwell_110 as inkwell; -#[cfg(feature = "inkwell-120")] -pub extern crate inkwell_120 as inkwell; #[cfg(feature = "inkwell-130")] -pub extern crate inkwell_130 as inkwell; +extern crate inkwell_130 as inkwell; #[cfg(feature = "inkwell-140")] -pub extern crate inkwell_140 as inkwell; +extern crate inkwell_140 as inkwell; #[cfg(feature = "inkwell-150")] -pub extern crate inkwell_150 as inkwell; +extern crate inkwell_150 as inkwell; #[cfg(feature = "inkwell-160")] -pub extern crate inkwell_160 as inkwell; +extern crate inkwell_160 as inkwell; #[cfg(feature = "inkwell-170")] -pub extern crate inkwell_170 as inkwell; -#[cfg(feature = "inkwell-40")] -pub extern crate inkwell_40 as inkwell; -#[cfg(feature = "inkwell-50")] -pub extern crate inkwell_50 as inkwell; -#[cfg(feature = "inkwell-60")] -pub extern crate inkwell_60 as inkwell; -#[cfg(feature = "inkwell-70")] -pub extern crate inkwell_70 as inkwell; -#[cfg(feature = "inkwell-80")] -pub extern crate inkwell_80 as inkwell; -#[cfg(feature = "inkwell-90")] -pub extern crate inkwell_90 as inkwell; - -#[cfg(feature = "inkwell-100")] -pub extern crate llvm_sys_100 as llvm_sys; -#[cfg(feature = "inkwell-110")] -pub extern crate llvm_sys_110 as llvm_sys; -#[cfg(feature = "inkwell-120")] -pub extern crate llvm_sys_120 as llvm_sys; +extern crate inkwell_170 as inkwell; + #[cfg(feature = "inkwell-130")] -pub extern crate llvm_sys_130 as llvm_sys; +extern crate llvm_sys_130 as llvm_sys; #[cfg(feature = "inkwell-140")] -pub extern crate llvm_sys_140 as llvm_sys; +extern crate llvm_sys_140 as llvm_sys; #[cfg(feature = "inkwell-150")] -pub extern crate llvm_sys_150 as llvm_sys; +extern crate llvm_sys_150 as llvm_sys; #[cfg(feature = "inkwell-160")] -pub extern crate llvm_sys_160 as llvm_sys; +extern crate llvm_sys_160 as llvm_sys; #[cfg(feature = "inkwell-170")] -pub extern crate llvm_sys_170 as llvm_sys; +extern crate llvm_sys_170 as llvm_sys; pub struct CompilerOptions { pub bitcode_only: bool,