From a551c225aad52d2bbfde8988834083163cb6568c Mon Sep 17 00:00:00 2001 From: Danny McGee Date: Tue, 14 May 2024 21:22:24 -0400 Subject: [PATCH] Logical binary expression eval --- packages/server/src/pre/interpreter.rs | 437 +++++++++++++++++++++++-- packages/server/src/pre/mod.rs | 47 +-- 2 files changed, 418 insertions(+), 66 deletions(-) diff --git a/packages/server/src/pre/interpreter.rs b/packages/server/src/pre/interpreter.rs index 38379631..6563e586 100644 --- a/packages/server/src/pre/interpreter.rs +++ b/packages/server/src/pre/interpreter.rs @@ -138,6 +138,24 @@ impl TryFrom<&Token> for PostOp { } } +impl From for BinOp { + fn from(value: BinMath) -> Self { + BinOp::Math(value) + } +} + +impl From for BinOp { + fn from(value: BinBitwise) -> Self { + BinOp::Bitwise(value) + } +} + +impl From for BinOp { + fn from(value: BinLogic) -> Self { + BinOp::Logical(value) + } +} + pub struct Interpreter { source: ArcStr, defs: Arc>, @@ -145,6 +163,22 @@ pub struct Interpreter { } impl Interpreter { + pub fn new(source: ArcStr) -> Self { + Self { + source, + defs: Default::default(), + result: None, + } + } + + pub fn with_defs(source: ArcStr, defs: Arc>) -> Self { + Self { + source, + defs, + result: None, + } + } + fn eval(&self, expr: &Expr) -> gramatika::Result { let mut fork = self.fork(); expr.walk(&mut fork); @@ -159,6 +193,155 @@ impl Interpreter { result: None, } } + + fn eval_binary_expr( + &self, + lhs: Value, + op: BinOp, + rhs: Value, + op_token: &Token, + ) -> gramatika::Result { + use Value::*; + + match (lhs, op, rhs) { + // lhs ( `&&` | `||` | `==` | `!=` | `<=` | `<` | `>=` | `>` ) rhs + (lhs, BinOp::Logical(op), rhs) => { + use BinLogic::*; + match (lhs, op, rhs) { + // bool ( `&&` | `||` ) bool + (Bool(lhs), And, Bool(rhs)) => Ok(Bool(lhs && rhs)), + (Bool(lhs), Or, Bool(rhs)) => Ok(Bool(lhs || rhs)), + // any ( `&&` | `||` ) any + (lhs, And, rhs) | (lhs, Or, rhs) => { + let lhs = coerce_bool(lhs); + let rhs = coerce_bool(rhs); + self.eval_binary_expr(Bool(lhs), op.into(), Bool(rhs), op_token) + } + // bool ( `==` | `!=` ) bool + (Bool(lhs), Eq, Bool(rhs)) => Ok(Bool(lhs == rhs)), + (Bool(lhs), NotEq, Bool(rhs)) => Ok(Bool(lhs != rhs)), + // bool ( `<=` | `<` | `>=` | `>` ) rhs + // lhs ( `<=` | `<` | `>=` | `>` ) bool + (Bool(_), LessEq | Less | GreaterEq | Greater, _) + | (_, LessEq | Less | GreaterEq | Greater, Bool(_)) => Err(SpannedError { + message: format!("Cannot use operator {} to compare a boolean", op_token), + source: self.source.clone(), + span: Some(op_token.span()), + }), + // Happy-path numerics + // == + (Int(lhs), Eq, Int(rhs)) => Ok(Bool(lhs == rhs)), + (Uint(lhs), Eq, Uint(rhs)) => Ok(Bool(lhs == rhs)), + (Float(lhs), Eq, Float(rhs)) => Ok(Bool(lhs == rhs)), + // != + (Int(lhs), NotEq, Int(rhs)) => Ok(Bool(lhs != rhs)), + (Uint(lhs), NotEq, Uint(rhs)) => Ok(Bool(lhs != rhs)), + (Float(lhs), NotEq, Float(rhs)) => Ok(Bool(lhs != rhs)), + // <= + (Int(lhs), LessEq, Int(rhs)) => Ok(Bool(lhs <= rhs)), + (Uint(lhs), LessEq, Uint(rhs)) => Ok(Bool(lhs <= rhs)), + (Float(lhs), LessEq, Float(rhs)) => Ok(Bool(lhs <= rhs)), + // < + (Int(lhs), Less, Int(rhs)) => Ok(Bool(lhs < rhs)), + (Uint(lhs), Less, Uint(rhs)) => Ok(Bool(lhs < rhs)), + (Float(lhs), Less, Float(rhs)) => Ok(Bool(lhs < rhs)), + // >= + (Int(lhs), GreaterEq, Int(rhs)) => Ok(Bool(lhs >= rhs)), + (Uint(lhs), GreaterEq, Uint(rhs)) => Ok(Bool(lhs >= rhs)), + (Float(lhs), GreaterEq, Float(rhs)) => Ok(Bool(lhs >= rhs)), + // > + (Int(lhs), Greater, Int(rhs)) => Ok(Bool(lhs > rhs)), + (Uint(lhs), Greater, Uint(rhs)) => Ok(Bool(lhs > rhs)), + (Float(lhs), Greater, Float(rhs)) => Ok(Bool(lhs > rhs)), + // Numerics requiring coercion + // Int == rhs + (Int(lhs), Eq, Uint(rhs)) => { + if lhs < 0 { + Ok(Bool(false)) + } else { + Ok(Bool((lhs as usize) == rhs)) + } + } + (Int(lhs), Eq, Float(rhs)) => { + if rhs.fract().abs() == 0.0 { + Ok(Bool(lhs == (rhs.trunc() as isize))) + } else { + Ok(Bool(false)) + } + } + // Uint == rhs + (Uint(lhs), Eq, Int(rhs)) => { + if rhs < 0 { + Ok(Bool(false)) + } else { + Ok(Bool(lhs == (rhs as usize))) + } + } + (Uint(lhs), Eq, Float(rhs)) => { + if rhs < 0.0 { + Ok(Bool(false)) + } else if rhs.fract() == 0.0 { + Ok(Bool(lhs == (rhs.trunc() as usize))) + } else { + Ok(Bool(false)) + } + } + // other == other + (lhs, Eq, rhs) => self.eval_binary_expr(rhs, Eq.into(), lhs, op_token), + // any != any + (lhs, NotEq, rhs) => { + let Bool(eq) = self.eval_binary_expr(lhs, Eq.into(), rhs, op_token)? else { + unreachable!(); + }; + Ok(Bool(!eq)) + } + // Uint cannot be < 0 + (Uint(_), Less, Int(rhs)) if rhs < 0 => Ok(Bool(false)), + (Int(lhs), Less, Uint(_)) if lhs < 0 => Ok(Bool(true)), + (Uint(_), Less, Float(rhs)) if rhs < 0.0 => Ok(Bool(false)), + (Float(lhs), Less, Uint(_)) if lhs < 0.0 => Ok(Bool(true)), + // Uint < Int + (Uint(lhs), Less, Int(rhs)) => Ok(Bool(lhs < (rhs as usize))), + // Int < Uint + (Int(lhs), Less, Uint(rhs)) => Ok(Bool((lhs as usize) < rhs)), + // Float < rhs + (Float(lhs), Less, rhs) => { + let rhs = Float(coerce_f32(rhs)); + self.eval_binary_expr(Float(lhs), Less.into(), rhs, op_token) + } + // lhs < Float + (lhs, Less, Float(rhs)) => { + let lhs = Float(coerce_f32(lhs)); + self.eval_binary_expr(lhs, Less.into(), Float(rhs), op_token) + } + // lhs > rhs -> rhs < lhs + (lhs, Greater, rhs) => self.eval_binary_expr(rhs, Less.into(), lhs, op_token), + // lhs <= rhs -> (lhs == rhs) || (lhs < rhs) + (lhs, LessEq, rhs) => { + if let Ok(Bool(true)) = self.eval_binary_expr(lhs, Eq.into(), rhs, op_token) + { + Ok(Bool(true)) + } else { + self.eval_binary_expr(lhs, Less.into(), rhs, op_token) + } + } + // lhs >= rhs -> rhs <= lhs + (lhs, GreaterEq, rhs) => { + self.eval_binary_expr(rhs, LessEq.into(), lhs, op_token) + } + } + } + _ => todo!(), + } + } +} + +impl From for Interpreter +where S: Into +{ + fn from(value: S) -> Self { + Self::new(value.into()) + } } impl Visitor for Interpreter { @@ -261,28 +444,240 @@ impl Visitor for Interpreter { } }; - use Value::*; - - let result = match (lhs, op, rhs) { - (lhs, BinOp::Logical(op), rhs) => { - use BinLogic::*; - match (lhs, op, rhs) { - // Booleans - (Bool(lhs), Eq, Bool(rhs)) => Ok(Bool(lhs == rhs)), - (Bool(lhs), NotEq, Bool(rhs)) => Ok(Bool(lhs != rhs)), - (Bool(lhs), Or, Bool(rhs)) => Ok(Bool(lhs || rhs)), - (Bool(lhs), And, Bool(rhs)) => Ok(Bool(lhs && rhs)), - (Bool(_), _, _) | (_, _, Bool(_)) => Err(SpannedError { - message: format!("Cannot use operator {} to compare booleans", expr.op), - source: self.source.clone(), - span: Some(expr.op.span()), - }), - _ => todo!(), - } - } - _ => todo!(), - }; + self.result = Some(self.eval_binary_expr(lhs, op, rhs, &expr.op)); FlowControl::Break } } + +fn coerce_bool(value: Value) -> bool { + match value { + Value::Bool(value) => value, + Value::Float(f) if f == 0.0 => false, + Value::Int(0) | Value::Uint(0) => false, + _ => true, + } +} + +/// Panics if `value` is a `Value::Bool` +fn coerce_f32(value: Value) -> f32 { + match value { + Value::Float(value) => value, + Value::Int(value) => value as f32, + Value::Uint(value) => value as f32, + Value::Bool(_) => panic!("Attempted to coerce a boolean to an f32"), + } +} + +#[cfg(test)] +mod tests { + use gramatika::arcstr; + + use crate::testing::token; + + use super::{BinLogic, Interpreter, Value}; + + #[allow(unused_macros)] + macro_rules! eval { + ($expr:literal) => {{ + use ::gramatika::ParseStreamer; + use ::parser::traversal::Walk; + + let source = ::gramatika::arcstr::literal!($expr); + let mut parser = ::parser::ParseStream::from(source.clone()); + let expr = parser.parse::<::parser::expr::Expr>().unwrap(); + + let mut interpreter = super::Interpreter::new(source); + expr.walk(&mut interpreter); + + interpreter.result.unwrap() + }}; + } + + macro_rules! assert_matches { + ($( $tt:tt )*) => { + ::std::assert!(::std::matches!( $($tt)* )) + } + } + + mod eval_binary_expr { + use super::*; + + #[test] + fn four_greater_two() { + use BinLogic::*; + use Value::*; + + let source = arcstr::literal!("4 > 2"); + let op_token = token!(Operator ">" (1:2..1:3)); + let interpreter = Interpreter::from(source); + + assert_matches!( + interpreter.eval_binary_expr(Uint(4), Greater.into(), Uint(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Int(4), Greater.into(), Int(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(4.), Greater.into(), Float(2.), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Uint(4), Greater.into(), Int(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Int(4), Greater.into(), Uint(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(4.), Greater.into(), Int(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Int(4), Greater.into(), Float(2.), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(4.), Greater.into(), Uint(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Uint(4), Greater.into(), Float(2.), &op_token), + Ok(Bool(true)), + ); + } + + #[test] + fn two_eq_two() { + use BinLogic::*; + use Value::*; + + let source = arcstr::literal!("2 == 2"); + let op_token = token!(Operator "==" (1:2..1:4)); + let interpreter = Interpreter::from(source); + + assert_matches!( + interpreter.eval_binary_expr(Uint(2), Eq.into(), Uint(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Int(2), Eq.into(), Int(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(2.), Eq.into(), Float(2.), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Uint(2), Eq.into(), Int(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Int(2), Eq.into(), Uint(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(2.), Eq.into(), Int(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Int(2), Eq.into(), Float(2.), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(2.), Eq.into(), Uint(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Uint(2), Eq.into(), Float(2.), &op_token), + Ok(Bool(true)), + ); + } + + #[test] + fn two_neq_two() { + use BinLogic::*; + use Value::*; + + let source = arcstr::literal!("2 != 2"); + let op_token = token!(Operator "!=" (1:2..1:4)); + let interpreter = Interpreter::from(source); + + assert_matches!( + interpreter.eval_binary_expr(Uint(2), NotEq.into(), Uint(2), &op_token), + Ok(Bool(false)), + ); + assert_matches!( + interpreter.eval_binary_expr(Int(2), NotEq.into(), Int(2), &op_token), + Ok(Bool(false)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(2.), NotEq.into(), Float(2.), &op_token), + Ok(Bool(false)), + ); + assert_matches!( + interpreter.eval_binary_expr(Uint(2), NotEq.into(), Int(2), &op_token), + Ok(Bool(false)), + ); + assert_matches!( + interpreter.eval_binary_expr(Int(2), NotEq.into(), Uint(2), &op_token), + Ok(Bool(false)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(2.), NotEq.into(), Int(2), &op_token), + Ok(Bool(false)), + ); + assert_matches!( + interpreter.eval_binary_expr(Int(2), NotEq.into(), Float(2.), &op_token), + Ok(Bool(false)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(2.), NotEq.into(), Uint(2), &op_token), + Ok(Bool(false)), + ); + assert_matches!( + interpreter.eval_binary_expr(Uint(2), NotEq.into(), Float(2.), &op_token), + Ok(Bool(false)), + ); + } + + #[test] + fn float_cmp_int() { + use BinLogic::*; + use Value::*; + + let source = arcstr::literal!("2.5 > 2"); + let op_token = token!(Operator ">" (1:4..1:5)); + let interpreter = Interpreter::from(source); + + assert_matches!( + interpreter.eval_binary_expr(Float(2.5), Greater.into(), Float(2.), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(2.5), Greater.into(), Int(2), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(2.5), Greater.into(), Uint(2), &op_token), + Ok(Bool(true)), + ); + + let source = arcstr::literal!("-2.5 < -2"); + let op_token = token!(Operator "<" (1:6..1:7)); + let interpreter = Interpreter::from(source); + + assert_matches!( + interpreter.eval_binary_expr(Float(-2.5), Less.into(), Float(-2.), &op_token), + Ok(Bool(true)), + ); + assert_matches!( + interpreter.eval_binary_expr(Float(-2.5), Less.into(), Int(-2), &op_token), + Ok(Bool(true)), + ); + } + } +} diff --git a/packages/server/src/pre/mod.rs b/packages/server/src/pre/mod.rs index 987994ac..0045219b 100644 --- a/packages/server/src/pre/mod.rs +++ b/packages/server/src/pre/mod.rs @@ -1,17 +1,12 @@ -use std::sync::Arc; - use bevy_utils::HashMap; use gramatika::{ArcStr, SpannedError}; use lsp_types::Url; use parser::{ - comment::Comment, pre::{self, traversal::Walk}, - scopes::Scope, - ParseResult, ParseStreamer, Span, SyntaxTree, Token, + Span, }; -use ropey::Rope; -use crate::{documents::WgslDocumentBundle, pre::pruner::Pruner}; +use crate::pre::pruner::Pruner; mod interpreter; mod pruner; @@ -46,42 +41,4 @@ pub fn prune(uri: &Url, source: ArcStr, defs: &HashMap) -> Prune inactive_spans: pruner.inactive_spans, errors, } - - // #[derive(Bundle)] - // pub struct WgslDocumentBundle { - // pub uri: DocumentUri, - // pub source: WgslSource, - // pub tokens: WgslTokens, - // pub comments: WgslComments, - // pub errors: ParseErrors, - // pub ast: WgslAst, - // pub rope: WgslRope, - // pub scopes: WgslScopes, - // pub token_refs: TokenReferences, - // pub inactive_ranges: InactiveRanges, - // } - - // let mut parser = parser::ParseStream::from(&pruned_source); - // let tree = parser.parse::()?; - // let ParseResult { - // tokens, - // comments, - // mut errors, - // .. - // } = parser.into_inner(); - - // errors.extend(pre_parse_errors); - // errors.extend(pruner.errors); - - // Ok(( - // // uri, - // pruner.source, - // pruned_source, - // tokens, - // comments, - // errors, - // pruner.inactive_branches, - // parser::scopes::build(&tree), - // tree, - // )) }