From 78496effb041f24199a78b7b437d5ff5f21487a4 Mon Sep 17 00:00:00 2001 From: mmatera Date: Thu, 20 Jul 2023 22:43:46 -0300 Subject: [PATCH] refactor arithmetic power --- mathics/builtin/arithfns/basic.py | 81 ++++++++++++++------ mathics/eval/arithmetic.py | 105 +++++++++++++++++++++++++- test/builtin/arithmetic/test_basic.py | 8 +- test/format/test_format.py | 43 ++++++++--- 4 files changed, 197 insertions(+), 40 deletions(-) diff --git a/mathics/builtin/arithfns/basic.py b/mathics/builtin/arithfns/basic.py index 558369c9f..23082458a 100644 --- a/mathics/builtin/arithfns/basic.py +++ b/mathics/builtin/arithfns/basic.py @@ -7,6 +7,9 @@ """ + +import sympy + from mathics.builtin.arithmetic import create_infix from mathics.builtin.base import ( BinaryOperator, @@ -45,7 +48,6 @@ Symbol, SymbolDivide, SymbolHoldForm, - SymbolNull, SymbolPower, SymbolTimes, ) @@ -56,10 +58,17 @@ SymbolInfix, SymbolLeft, SymbolMinus, + SymbolOverflow, SymbolPattern, - SymbolSequence, ) -from mathics.eval.arithmetic import eval_Plus, eval_Times +from mathics.eval.arithmetic import ( + associate_powers, + eval_Exponential, + eval_Plus, + eval_Power_inexact, + eval_Power_number, + eval_Times, +) from mathics.eval.nevaluator import eval_N from mathics.eval.numerify import numerify @@ -535,15 +544,15 @@ class Power(BinaryOperator, MPMathFunction): # Remember to up sympy doc link when this is corrected sympy_name = "Pow" + def eval_exp(self, x, evaluation): + "Power[E, x]" + return eval_Exponential(x) + def eval_check(self, x, y, evaluation): "Power[x_, y_]" - - # Power uses MPMathFunction but does some error checking first - if isinstance(x, Number) and x.is_zero: - if isinstance(y, Number): - y_err = y - else: - y_err = eval_N(y, evaluation) + # if x is zero + if x.is_zero: + y_err = y if isinstance(y, Number) else eval_N(y, evaluation) if isinstance(y_err, Number): py_y = y_err.round_to_float(permit_complex=True).real if py_y > 0: @@ -557,17 +566,47 @@ def eval_check(self, x, y, evaluation): evaluation.message( "Power", "infy", Expression(SymbolPower, x, y_err) ) - return SymbolComplexInfinity - if isinstance(x, Complex) and x.real.is_zero: - yhalf = Expression(SymbolTimes, y, RationalOneHalf) - factor = self.eval(Expression(SymbolSequence, x.imag, y), evaluation) - return Expression( - SymbolTimes, factor, Expression(SymbolPower, IntegerM1, yhalf) - ) - - result = self.eval(Expression(SymbolSequence, x, y), evaluation) - if result is None or result != SymbolNull: - return result + return SymbolComplexInfinity + + # If x and y are inexact numbers, use the numerical function + + if x.is_inexact() and y.is_inexact(): + try: + return eval_Power_inexact(x, y) + except OverflowError: + evaluation.message("General", "ovfl") + return Expression(SymbolOverflow) + + # Tries to associate powers a^b^c-> a^(b*c) + assoc = associate_powers(x, y) + if not assoc.has_form("Power", 2): + return assoc + + assoc = numerify(assoc, evaluation) + x, y = assoc.elements + # If x and y are numbers + if isinstance(x, Number) and isinstance(y, Number): + try: + return eval_Power_number(x, y) + except OverflowError: + evaluation.message("General", "ovfl") + return Expression(SymbolOverflow) + + # if x or y are inexact, leave the expression + # as it is: + if x.is_inexact() or y.is_inexact(): + return assoc + + # Finally, try to convert to sympy + base_sp, exp_sp = x.to_sympy(), y.to_sympy() + if base_sp is None or exp_sp is None: + # If base or exp can not be converted to sympy, + # returns the result of applying the associative + # rule. + return assoc + + result = from_sympy(sympy.Pow(base_sp, exp_sp)) + return result.evaluate_elements(evaluation) class Sqrt(SympyFunction): diff --git a/mathics/eval/arithmetic.py b/mathics/eval/arithmetic.py index 035dff801..38606483a 100644 --- a/mathics/eval/arithmetic.py +++ b/mathics/eval/arithmetic.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- """ -arithmetic-related evaluation functions. +helper functions for arithmetic evaluation, which do not +depends on the evaluation context. Conversions to Sympy are +used just as a last resource. Many of these do do depend on the evaluation context. Conversions to Sympy are used just as a last resource. @@ -320,6 +322,28 @@ def eval_complex_sign(n: BaseElement) -> Optional[BaseElement]: return sign or eval_complex_sign(expr) +def eval_Sign_number(n: Number) -> Number: + """ + Evals the absolute value of a number. + """ + if n.is_zero: + return Integer0 + if isinstance(n, (Integer, Rational, Real)): + return Integer1 if n.value > 0 else IntegerM1 + if isinstance(n, Complex): + abs_sq = eval_add_numbers( + *(eval_multiply_numbers(x, x) for x in (n.real, n.imag)) + ) + criteria = eval_add_numbers(abs_sq, IntegerM1) + if test_zero_arithmetic_expr(criteria): + return n + if n.is_inexact(): + return eval_multiply_numbers(n, eval_Power_number(abs_sq, RealM0p5)) + if test_zero_arithmetic_expr(criteria, numeric=True): + return n + return eval_multiply_numbers(n, eval_Power_number(abs_sq, RationalMOneHalf)) + + def eval_mpmath_function( mpmath_function: Callable, *args: Number, prec: Optional[int] = None ) -> Optional[Number]: @@ -347,6 +371,31 @@ def eval_mpmath_function( return call_mpmath(mpmath_function, tuple(mpmath_args), prec) +def eval_Exponential(exp: BaseElement) -> BaseElement: + """ + Eval E^exp + """ + # If both base and exponent are exact quantities, + # use sympy. + + if not exp.is_inexact(): + exp_sp = exp.to_sympy() + if exp_sp is None: + return None + return from_sympy(sympy.Exp(exp_sp)) + + prec = exp.get_precision() + if prec is not None: + if exp.is_machine_precision(): + number = mpmath.exp(exp.to_mpmath()) + result = from_mpmath(number) + return result + else: + with mpmath.workprec(prec): + number = mpmath.exp(exp.to_mpmath()) + return from_mpmath(number, prec) + + def eval_Plus(*items: BaseElement) -> BaseElement: "evaluate Plus for general elements" numbers, items_tuple = segregate_numbers_from_sorted_list(*items) @@ -645,8 +694,58 @@ def eval_Times(*items: BaseElement) -> BaseElement: ) +def associate_powers(expr: BaseElement, power: BaseElement = Integer1) -> BaseElement: + """ + base^a^b^c^...^power -> base^(a*b*c*...power) + provided one of the following cases + * `a`, `b`, ... `power` are all integer numbers + * `a`, `b`,... are Rational/Real number with absolute value <=1, + and the other powers are not integer numbers. + * `a` is not a Rational/Real number, and b, c, ... power are all + integer numbers. + """ + powers = [] + base = expr + if power is not Integer1: + powers.append(power) + + while base.has_form("Power", 2): + previous_base, outer_power = base, power + base, power = base.elements + if len(powers) == 0: + if power is not Integer1: + powers.append(power) + continue + if power is IntegerM1: + powers.append(power) + continue + if isinstance(power, (Rational, Real)): + if abs(power.value) < 1: + powers.append(power) + continue + # power is not rational/real and outer_power is integer, + elif isinstance(outer_power, Integer): + if power is not Integer1: + powers.append(power) + if isinstance(power, Integer): + continue + else: + break + # in any other case, use the previous base and + # exit the loop + base = previous_base + break + + if len(powers) == 0: + return base + elif len(powers) == 1: + return Expression(SymbolPower, base, powers[0]) + result = Expression(SymbolPower, base, Expression(SymbolTimes, *powers)) + return result + + def eval_add_numbers( - *numbers: Number, + *numbers: List[Number], ) -> BaseElement: """ Add the elements in ``numbers``. @@ -693,7 +792,7 @@ def eval_inverse_number(n: Number) -> Number: return eval_Power_number(n, IntegerM1) -def eval_multiply_numbers(*numbers: Number) -> Number: +def eval_multiply_numbers(*numbers: Number) -> BaseElement: """ Multiply the elements in ``numbers``. """ diff --git a/test/builtin/arithmetic/test_basic.py b/test/builtin/arithmetic/test_basic.py index d99b0b9dc..097208fc8 100644 --- a/test/builtin/arithmetic/test_basic.py +++ b/test/builtin/arithmetic/test_basic.py @@ -197,7 +197,7 @@ def test_directed_infinity_precedence(str_expr, str_expected, msg): ("I^(2/3)", "(-1) ^ (1 / 3)", None), # In WMA, the next test would return ``-(-I)^(2/3)`` # which is less compact and elegant... - # ("(-I)^(2/3)", "(-1) ^ (-1 / 3)", None), + ("(-I)^(2/3)", "(-1) ^ (-1 / 3)", None), ("(2+3I)^3", "-46 + 9 I", None), ("(1.+3. I)^.6", "1.46069 + 1.35921 I", None), ("3^(1+2 I)", "3 ^ (1 + 2 I)", None), @@ -208,15 +208,15 @@ def test_directed_infinity_precedence(str_expr, str_expected, msg): # sympy, which produces the result ("(3/Pi)^(-I)", "(3 / Pi) ^ (-I)", None), # Association rules - # ('(a^"w")^2', 'a^(2 "w")', "Integer power of a power with string exponent"), + ('(a^"w")^2', 'a^(2 "w")', "Integer power of a power with string exponent"), ('(a^2)^"w"', '(a ^ 2) ^ "w"', None), ('(a^2)^"w"', '(a ^ 2) ^ "w"', None), ("(a^2)^(1/2)", "Sqrt[a ^ 2]", None), ("(a^(1/2))^2", "a", None), ("(a^(1/2))^2", "a", None), ("(a^(3/2))^3.", "(a ^ (3 / 2)) ^ 3.", None), - # ("(a^(1/2))^3.", "a ^ 1.5", "Power associativity rational, real"), - # ("(a^(.3))^3.", "a ^ 0.9", "Power associativity for real powers"), + ("(a^(1/2))^3.", "a ^ 1.5", "Power associativity rational, real"), + ("(a^(.3))^3.", "a ^ 0.9", "Power associativity for real powers"), ("(a^(1.3))^3.", "(a ^ 1.3) ^ 3.", None), # Exponentials involving expressions ("(a^(p-2 q))^3", "a ^ (3 p - 6 q)", None), diff --git a/test/format/test_format.py b/test/format/test_format.py index 161ebc5df..ee81add3c 100644 --- a/test/format/test_format.py +++ b/test/format/test_format.py @@ -456,34 +456,53 @@ "Sqrt[1/(1+1/(1+1/a))]": { "msg": "SqrtBox", "text": { - "System`StandardForm": "Sqrt[1 / (1+1 / (1+1 / a))]", - "System`TraditionalForm": "Sqrt[1 / (1+1 / (1+1 / a))]", - "System`InputForm": "Sqrt[1 / (1 + 1 / (1 + 1 / a))]", - "System`OutputForm": "Sqrt[1 / (1 + 1 / (1 + 1 / a))]", + "System`StandardForm": "1 / Sqrt[1+1 / (1+1 / a)]", + "System`TraditionalForm": "1 / Sqrt[1+1 / (1+1 / a)]", + "System`InputForm": "1 / Sqrt[1 + 1 / (1 + 1 / a)]", + "System`OutputForm": "1 / Sqrt[1 + 1 / (1 + 1 / a)]", }, "mathml": { "System`StandardForm": ( - " 1 1 + 1 1 + 1 a ", + ( + r"1 1 + 1 " + r"1 + 1 a " + r"" + ), "Fragile!", ), "System`TraditionalForm": ( - " 1 1 + 1 1 + 1 a ", + ( + r"1 1 + 1 " + r"1 + 1 a " + r"" + ), "Fragile!", ), "System`InputForm": ( - "Sqrt [ 1  /  ( 1  +  1  /  ( 1  +  1  /  a ) ) ]", + ( + r"1  /  Sqrt [ " + r"1  +  1  /  " + r"( 1  +  1 " + r" /  a ) ]" + ), "Fragile!", ), "System`OutputForm": ( - "Sqrt [ 1  /  ( 1  +  1  /  ( 1  +  1  /  a ) ) ]", + ( + r"1  /  Sqrt [" + r" 1  +  1 " + r" /  ( 1 " + r" +  1  /  " + r"a ) ]" + ), "Fragile!", ), }, "latex": { - "System`StandardForm": "\\sqrt{\\frac{1}{1+\\frac{1}{1+\\frac{1}{a}}}}", - "System`TraditionalForm": "\\sqrt{\\frac{1}{1+\\frac{1}{1+\\frac{1}{a}}}}", - "System`InputForm": "\\text{Sqrt}\\left[1\\text{ / }\\left(1\\text{ + }1\\text{ / }\\left(1\\text{ + }1\\text{ / }a\\right)\\right)\\right]", - "System`OutputForm": "\\text{Sqrt}\\left[1\\text{ / }\\left(1\\text{ + }1\\text{ / }\\left(1\\text{ + }1\\text{ / }a\\right)\\right)\\right]", + "System`StandardForm": "\\frac{1}{\\sqrt{1+\\frac{1}{1+\\frac{1}{a}}}}", + "System`TraditionalForm": "\\frac{1}{\\sqrt{1+\\frac{1}{1+\\frac{1}{a}}}}", + "System`InputForm": r"1\text{ / }\text{Sqrt}\left[1\text{ + }1\text{ / }\left(1\text{ + }1\text{ / }a\right)\right]", + "System`OutputForm": r"1\text{ / }\text{Sqrt}\left[1\text{ + }1\text{ / }\left(1\text{ + }1\text{ / }a\right)\right]", }, }, # Grids, arrays and matrices