Skip to content

Commit

Permalink
Avoid exponential behavior in fold_mul_chains
Browse files Browse the repository at this point in the history
  • Loading branch information
simonlindholm committed May 18, 2024
1 parent 1cb6b55 commit cdc8cdd
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions m2c/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,25 +811,32 @@ def fold(
expr: Expression, toplevel: bool, allow_sll: bool
) -> Tuple[Expression, int]:
if isinstance(expr, BinaryOp):
lbase, lnum = fold(expr.left, False, (expr.op != "<<"))
rbase, rnum = fold(expr.right, False, (expr.op != "<<"))
if expr.op == "<<" and isinstance(expr.right, Literal) and allow_sll:
# Left-shifts by small numbers are easier to understand if
# written as multiplications (they compile to the same thing).
if toplevel and lnum == 1 and not (1 <= expr.right.value <= 4):
return (expr, 1)
return (lbase, lnum << expr.right.value)
if (
expr.op == "*"
and isinstance(expr.right, Literal)
and (allow_sll or expr.right.value % 2 != 0)
):
return (lbase, lnum * expr.right.value)
if early_unwrap(lbase) == early_unwrap(rbase):
if expr.op == "+":
return (lbase, lnum + rnum)
if expr.op == "-":
return (lbase, lnum - rnum)
if expr.op in ("<<", "*") and isinstance(expr.right, Literal):
lbase, lnum = fold(expr.left, False, (expr.op != "<<"))
rhs = expr.right.value
if expr.op == "<<" and allow_sll:
# At top level, keep left shifts, unless they are by such
# small numbers that they are easier to understand as
# multiplications (they compile to the same thing).
if toplevel and lnum == 1 and not (1 <= rhs <= 4):
return (expr, 1)
return (lbase, lnum << rhs)
if expr.op == "*" and (allow_sll or rhs % 2 != 0):
# If we don't allow << to be expanded into multiplication
# because the outer layer is already <<'ing, don't allow
# multiplication by even numbers either, because the power
# of two part of that scalar would have been folded into
# the outer shift unless something weird was up, which we
# want to highlight.
return (lbase, lnum * rhs)
if expr.op in ("+", "-") and toplevel:
lbase, lnum = fold(expr.left, False, (expr.op != "<<"))
rbase, rnum = fold(expr.right, False, (expr.op != "<<"))
if early_unwrap(lbase) == early_unwrap(rbase):
if expr.op == "+":
return (lbase, lnum + rnum)
if expr.op == "-":
return (lbase, lnum - rnum)
if isinstance(expr, UnaryOp) and expr.op == "-" and not toplevel:
base, num = fold(expr.expr, False, True)
return (base, -num)
Expand Down Expand Up @@ -999,9 +1006,13 @@ def handle_add(args: InstrArgs) -> Expression:
# addiu instructions can sometimes be emitted as addu instead, when the
# offset is too large.
if isinstance(rhs, Literal):
return handle_addi_real(output_reg, args.reg_ref(1), lhs, rhs, args)
return fold_mul_chains(
handle_addi_real(output_reg, args.reg_ref(1), lhs, rhs, args)
)
if isinstance(lhs, Literal):
return handle_addi_real(output_reg, args.reg_ref(2), rhs, lhs, args)
return fold_mul_chains(
handle_addi_real(output_reg, args.reg_ref(2), rhs, lhs, args)
)

return handle_add_real(output_reg, lhs, rhs, args)

Expand Down

0 comments on commit cdc8cdd

Please sign in to comment.