diff --git a/m2c/evaluate.py b/m2c/evaluate.py index 043ce9f6..c14c2b67 100644 --- a/m2c/evaluate.py +++ b/m2c/evaluate.py @@ -811,25 +811,32 @@ def fold( expr: Expression, toplevel: bool, allow_sll: bool ) -> Tuple[Expression, int]: if isinstance(expr, BinaryOp): - lbase, lnum = fold(expr.left, False, (expr.op != "<<")) - rbase, rnum = fold(expr.right, False, (expr.op != "<<")) - if expr.op == "<<" and isinstance(expr.right, Literal) and allow_sll: - # Left-shifts by small numbers are easier to understand if - # written as multiplications (they compile to the same thing). - if toplevel and lnum == 1 and not (1 <= expr.right.value <= 4): - return (expr, 1) - return (lbase, lnum << expr.right.value) - if ( - expr.op == "*" - and isinstance(expr.right, Literal) - and (allow_sll or expr.right.value % 2 != 0) - ): - return (lbase, lnum * expr.right.value) - if early_unwrap(lbase) == early_unwrap(rbase): - if expr.op == "+": - return (lbase, lnum + rnum) - if expr.op == "-": - return (lbase, lnum - rnum) + if expr.op in ("<<", "*") and isinstance(expr.right, Literal): + lbase, lnum = fold(expr.left, False, (expr.op != "<<")) + rhs = expr.right.value + if expr.op == "<<" and allow_sll: + # At top level, keep left shifts, unless they are by such + # small numbers that they are easier to understand as + # multiplications (they compile to the same thing). + if toplevel and lnum == 1 and not (1 <= rhs <= 4): + return (expr, 1) + return (lbase, lnum << rhs) + if expr.op == "*" and (allow_sll or rhs % 2 != 0): + # If we don't allow << to be expanded into multiplication + # because the outer layer is already <<'ing, don't allow + # multiplication by even numbers either, because the power + # of two part of that scalar would have been folded into + # the outer shift unless something weird was up, which we + # want to highlight. + return (lbase, lnum * rhs) + if expr.op in ("+", "-") and toplevel: + lbase, lnum = fold(expr.left, False, (expr.op != "<<")) + rbase, rnum = fold(expr.right, False, (expr.op != "<<")) + if early_unwrap(lbase) == early_unwrap(rbase): + if expr.op == "+": + return (lbase, lnum + rnum) + if expr.op == "-": + return (lbase, lnum - rnum) if isinstance(expr, UnaryOp) and expr.op == "-" and not toplevel: base, num = fold(expr.expr, False, True) return (base, -num) @@ -999,9 +1006,13 @@ def handle_add(args: InstrArgs) -> Expression: # addiu instructions can sometimes be emitted as addu instead, when the # offset is too large. if isinstance(rhs, Literal): - return handle_addi_real(output_reg, args.reg_ref(1), lhs, rhs, args) + return fold_mul_chains( + handle_addi_real(output_reg, args.reg_ref(1), lhs, rhs, args) + ) if isinstance(lhs, Literal): - return handle_addi_real(output_reg, args.reg_ref(2), rhs, lhs, args) + return fold_mul_chains( + handle_addi_real(output_reg, args.reg_ref(2), rhs, lhs, args) + ) return handle_add_real(output_reg, lhs, rhs, args)