Skip to content

Commit

Permalink
factor our compile function
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Aug 25, 2023
1 parent 99f17d0 commit 3e1048a
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 84 deletions.
83 changes: 7 additions & 76 deletions src/bin/diffeq.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@

use std::{path::Path, ffi::OsStr, process::Command};

use clap::Parser;
use anyhow::{Result, anyhow};
use diffeq::{parser::{parse_ms_string, parse_ds_string}, continuous::ModelInfo, discretise::DiscreteModel, codegen::Compiler};
use anyhow::Result;
use diffeq::compile;

/// compiles a model in continuous (.cs) or discrete (.ds) format to an object file
#[derive(Parser, Debug)]
Expand All @@ -23,80 +20,14 @@ struct Args {
/// Compile object file only
#[arg(short, long)]
compile: bool,

/// Compile to WASM
#[arg(short, long)]
wasm: bool,
}

fn main() -> Result<()> {
let cli = Args::parse();

let inputfile = Path::new(&cli.input);
let out = if let Some(out) = cli.out {
out.clone()
} else if cli.compile {
"out.o".to_owned()
} else {
"out".to_owned()
};

let objectname = if cli.compile { out.clone() } else { format!("{}.o", out) };
let objectfile = Path::new(objectname.as_str());
let is_discrete = inputfile.extension().unwrap_or(OsStr::new("")).to_str().unwrap() == "ds";
let is_continuous = inputfile.extension().unwrap_or(OsStr::new("")).to_str().unwrap() == "cs";
if !is_discrete && !is_continuous {
panic!("Input file must have extension .ds or .cs");
}
let text = std::fs::read_to_string(inputfile)?;
if is_continuous {
let models = parse_ms_string(text.as_str())?;
let model_name = if let Some(model_name) = cli.model {
model_name
} else {
return Err(anyhow!("Model name must be specified for continuous models"));
};
let model_info = ModelInfo::build(model_name.as_str(), &models).map_err(|e| anyhow!("{}", e))?;
if model_info.errors.len() > 0 {
for error in model_info.errors {
println!("{}", error.as_error_message(text.as_str()));
}
return Err(anyhow!("Errors in model"));
}
let model = DiscreteModel::from(&model_info);
let compiler = Compiler::from_discrete_model(&model)?;
compiler.write_object_file(objectfile)?;
} else {
let model = parse_ds_string(text.as_str())?;
let model = match DiscreteModel::build(&cli.input, &model) {
Ok(model) => model,
Err(e) => {
println!("{}", e.as_error_message(text.as_str()));
return Err(anyhow!("Errors in model"));
}
};
let compiler = Compiler::from_discrete_model(&model)?;
compiler.write_object_file(objectfile)?;
};

if cli.compile {
return Ok(());
}

// compile the object file using clang and our runtime library
let output = Command::new("clang")
.arg("-o")
.arg(out)
.arg(objectname.clone())
.arg("-ldiffeq_runtime")
.output()?;

// clean up the object file
std::fs::remove_file(objectfile)?;

if let Some(code) = output.status.code() {
if code != 0 {
println!("{}", String::from_utf8_lossy(&output.stderr));
return Err(anyhow!("clang returned error code {}", code));
}
}
Ok(())

compile(&cli.input, cli.out.as_deref(), cli.model.as_deref(), cli.compile, cli.wasm)
}

4 changes: 2 additions & 2 deletions src/codegen/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,10 +792,10 @@ impl<'ctx> CodeGen<'ctx> {
let int_ptr_type = self.context.i32_type().ptr_type(AddressSpace::default());
let void_type = self.context.void_type();
let fn_type = void_type.fn_type(
&[self.real_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into(), int_ptr_type.into(), real_ptr_type.into()]
&[self.real_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into(), int_ptr_type.into()]
, false
);
let fn_arg_names = &["t", "u", "dudt", "data", "indices", "out"];
let fn_arg_names = &["t", "u", "dudt", "data", "indices"];
let function = self.module.add_function("calc_out", fn_type, None);
let basic_block = self.context.append_basic_block(function, "entry");
self.fn_value_opt = Some(function);
Expand Down
43 changes: 37 additions & 6 deletions src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::path::Path;
use anyhow::anyhow;

use anyhow::Result;
use inkwell::targets::TargetMachine;
use inkwell::{context::Context, OptimizationLevel, targets::{TargetTriple, InitializationConfig, Target, RelocMode, CodeModel, FileType}, execution_engine::{JitFunction, ExecutionEngine, UnsafeFunctionPointer}};
use ouroboros::self_referencing;
use crate::discretise::DiscreteModel;
Expand Down Expand Up @@ -225,23 +226,53 @@ impl Compiler {
})
}

pub fn write_object_file(&self, path: &Path) -> Result<()> {
Target::initialize_x86(&InitializationConfig::default());
fn get_native_machine() -> Result<TargetMachine> {
Target::initialize_native(&InitializationConfig::default()).map_err(|e| anyhow!("{}", e))?;
let opt = OptimizationLevel::Default;
let reloc = RelocMode::Default;
let model = CodeModel::Default;
let target_triple = TargetMachine::get_default_triple();
let target = Target::from_triple(&target_triple).unwrap();
let target_machine = target.create_target_machine(
&target_triple,
TargetMachine::get_host_cpu_name().to_str().unwrap(),
TargetMachine::get_host_cpu_features().to_str().unwrap(),
opt,
reloc,
model
)
.unwrap();
Ok(target_machine)
}

fn get_wasm_machine() -> Result<TargetMachine> {
Target::initialize_webassembly(&InitializationConfig::default());
let opt = OptimizationLevel::Default;
let reloc = RelocMode::Default;
let model = CodeModel::Default;
let target = Target::from_name("x86-64").unwrap();
let target_triple = TargetTriple::create("wasm32-unknown-emscripten");
let target = Target::from_triple(&target_triple).unwrap();
let target_machine = target.create_target_machine(
&TargetTriple::create("x86_64-pc-linux-gnu"),
"x86-64",
"+avx2",
&target_triple,
"generic",
"",
opt,
reloc,
model
)
.unwrap();
Ok(target_machine)
}

pub fn write_object_file(&self, path: &Path) -> Result<()> {
let target_machine = Compiler::get_native_machine()?;
self.with_data(|data|
target_machine.write_to_file(data.codegen.module(), FileType::Object, &path).map_err(|e| anyhow::anyhow!("Error writing object file: {:?}", e))
)
}

pub fn write_wasm_object_file(&self, path: &Path) -> Result<()> {
let target_machine = Compiler::get_wasm_machine()?;
self.with_data(|data|
target_machine.write_to_file(data.codegen.module(), FileType::Object, &path).map_err(|e| anyhow::anyhow!("Error writing object file: {:?}", e))
)
Expand Down
109 changes: 109 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
use std::{path::Path, ffi::OsStr, process::Command};
use anyhow::{Result, anyhow};
use codegen::Compiler;
use continuous::ModelInfo;
use discretise::DiscreteModel;
use parser::{parse_ms_string, parse_ds_string};

extern crate pest;
#[macro_use]
extern crate pest_derive;
Expand All @@ -9,5 +16,107 @@ pub mod continuous;
pub mod codegen;


pub fn compile(input: &str, out: Option<&str>, model: Option<&str>, compile: bool, wasm: bool) -> Result<()> {
let inputfile = Path::new(input);
let out = if let Some(out) = out {
out.clone()
} else if compile {
"out.o"
} else {
"out"
};

let objectname = if compile { out.to_owned() } else { format!("{}.o", out) };
let objectfile = Path::new(objectname.as_str());
let is_discrete = inputfile.extension().unwrap_or(OsStr::new("")).to_str().unwrap() == "ds";
let is_continuous = inputfile.extension().unwrap_or(OsStr::new("")).to_str().unwrap() == "cs";
if !is_discrete && !is_continuous {
panic!("Input file must have extension .ds or .cs");
}

let model_name = if is_continuous {
if let Some(model_name) = model {
model_name
} else {
return Err(anyhow!("Model name must be specified for continuous models"));
}
} else {
""
};

let text = std::fs::read_to_string(inputfile)?;
let continuous_ast = if is_continuous {
Some(parse_ms_string(text.as_str())?)
} else {
None
};

let discrete_ast = if is_discrete {
Some(parse_ds_string(text.as_str())?)
} else {
None
};

let continuous_model_info = if let Some(ast) = &continuous_ast {
let model_info = ModelInfo::build(model_name, ast).map_err(|e| anyhow!("{}", e))?;
if model_info.errors.len() > 0 {
for error in model_info.errors {
println!("{}", error.as_error_message(text.as_str()));
}
return Err(anyhow!("Errors in model"));
}
Some(model_info)
} else {
None
};

let discrete_model = if let Some(model_info) = &continuous_model_info {
let model = DiscreteModel::from(&model_info);
model
} else if let Some(ast) = &discrete_ast {
match DiscreteModel::build(input, ast) {
Ok(model) => model,
Err(e) => {
println!("{}", e.as_error_message(text.as_str()));
return Err(anyhow!("Errors in model"));
}
}
} else {
panic!("No model found");
};

let compiler = Compiler::from_discrete_model(&discrete_model)?;

if wasm {
compiler.write_wasm_object_file(objectfile)?;
} else {
compiler.write_object_file(objectfile)?;
}

if compile {
return Ok(());
}

// compile the object file using clang and our runtime library
let output = Command::new("clang")
.arg("-o")
.arg(out)
.arg(objectname.clone())
.arg("-ldiffeq_runtime")
.output()?;

// clean up the object file
std::fs::remove_file(objectfile)?;

if let Some(code) = output.status.code() {
if code != 0 {
println!("{}", String::from_utf8_lossy(&output.stderr));
return Err(anyhow!("clang returned error code {}", code));
}
}
Ok(())
}




0 comments on commit 3e1048a

Please sign in to comment.