diff --git a/src/treewalk/interpreter.rs b/src/treewalk/interpreter.rs index d30a31d..7849215 100644 --- a/src/treewalk/interpreter.rs +++ b/src/treewalk/interpreter.rs @@ -162,7 +162,7 @@ impl Interpreter { .ok_or(InterpreterError::ExpectedBoolean(self.state.call_stack()))? } - fn evaluate_binary_operation( + fn evaluate_binary_operation_outer( &self, left: &Expr, op: &BinOp, @@ -171,6 +171,15 @@ impl Interpreter { let left = self.evaluate_expr(left)?; let right = self.evaluate_expr(right)?; + self.evaluate_binary_operation(left, op, right) + } + + fn evaluate_binary_operation( + &self, + left: ExprResult, + op: &BinOp, + right: ExprResult, + ) -> Result { if matches!(op, BinOp::In) { if let Some(mut iterable) = right.try_into_iter() { return Ok(ExprResult::Boolean(iterable.contains(left))); @@ -227,35 +236,12 @@ impl Interpreter { && right.as_object().is_some() && matches!(op, BinOp::Equals | BinOp::NotEquals) { - match op { - BinOp::Equals => self.evaluate_method( - left, - Dunder::Eq.value(), - &ResolvedArguments::default().add_arg(right), - ), - BinOp::NotEquals => { - if left - .as_object() - .unwrap() - .get(self, Dunder::Ne.value()) - .is_some() - { - self.evaluate_method( - left, - Dunder::Ne.value(), - &ResolvedArguments::default().add_arg(right), - ) - } else { - let result = self.evaluate_method( - left, - Dunder::Eq.value(), - &ResolvedArguments::default().add_arg(right), - )?; - Ok(result.inverted()) - } - } + let dunder = match op { + BinOp::Equals => Dunder::Eq.value(), + BinOp::NotEquals => Dunder::Ne.value(), _ => unreachable!(), - } + }; + self.evaluate_method(left, dunder, &ResolvedArguments::default().add_arg(right)) } else { evaluators::evaluate_object_comparison(left, op, right) } @@ -740,7 +726,7 @@ impl Interpreter { value: &Expr, ) -> Result<(), InterpreterError> { let op = operator.to_bin_op(); - let result = self.evaluate_binary_operation(target, &op, value)?; + let result = self.evaluate_binary_operation_outer(target, &op, value)?; self.evaluate_assignment_inner(target, result) } @@ -1422,7 +1408,7 @@ impl Interpreter { } => self.evaluate_dict_comprehension(key, value, range, key_body, value_body), Expr::UnaryOperation { op, right } => self.evaluate_unary_operation(op, right), Expr::BinaryOperation { left, op, right } => { - self.evaluate_binary_operation(left, op, right) + self.evaluate_binary_operation_outer(left, op, right) } Expr::Await { right } => self.evaluate_await(right), Expr::FunctionCall { name, args, callee } => { @@ -8512,6 +8498,8 @@ f = Foo(4) g = Foo(4) a = f == g b = f != g +c = f.__ne__ +d = c(g) "#; let (mut parser, mut interpreter) = init(input); @@ -8524,6 +8512,14 @@ b = f != g interpreter.state.read("b"), Some(ExprResult::Boolean(false)) ); + let Some(ExprResult::Method(method)) = interpreter.state.read("c") else { + panic!("Expected a method!"); + }; + assert!(matches!(method.receiver(), Some(ExprResult::Object(_)))); + assert_eq!( + interpreter.state.read("d"), + Some(ExprResult::Boolean(false)) + ); } } } diff --git a/src/treewalk/types/object.rs b/src/treewalk/types/object.rs index 470b252..dc21fb2 100644 --- a/src/treewalk/types/object.rs +++ b/src/treewalk/types/object.rs @@ -27,6 +27,7 @@ impl Object { Box::new(InitBuiltin), Box::new(NewBuiltin), Box::new(EqBuiltin), + Box::new(NeBuiltin), ] } @@ -230,3 +231,34 @@ impl Callable for EqBuiltin { BindingType::Instance } } + +/// The default behavior in Python for the `!=` sign is to call the `Dunder::Eq` and invert the +/// result. This is only used when `Dunder::Ne` is not overridden by a user-defined class. +struct NeBuiltin; + +impl Callable for NeBuiltin { + fn call( + &self, + interpreter: &Interpreter, + args: ResolvedArguments, + ) -> Result { + let receiver = args.get_self().ok_or(InterpreterError::ExpectedObject( + interpreter.state.call_stack(), + ))?; + let result = interpreter.evaluate_method( + receiver, + Dunder::Eq.value(), + &ResolvedArguments::default().add_arg(args.get_arg(0)), + )?; + + Ok(result.inverted()) + } + + fn name(&self) -> String { + Dunder::Ne.into() + } + + fn binding_type(&self) -> BindingType { + BindingType::Instance + } +}