diff --git a/corelib/src/num/traits/ops/pow.cairo b/corelib/src/num/traits/ops/pow.cairo index 612f6aaeb40..0b99b005a65 100644 --- a/corelib/src/num/traits/ops/pow.cairo +++ b/corelib/src/num/traits/ops/pow.cairo @@ -40,16 +40,42 @@ mod mul_based { type Output = T; const fn pow(self: T, exp: usize) -> T { - let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); - let tail_result = if tail_exp == 0 { + if exp == 0 { H::one() } else { - Self::pow(H::mul(self, self), tail_exp) - }; + pow_non_zero_exp(self, exp) + } + } + } + + /// Equivalent of `PowByMul::pow` but assumes `exp` is non zero. + const fn pow_non_zero_exp, +Copy, +Drop>( + base: T, exp: usize, + ) -> T { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + if head_exp == 0 { + pow_non_zero_exp(H::mul(base, base), tail_exp) + } else { + pow_given_sqrt_base(base, tail_exp, base) + } + } + + /// Returns `(sqrt_base * sqrt_base).pow(exp) * acc`. + /// + /// Receives the square root of the base to avoid overflow if the squaring is not required for + /// the calculation (mostly when `exp` is 0). + const fn pow_given_sqrt_base, +Copy, +Drop>( + sqrt_base: T, exp: usize, acc: T, + ) -> T { + if exp == 0 { + acc + } else { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let base = H::mul(sqrt_base, sqrt_base); if head_exp == 0 { - tail_result + H::mul(pow_non_zero_exp(H::mul(base, base), tail_exp), acc) } else { - H::mul(tail_result, self) + pow_given_sqrt_base(base, tail_exp, H::mul(base, acc)) } } }