diff --git a/README.md b/README.md index 0c87783..0e51ce7 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ fn main () -> Result<(), Error> { // use braces to refer to previously defined lambda let mut and_f_t = lambda!({and} {ff} {tt}); - and_f_t.simplify()?; // get simplified result + and_f_t.simplify(true)?; // get simplified result assert_eq!(and_f_t, ff); // parse lambda expression string diff --git a/examples/church_encoding.rs b/examples/church_encoding.rs index 5f08ffc..e58c661 100644 --- a/examples/church_encoding.rs +++ b/examples/church_encoding.rs @@ -5,15 +5,17 @@ use lamcalc::{lambda, Error}; fn main() -> Result<(), Error> { let zero = lambda!(s. (z. z)); let suc = lambda!(n. s. z. s (n s z)); - let plus = lambda!(n. m. n {suc} m).simplify()?.to_owned(); + let plus = lambda!(n. m. n {suc} m).simplify(true)?.to_owned(); let mut nats = vec![zero]; for i in 1..10 { - let sx = lambda!({suc} {nats[i - 1]}).simplify()?.to_owned(); + let sx = lambda!({suc} {nats[i - 1]}).simplify(true)?.to_owned(); nats.push(sx); } - let sum = lambda!({plus} {nats[4]} {nats[3]}).simplify()?.to_owned(); + let sum = lambda!({plus} {nats[4]} {nats[3]}) + .simplify(true)? + .to_owned(); assert_eq!(sum, nats[7]); Ok(()) diff --git a/examples/parser.rs b/examples/parser.rs index 388357c..732f926 100644 --- a/examples/parser.rs +++ b/examples/parser.rs @@ -21,10 +21,10 @@ fn main() -> Result<(), Error> { "##, )?; - let and_t_f = lambda!({map["and"]} {tt} {ff}).simplify()?.to_owned(); + let and_t_f = lambda!({map["and"]} {tt} {ff}).simplify(true)?.to_owned(); assert_eq!(and_t_f, ff); - let or_t_f = lambda!({map["or"]} {tt} {ff}).simplify()?.to_owned(); + let or_t_f = lambda!({map["or"]} {tt} {ff}).simplify(true)?.to_owned(); assert_eq!(or_t_f, tt); Ok(()) diff --git a/examples/y_combinator.rs b/examples/y_combinator.rs index 9ee7132..9b57913 100644 --- a/examples/y_combinator.rs +++ b/examples/y_combinator.rs @@ -9,21 +9,21 @@ fn main() -> Result<(), Error> { let prev = lambda!(n. f. x. n (g. h. h (g f)) (u. x) (u. u)); let mut nats = vec![zero]; for i in 1..10 { - let sx = lambda!({suc} {nats[i - 1]}).simplify()?.to_owned(); + let sx = lambda!({suc} {nats[i - 1]}).simplify(true)?.to_owned(); nats.push(sx); assert_eq!( - lambda!({prev} {nats[i]}).simplify()?.to_string(), + lambda!({prev} {nats[i]}).simplify(true)?.to_string(), nats[i - 1].to_string() ); } // utilities let mul = lambda!(n. m. f. x. n (m f) x); - let if_n_is_zero = lambda!(n. n (w. x. y. y) (x. y. x)); + let if_n_is_zero = lambda!(n. (n (w. x. y. y)) (x. y. x)); assert_eq!( lambda!({if_n_is_zero} {nats[0]} {nats[2]} {nats[1]} ) - .simplify()? + .simplify(true)? .purify(), nats[2].purify() ); @@ -35,18 +35,21 @@ fn main() -> Result<(), Error> { let mut fact = lambda!(y. n. {if_n_is_zero} n (f. x. f x) ({mul} n (y ({prev} n)))); eprintln!("simplify fact"); - while fact.eval_normal_order(true) { + while fact.eval_normal_order(true, true) { eprintln!("fact = {}", fact); } let y_fact = lambda!({y} {fact}); - let res = lambda!({y_fact} {nats[3]}).purify().simplify()?.to_owned(); + let res = lambda!({y_fact} {nats[3]}) + .purify() + .simplify(true)? + .to_owned(); eprintln!("{}", res); assert_eq!(res, nats[6].purify()); // if you try to simplify Y combinator ... - eprintln!("simplify y: {}", y.simplify().unwrap_err()); // lamcalc::Error::SimplifyLimitExceeded + eprintln!("simplify y: {}", y.simplify(true).unwrap_err()); // lamcalc::Error::SimplifyLimitExceeded Ok(()) } diff --git a/src/eval.rs b/src/eval.rs index eed1199..aa72849 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -45,11 +45,11 @@ where { /// Simplify repeatedly using beta-reduction in normal order /// for at most [`SIMPLIFY_LIMIT`] times. - pub fn simplify(&mut self) -> Result<&mut Self, Error> { + pub fn simplify(&mut self, optimize: bool) -> Result<&mut Self, Error> { #[cfg(feature = "experimental")] GLOBAL_PROFILE.lock().unwrap().reset_counter(); for _ in 0..SIMPLIFY_LIMIT { - if !self.eval_normal_order(false) { + if !self.eval_normal_order(false, optimize) { return Ok(self); } } @@ -61,9 +61,22 @@ where /// the body of an abstraction before the arguments are reduced. /// /// return `false` if nothing changes, otherwise `true`. - pub fn eval_normal_order(&mut self, eta_reduce: bool) -> bool { + pub fn eval_normal_order(&mut self, eta_reduce: bool, optimize: bool) -> bool { #[cfg(feature = "experimental")] GLOBAL_PROFILE.lock().unwrap().inc_eval_fn_counter(); + + if optimize { + if self.try_add_opt_1() { + return true; + } + if self.try_add_opt_2() { + return true; + } + if self.try_mul_opt() { + return true; + } + } + if self.beta_reduce() { #[cfg(feature = "experimental")] GLOBAL_PROFILE.lock().unwrap().inc_beta_counter(); @@ -76,25 +89,230 @@ where } match self { Exp::Var(_) => false, - Exp::Abs(_, body) => body.eval_normal_order(eta_reduce), + Exp::Abs(_, body) => body.eval_normal_order(eta_reduce, optimize), Exp::App(l, body) => { - if l.eval_normal_order(eta_reduce) { + if l.eval_normal_order(eta_reduce, optimize) { true } else { - body.eval_normal_order(eta_reduce) + body.eval_normal_order(eta_reduce, optimize) } } } } } +// Church encoding optimization +mod optimize { + use crate::{lambda, Exp}; + + const PURE_MUL: std::cell::OnceCell> = std::cell::OnceCell::new(); + + impl Exp + where + T: Clone + Eq, + { + fn try_into_church_num(&self) -> Option<(u64, T, T)> { + let (f, body) = self.into_abs()?; + let (x, mut body) = body.into_abs()?; + let mut val = 0; + while let Some((func, app_body)) = body.into_app() { + let f = func.into_ident()?; + if f.1 != 2 { + return None; + } + val += 1; + body = app_body; + } + if body.into_ident()?.1 != 1 { + return None; + } + Some((val, f.0.clone(), x.0.clone())) + } + fn from_church_num(num: u64, f: T, x: T) -> Self { + let mut cur = Exp::Var(crate::Ident(x.clone(), 1)); + for _ in 0..num { + cur = Exp::App( + Box::new(Exp::Var(crate::Ident(f.clone(), 2))), + Box::new(cur), + ); + } + cur = Exp::Abs(crate::Ident(x, 0), Box::new(cur)); + cur = Exp::Abs(crate::Ident(f, 0), Box::new(cur)); + cur + } + fn is_add(&self) -> bool { + let add = lambda!(n. m. f. x. n f (m f x)).purify(); + self.purify() == add + } + /// Check if the function is `add k` + fn is_add_k(&self) -> Option { + let (_m, body) = self.into_abs()?; + let (_f, body) = body.into_abs()?; + let (_x, mut body) = body.into_abs()?; + let mut val = 0; + while let Some((func, app_body)) = body.into_app() { + if let Some(f) = func.into_ident() { + if f.1 != 2 { + return None; + } + val += 1; + body = app_body; + } else if let Some((m1, f1)) = func.into_app() { + let m1 = m1.into_ident()?; + let f1 = f1.into_ident()?; + if m1.1 != 3 || f1.1 != 2 { + return None; + } else { + return Some(val); + } + } else { + return None; + } + } + None + } + fn is_mul(&self) -> bool { + let binding = PURE_MUL; + let value = binding.get_or_init(|| lambda!(n. m. f. x. n (m f) x).purify()); + &self.purify() == value + } + } + impl Exp + where + T: Clone + Eq, + { + /// Try to apply add. optimization, return false if nothing changed + /// + /// The target form is `(add a) b` + pub fn try_add_opt_1(&mut self) -> bool { + let mut inner = || { + let (add, a) = self.into_app_mut()?; + if !add.is_add() { + return None; + } + + let (va, f, x) = a.try_into_church_num()?; + let m = add.into_abs()?.1.into_abs()?.0 .0.clone(); + + eprintln!("add opt 1: va = {va}"); + // construct add a + let mut cur = Exp::App( + Box::new(Exp::Var(crate::Ident(m.clone(), 3))), + Box::new(Exp::Var(crate::Ident(f.clone(), 2))), + ); + cur = Exp::App( + Box::new(cur), + Box::new(Exp::Var(crate::Ident(x.clone(), 1))), + ); + for _ in 0..va { + cur = Exp::App( + Box::new(Exp::Var(crate::Ident(f.clone(), 2))), + Box::new(cur), + ); + } + cur = Exp::Abs(crate::Ident(x, 0), Box::new(cur)); + cur = Exp::Abs(crate::Ident(f, 0), Box::new(cur)); + cur = Exp::Abs(crate::Ident(m, 0), Box::new(cur)); + + *self = cur; + Some(()) + }; + inner().is_some() + } + + /// Try to apply add. optimization, return false if nothing changed + /// + /// The target form is `(add a) b` + pub fn try_add_opt_2(&mut self) -> bool { + let mut inner = || { + let (add_a, b) = self.into_app_mut()?; + let vb = b.try_into_church_num()?; + // eprintln!("try add opt: match b = {}", vb.0); + // eprintln!("try add opt: add_a = {}", add_a.purify()); + let va = add_a.is_add_k()?; + eprintln!("add opt: va = {va}, vb = {}", vb.0); + let s = Self::from_church_num(va + vb.0, vb.1, vb.2); + *self = s; + Some(()) + }; + inner().is_some() + } + /// Try to apply add. optimization, return false if nothing changed + /// + /// The target form is `(add a) b` + pub fn try_mul_opt(&mut self) -> bool { + let mut inner = || { + let (mul_a, b) = self.into_app_mut()?; + let vb = b.try_into_church_num()?; + let (mul, a) = mul_a.into_app_mut()?; + if !mul.is_mul() { + return None; + } + let va = a.try_into_church_num()?; + let s = Self::from_church_num(va.0 * vb.0, va.1, va.2); + *self = s; + Some(()) + }; + inner().is_some() + } + } + + #[cfg(test)] + mod tests { + use crate::{lambda, Exp}; + + #[test] + fn test_church() { + assert!( + lambda!(f. x. f (f (f x))) + .purify() + .try_into_church_num() + .unwrap() + .0 + == 3 + ); + assert!(lambda!(f. x. f (f (x x))) + .purify() + .try_into_church_num() + .is_none()); + + assert!( + Exp::from_church_num(10, (), ()) + .try_into_church_num() + .unwrap() + .0 + == 10 + ); + + let add = lambda!(n. m. f. x. n f (m f x)); + for i in 0..5 { + let mut e = lambda!({add} {Exp::from_church_num(i, "f".into(), "x".into())}); + e.simplify(true).unwrap(); + eprintln!("{}", e); + } + } + #[test] + fn test_add_opt() { + let a = Exp::from_church_num(10, "f", "x").to_string_exp(); + let b = Exp::from_church_num(15, "f", "x").to_string_exp(); + let add = lambda!(n. m. f. x. n f (m f x)); + // let add_a = lambda!({add} {a}).simplify(false).unwrap().to_owned(); + let mut e = lambda!({add} {a} {b}); + eprintln!("{}", e); + while e.eval_normal_order(false, true) { + eprintln!("{}", e); + } + assert_eq!(e.try_into_church_num().unwrap().0, 25) + } + } +} + #[cfg(test)] mod tests { #[test] #[cfg(feature = "experimental")] fn bench_pred() -> Result<(), crate::Error> { use crate::eval::GLOBAL_PROFILE; - use crate::lambda; let suc = lambda!(n. f. x. f (n f x)); let prev = lambda!(n. f. x. n (g. h. h (g f)) (u. x) (u. u)); diff --git a/src/exp.rs b/src/exp.rs index 5de580e..3fe2d8a 100644 --- a/src/exp.rs +++ b/src/exp.rs @@ -241,6 +241,26 @@ impl Exp { } } +impl Exp +where + T: Clone + Eq + ToString, +{ + /// Transform arbitrary expression to string named expression + pub fn to_string_exp(&self) -> Exp { + match self { + Exp::Var(Ident(name, code)) => Exp::Var(Ident(name.to_string(), *code)), + Exp::Abs(Ident(name, code), body) => Exp::Abs( + Ident(name.to_string(), *code), + Box::new(body.to_string_exp()), + ), + Exp::App(func, body) => Exp::App( + Box::new(func.to_string_exp()), + Box::new(body.to_string_exp()), + ), + } + } +} + impl std::fmt::Display for Exp { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -345,7 +365,7 @@ mod tests { fn test_subst_unbounded() -> Result<(), Error> { let mut exp = lambda!(x. y. f x y); exp.subst_unbounded(&String::from("f"), &lambda!(x. (y. z))); - exp.simplify()?; + exp.simplify(false)?; assert_eq!(exp, lambda!(x. (y. z))); Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 0a8bbac..b5ed35d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,7 +108,7 @@ mod tests { let mut res = lambda!({and} {tt} {tt}); println!("res = {}", res); - while res.eval_normal_order(false) { + while res.eval_normal_order(false, false) { println!("res = {}", res); } assert_eq!(res.to_string(), "λx. λy. x"); @@ -118,18 +118,18 @@ mod tests { let zero = lambda!(s. (z. z)); let suc = lambda!(n. s. z. s (n s z)); let mut plus = lambda!(n. m. n {suc} m); - plus.simplify()?; + plus.simplify(true)?; let mut nats = vec![zero]; for i in 1..10 { let x = nats.last().unwrap(); let mut sx = lambda!({suc} {x}); - sx.simplify()?; + sx.simplify(true)?; eprintln!("{} = {}", i, sx.purify()); nats.push(sx); } let mut test = lambda!({plus} {nats[4]} {nats[3]}); - test.simplify()?; + test.simplify(true)?; println!("test = {:#}", test); assert_eq!(test.to_string(), nats[7].to_string()); diff --git a/src/parser.rs b/src/parser.rs index 1b9a571..29dc739 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -359,9 +359,9 @@ mod tests { "((λx. λy. ((((((x 即) 是) y) y) 即) 是) x) 色) 空" ); eprintln!("{:#}", exp); - exp.eval_normal_order(false); + exp.eval_normal_order(false, false); eprintln!("{:#}", exp); - exp.simplify()?; + exp.simplify(false)?; assert_eq!(exp.to_string(), "((((((色 即) 是) 空) 空) 即) 是) 色"); Ok(())