Skip to content

Commit

Permalink
add sigmoid function
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Sep 13, 2023
1 parent a1ee78e commit 2221d26
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
50 changes: 34 additions & 16 deletions src/codegen/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,38 +148,56 @@ impl<'ctx> CodeGen<'ctx> {
fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
match self.functions.get(name) {
Some(&func) => Some(func),
// support some llvm intrinsics
None => {
match name {
let function = match name {
// support some llvm intrinsics
"sin" | "cos" | "tan" | "pow" | "exp" | "log" | "sqrt" | "abs" => {
let arg_len = 1;
let llvm_name = format!("llvm.{}.{}", name, self.real_type_str);
let intrinsic = Intrinsic::find(&llvm_name).unwrap();
let ret_type = self.real_type;

let args_types = std::iter::repeat(ret_type)
.take(arg_len)
.map(|f| f.into())
.collect::<Vec<BasicTypeEnum>>();
intrinsic.get_declaration(&self.module, args_types.as_slice())
},
// some custom functions
"sigmoid" => {
let arg_len = 1;
let ret_type = self.real_type;

let args_types = std::iter::repeat(ret_type)
.take(arg_len)
.map(|f| f.into())
.collect::<Vec<BasicMetadataTypeEnum>>();
let args_types = args_types.as_slice();
let fn_type = ret_type.fn_type(args_types, false);

let fn_type = ret_type.fn_type(args_types.as_slice(), false);
let fn_val = self.module.add_function(name, fn_type, None);

for (_, arg) in fn_val.get_param_iter().enumerate() {
arg.into_float_value().set_name("x");
}

let args_types = std::iter::repeat(ret_type)
.take(arg_len)
.map(|f| f.into())
.collect::<Vec<BasicTypeEnum>>();
let args_types = args_types.as_slice();
let function = intrinsic.get_declaration(&self.module, args_types).unwrap();

self.functions.insert(name.to_owned(), function)
let current_block = self.builder.get_insert_block().unwrap();
let basic_block = self.context.append_basic_block(fn_val, "entry");
self.builder.position_at_end(basic_block);
let x = fn_val.get_nth_param(0)?.into_float_value();
let one = self.real_type.const_float(1.0);
let negx = self.builder.build_float_neg(x, name);
let exp = self.get_function("exp").unwrap();
let exp_negx = self.builder.build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name);
let one_plus_exp_negx = self.builder.build_float_add(exp_negx.try_as_basic_value().left().unwrap().into_float_value(), one, name);
let sigmoid = self.builder.build_float_div(one, one_plus_exp_negx, name);
self.builder.build_return(Some(&sigmoid));
self.builder.position_at_end(current_block);
Some(fn_val)
},
_ => None,
}

_ => None
}?;
self.functions.insert(name.to_owned(), function);
Some(function)
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ mod tests {
}

tensor_test!{
exp_function: "r { exp(2) }" expect "r" vec![f64::exp(2.0)],
sigmoid_function: "r { sigmoid(0.1) }" expect "r" vec![1.0 / (1.0 + f64::exp(-0.1))],
scalar: "r {2}" expect "r" vec![2.0,],
constant: "r_i {2, 3}" expect "r" vec![2., 3.],
expression: "r_i {2 + 3, 3 * 2}" expect "r" vec![5., 6.],
Expand Down

0 comments on commit 2221d26

Please sign in to comment.