Skip to content

Commit

Permalink
Add ASM for double evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Jul 28, 2024
1 parent a980382 commit b9b43dc
Showing 1 changed file with 123 additions and 3 deletions.
126 changes: 123 additions & 3 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
res += &format!("\tZ[{}] = {};\n", i, self.stack[i]);
}

Self::export_asm_impl(&self.instructions, &mut res);
Self::export_asm_complex_impl(&self.instructions, &mut res);

for (i, r) in &mut self.result_indices.iter().enumerate() {
res += &format!("\tout[{}] = Z[{}];\n", i, r);
Expand All @@ -989,14 +989,134 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
res += "\treturn;\n}\n";

res += &format!(
"\n\tvoid {}_double(double* params, double* out) {{\n\t\treturn;\n\t}}\n}}\n",
"\n\tvoid {}_double(double* params, double* out) {{\n",
function_name
);

res += &format!("\tdouble Z[{}];\n", self.stack.len());

for i in 0..self.param_count {
res += &format!("\tZ[{}] = params[{}];\n", i, i);
}

for i in self.param_count..self.reserved_indices {
res += &format!("\tZ[{}] = {};\n", i, self.stack[i]);
}

Self::export_asm_double_impl(&self.instructions, &mut res);

for (i, r) in &mut self.result_indices.iter().enumerate() {
res += &format!("\tout[{}] = Z[{}];\n", i, r);
}

res += "\treturn;\n}\n";

res += "}\n";

res
}

fn export_asm_impl(instr: &[Instr], out: &mut String) {
fn export_asm_double_impl(instr: &[Instr], out: &mut String) {
let mut in_asm_block = false;
for ins in instr {
match ins {
Instr::Add(o, a) => {
if !in_asm_block {
*out += "\t__asm__(\n";
in_asm_block = true;
}

*out += &format!("\t\t\"movsd xmm0, QWORD PTR [%0+{}]\\n\\t\"\n", a[0] * 8);

// TODO: try loading in multiple registers for better instruction-level parallelism?
for i in &a[1..] {
*out += &format!("\t\t\"addsd xmm0, QWORD PTR [%0+{}]\\n\\t\"\n", *i * 8);
}
*out += &format!("\t\t\"movsd QWORD PTR [%0+{}], xmm0\\n\\t\"\n", *o * 8,);
}
Instr::Mul(o, a) => {
if !in_asm_block {
*out += "\t__asm__(\n";
in_asm_block = true;
}

*out += &format!("\t\t\"movsd xmm0, QWORD PTR [%0+{}]\\n\\t\"\n", a[0] * 8);

for i in &a[1..] {
*out += &format!("\t\t\"mulsd xmm0, QWORD PTR [%0+{}]\\n\\t\"\n", *i * 8);
}
*out += &format!("\t\t\"movsd QWORD PTR [%0+{}], xmm0\\n\\t\"\n", *o * 8,);
}
Instr::Pow(o, b, e) => {
if in_asm_block {
*out += ":
: \"r\"(Z)
: \"memory\");
";
in_asm_block = false;
}

let base = format!("Z[{}]", b);
*out += format!("\tZ[{}] = pow({}, {});\n", o, base, e).as_str();
}
Instr::Powf(o, b, e) => {
if in_asm_block {
*out += ":
: \"r\"(Z)
: \"memory\");
";
in_asm_block = false;
}

let base = format!("Z[{}]", b);
let exp = format!("Z[{}]", e);
*out += format!("\tZ[{}] = pow({}, {});\n", o, base, exp).as_str();
}
Instr::BuiltinFun(o, s, a) => {
if in_asm_block {
*out += ":
: \"r\"(Z)
: \"memory\");
";
in_asm_block = false;
}

match *s {
State::EXP => {
let arg = format!("Z[{}]", a);
*out += format!("\tZ[{}] = exp({});\n", o, arg).as_str();
}
State::LOG => {
let arg = format!("Z[{}]", a);
*out += format!("\tZ[{}] = log({});\n", o, arg).as_str();
}
State::SIN => {
let arg = format!("Z[{}]", a);
*out += format!("\tZ[{}] = sin({});\n", o, arg).as_str();
}
State::COS => {
let arg = format!("Z[{}]", a);
*out += format!("\tZ[{}] = cos({});\n", o, arg).as_str();
}
State::SQRT => {
let arg = format!("Z[{}]", a);
*out += format!("\tZ[{}] = sqrt({});\n", o, arg).as_str();
}
_ => unreachable!(),
}
}
}
}

if in_asm_block {
*out += ":
: \"r\"(Z)
: \"memory\");
";
}
}

fn export_asm_complex_impl(instr: &[Instr], out: &mut String) {
let mut in_asm_block = false;
for ins in instr {
match ins {
Expand Down

0 comments on commit b9b43dc

Please sign in to comment.