Skip to content

Commit

Permalink
refactor: Bitshift takes usize as arg
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Sep 30, 2024
1 parent 8b29b2f commit dbf3196
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 55 deletions.
9 changes: 8 additions & 1 deletion crates/evm/src/instructions/comparison_operations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait {
if i > 31 {
return self.stack.push(0);
}
let i: usize = i.try_into().unwrap(); // Safe because i <= 31

// Right shift value by offset bits and then take the least significant byte.
let result = x.shr((31 - i) * 8) & 0xFF;
Expand All @@ -150,7 +151,7 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait {
if shift > 255 {
return self.stack.push(0);
}

let shift: usize = shift.try_into().unwrap(); // Safe because shift <= 255
let result = val.wrapping_shl(shift);
self.stack.push(result)
}
Expand All @@ -163,6 +164,11 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait {
let shift = *popped[0];
let value = *popped[1];

// if shift is bigger than 255 return 0
if shift > 255 {
return self.stack.push(0);
}
let shift: usize = shift.try_into().unwrap(); // Safe because shift <= 255
let result = value.wrapping_shr(shift);
self.stack.push(result)
}
Expand All @@ -187,6 +193,7 @@ pub impl ComparisonAndBitwiseOperations of ComparisonAndBitwiseOperationsTrait {
if (shift > 256) {
self.stack.push(sign)
} else {
let shift: usize = shift.try_into().unwrap(); // Safe because shift <= 256
// XORing with sign before and after the shift propagates the sign bit of the operation
let result = (sign ^ value.value).shr(shift) ^ sign;
self.stack.push(result)
Expand Down
2 changes: 1 addition & 1 deletion crates/evm/src/memory.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl MemoryImpl of MemoryTrait {

// First erase byte value at offset, then set the new value using bitwise ops
let word: u128 = self.items.get(chunk_index.into());
let new_word = (word & ~mask) | (value.into().shl(right_offset.into() * 8));
let new_word = (word & ~mask) | (value.into().shl(right_offset * 8));
self.items.insert(chunk_index.into(), new_word);
}

Expand Down
2 changes: 1 addition & 1 deletion crates/utils/src/crypto/blake2_compress.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fn rotate_right(value: u64, n: u32) -> u64 {
let bits = BitSize::<u64>::bits(); // The number of bits in a u64
let n = n % bits; // Ensure n is less than 64

let res = value.wrapping_shr(n.into()) | value.wrapping_shl((bits - n).into());
let res = value.wrapping_shr(n) | value.wrapping_shl((bits - n));
res
}
}
Expand Down
24 changes: 12 additions & 12 deletions crates/utils/src/crypto/modexp/arith.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ pub fn mod_inv(x: Word) -> Word {
break;
}

let mask: u64 = 1_u64.shl(i.into()) - 1;
let mask: u64 = 1_u64.shl(i) - 1;
let xy = x.wrapping_mul(y) & mask;
let q = (mask + 1) / 2;
if xy >= q {
Expand All @@ -310,7 +310,7 @@ pub fn mod_inv(x: Word) -> Word {
};

let xy = x.wrapping_mul(y);
let q = 1_u64.wrapping_shl((WORD_BITS - 1).into());
let q = 1_u64.wrapping_shl((WORD_BITS - 1));
if xy >= q {
y += q;
}
Expand Down Expand Up @@ -415,7 +415,7 @@ pub fn borrowing_sub(x: Word, y: Word, borrow: bool) -> (Word, bool) {
/// The double word obtained by joining `hi` and `lo`
pub fn join_as_double(hi: Word, lo: Word) -> DoubleWord {
let hi: DoubleWord = hi.into();
(hi.shl(WORD_BITS.into())).into() + lo.into()
hi.shl(WORD_BITS).into() + lo.into()
}

/// Computes `x^2`, storing the result in `out`.
Expand Down Expand Up @@ -457,14 +457,14 @@ fn big_sq(ref x: MPNat, ref out: Felt252Vec<Word>) {
}

out.set(i + j, res.as_u64());
c = new_c + res.shr(WORD_BITS.into());
c = new_c + res.shr(WORD_BITS);

j += 1;
};

let (sum, carry) = carrying_add(out[i + s], c.as_u64(), false);
out.set(i + s, sum);
out.set(i + s + 1, (c.shr(WORD_BITS.into()) + (carry.into())).as_u64());
out.set(i + s + 1, (c.shr(WORD_BITS) + (carry.into())).as_u64());

i += 1;
}
Expand All @@ -482,8 +482,8 @@ pub fn in_place_shl(ref a: Felt252Vec<Word>, shift: u32) -> Word {
}

let mut a_digit = a[i];
let carry = a_digit.wrapping_shr(carry_shift.into());
a_digit = a_digit.wrapping_shl(shift.into()) | c;
let carry = a_digit.wrapping_shr(carry_shift);
a_digit = a_digit.wrapping_shl(shift) | c;
a.set(i, a_digit);

c = carry;
Expand All @@ -508,8 +508,8 @@ pub fn in_place_shr(ref a: Felt252Vec<Word>, shift: u32) -> Word {
let j = i - 1;

let mut a_digit = a[j];
let borrow = a_digit.wrapping_shl(borrow_shift.into());
a_digit = a_digit.wrapping_shr(shift.into()) | b;
let borrow = a_digit.wrapping_shl(borrow_shift);
a_digit = a_digit.wrapping_shr(shift) | b;
a.set(j, a_digit);

b = borrow;
Expand Down Expand Up @@ -574,7 +574,7 @@ pub fn in_place_mul_sub(ref a: Felt252Vec<Word>, ref x: Felt252Vec<Word>, y: Wor
+ offset_carry.into()
- ((x_digit.into()) * (y.into()));

let new_offset_carry = (offset_sum.shr(WORD_BITS.into())).as_u64();
let new_offset_carry = (offset_sum.shr(WORD_BITS)).as_u64();
let new_x = offset_sum.as_u64();
offset_carry = new_offset_carry;
a.set(i, new_x);
Expand Down Expand Up @@ -661,15 +661,15 @@ mod tests {
let mut result = mp_nat_to_u128(ref x);

let mask = BASE.wrapping_pow(x.digits.len().into()).wrapping_sub(1);
assert_eq!(result, n.wrapping_shl(shift.into()) & mask);
assert_eq!(result, n.wrapping_shl(shift) & mask);
}

fn check_in_place_shr(n: u128, shift: u32) {
let mut x = MPNatTrait::from_big_endian(n.to_be_bytes_padded());
in_place_shr(ref x.digits, shift);
let mut result = mp_nat_to_u128(ref x);

assert_eq!(result, n.wrapping_shr(shift.into()));
assert_eq!(result, n.wrapping_shr(shift));
}

fn check_mod_inv(n: Word) {
Expand Down
11 changes: 5 additions & 6 deletions crates/utils/src/crypto/modexp/mpnat.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ pub impl MPNatTraitImpl of MPNatTrait {

in_place_shr(ref b.digits, 1);

res.digits.set(wordpos, res.digits[wordpos] | (x.shl(bitpos.into())));
res.digits.set(wordpos, res.digits[wordpos] | (x.shl(bitpos)));

bitpos += 1;
if bitpos == WORD_BITS {
Expand Down Expand Up @@ -404,7 +404,7 @@ pub impl MPNatTraitImpl of MPNatTrait {
let mut digits = Felt252VecImpl::new();
digits.expand(trailing_zeros + 1).unwrap();
let mut tmp = MPNat { digits };
tmp.digits.set(trailing_zeros, 1_u64.shl(additional_zero_bits.into()));
tmp.digits.set(trailing_zeros, 1_u64.shl(additional_zero_bits));
tmp
};

Expand All @@ -415,7 +415,7 @@ pub impl MPNatTraitImpl of MPNatTrait {
digits.expand(num_digits).unwrap();
let mut tmp = MPNat { digits };
if additional_zero_bits > 0 {
tmp.digits.set(0, modulus.digits[trailing_zeros].shr(additional_zero_bits.into()));
tmp.digits.set(0, modulus.digits[trailing_zeros].shr(additional_zero_bits));
let mut i = 1;
loop {
if i == num_digits {
Expand All @@ -429,10 +429,9 @@ pub impl MPNatTraitImpl of MPNatTrait {
i - 1,
tmp.digits[i
- 1]
+ (d & power_of_two_mask)
.shl((WORD_BITS - additional_zero_bits).into())
+ (d & power_of_two_mask).shl(WORD_BITS - additional_zero_bits)
);
tmp.digits.set(i, d.shr(additional_zero_bits.into()));
tmp.digits.set(i, d.shr(additional_zero_bits));

i += 1;
};
Expand Down
53 changes: 39 additions & 14 deletions crates/utils/src/math.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::integer::{u512};
use core::num::traits::{Zero, One, BitSize, OverflowingAdd, OverflowingMul};
use core::num::traits::{Zero, One, BitSize, OverflowingAdd, OverflowingMul, Bounded};
use core::panic_with_felt252;
use core::traits::{BitAnd};

Expand Down Expand Up @@ -203,7 +203,7 @@ pub trait Bitshift<T> {
///
/// Panics if the shift is greater than 255.
/// Panics if the result overflows the type T.
fn shl(self: T, shift: T) -> T;
fn shl(self: T, shift: usize) -> T;

/// Shift a number right by a given number of bits.
///
Expand All @@ -219,7 +219,7 @@ pub trait Bitshift<T> {
/// # Panics
///
/// Panics if the shift is greater than 255.
fn shr(self: T, shift: T) -> T;
fn shr(self: T, shift: usize) -> T;
}

impl BitshiftImpl<
Expand All @@ -237,23 +237,35 @@ impl BitshiftImpl<
+BitSize<T>,
+TryInto<usize, T>,
> of Bitshift<T> {
fn shl(self: T, shift: T) -> T {
fn shl(self: T, shift: usize) -> T {
// if we shift by more than nb_bits of T, the result is 0
// we early return to save gas and prevent unexpected behavior
if shift > BitSize::<T>::bits().try_into().unwrap() - One::one() {
if shift > BitSize::<T>::bits() - One::one() {
panic_with_felt252('mul Overflow');
}
// if the shift is within the bit size of u256 (<= 255 bits),
// use the POW_2 lookup table to get 2^shift for efficient multiplication
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
// In case the pow2 is greater than the max value of T, we have an overflow
// so we can panic
return self * (*POW_2_256.span().at(shift)).try_into().expect('mul Overflow');
}
// for shifts greater than 255 bits, perform the shift manually
let two = One::one() + One::one();
self * two.pow(shift)
self * two.pow(shift.try_into().expect('mul Overflow'))
}

fn shr(self: T, shift: T) -> T {
fn shr(self: T, shift: usize) -> T {
// early return to save gas if shift > nb_bits of T
if shift > BitSize::<T>::bits().try_into().unwrap() - One::one() {
if shift > BitSize::<T>::bits() - One::one() {
panic_with_felt252('mul Overflow');
}
// use the POW_2 lookup table when the bit size
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
return self / (*POW_2_256.span().at(shift)).try_into().expect('mul Overflow');
}
let two = One::one() + One::one();
self / two.pow(shift)
self / two.pow(shift.try_into().expect('mul Overflow'))
}
}

Expand All @@ -270,7 +282,7 @@ pub trait WrappingBitshift<T> {
/// # Returns
///
/// The result of shifting `self` left by `shift` bits, wrapped if necessary
fn wrapping_shl(self: T, shift: T) -> T;
fn wrapping_shl(self: T, shift: usize) -> T;

/// Shift a number right by a given number of bits.
/// If the shift is greater than 255, the result is 0.
Expand All @@ -283,7 +295,7 @@ pub trait WrappingBitshift<T> {
/// # Returns
///
/// The result of shifting `self` right by `shift` bits, or 0 if shift > 255
fn wrapping_shr(self: T, shift: T) -> T;
fn wrapping_shr(self: T, shift: usize) -> T;
}

pub impl WrappingBitshiftImpl<
Expand All @@ -300,18 +312,31 @@ pub impl WrappingBitshiftImpl<
+OverflowingMul<T>,
+WrappingExponentiation<T>,
+BitSize<T>,
+Bounded<T>,
+Into<T, u256>,
+TryInto<usize, T>,
> of WrappingBitshift<T> {
fn wrapping_shl(self: T, shift: T) -> T {
fn wrapping_shl(self: T, shift: usize) -> T {
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
let pow_2: u256 = (*POW_2_256.span().at(shift));
let pow2_mod_t: u256 = pow_2 % Bounded::<T>::MAX.into();
let (result, _) = self.overflowing_mul(pow2_mod_t.try_into().unwrap());
return result;
}
let two = One::<T>::one() + One::<T>::one();
let (result, _) = self.overflowing_mul(two.wrapping_pow(shift));
result
}

fn wrapping_shr(self: T, shift: T) -> T {
fn wrapping_shr(self: T, shift: usize) -> T {
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
let pow_2: u256 = (*POW_2_256.span().at(shift));
let pow2_mod_t: u256 = pow_2 % Bounded::<T>::MAX.into();
return self / pow2_mod_t.try_into().unwrap();
}
let two = One::<T>::one() + One::<T>::one();

if shift > BitSize::<T>::bits().try_into().unwrap() - One::one() {
if shift > BitSize::<T>::bits() - One::one() {
return Zero::zero();
}
self / two.pow(shift)
Expand Down
19 changes: 6 additions & 13 deletions crates/utils/src/traits/bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub impl U8SpanExImpl of U8SpanExTrait {
Option::Some(byte) => {
let byte: u64 = (*byte.unbox()).into();
// Accumulate pending_word in a little endian manner
byte.shl(8_u64 * byte_counter.into())
byte.shl(8_u32 * byte_counter.into())
},
Option::None => { break; },
};
Expand All @@ -69,7 +69,7 @@ pub impl U8SpanExImpl of U8SpanExTrait {
last_input_word += match self.get(full_u64_word_count * 8 + byte_counter.into()) {
Option::Some(byte) => {
let byte: u64 = (*byte.unbox()).into();
byte.shl(8_u64 * byte_counter.into())
byte.shl(8_u32 * byte_counter.into())
},
Option::None => { break; },
};
Expand Down Expand Up @@ -246,17 +246,13 @@ pub impl ToBytesImpl<
fn to_be_bytes(self: T) -> Span<u8> {
let bytes_used = self.bytes_used();

let one = One::<T>::one();
let two = one + one;
let eight = two * two * two;

// 0xFF
let mask = Bounded::<u8>::MAX.into();

let mut bytes: Array<u8> = Default::default();
for i in 0
..bytes_used {
let val = Bitshift::<T>::shr(self, eight * (bytes_used - i - 1).into());
let val = Bitshift::<T>::shr(self, 8_u32 * (bytes_used.into() - i.into() - 1));
bytes.append((val & mask).try_into().unwrap());
};

Expand All @@ -270,9 +266,6 @@ pub impl ToBytesImpl<

fn to_le_bytes(mut self: T) -> Span<u8> {
let bytes_used = self.bytes_used();
let one = One::<T>::one();
let two = one + one;
let eight = two * two * two;

// 0xFF
let mask = Bounded::<u8>::MAX.into();
Expand All @@ -281,7 +274,7 @@ pub impl ToBytesImpl<

for i in 0
..bytes_used {
let val = self.shr(eight * i.into());
let val = self.shr(8_u32 * i.into());
bytes.append((val & mask).try_into().unwrap());
};

Expand Down Expand Up @@ -526,7 +519,7 @@ pub impl ByteArrayExt of ByteArrayExTrait {
Option::Some(byte) => {
let byte: u64 = byte.into();
// Accumulate pending_word in a little endian manner
byte.shl(8_u64 * byte_counter.into())
byte.shl(8_u32 * byte_counter.into())
},
Option::None => { break; },
};
Expand All @@ -546,7 +539,7 @@ pub impl ByteArrayExt of ByteArrayExTrait {
last_input_word += match self.at(full_u64_word_count * 8 + byte_counter.into()) {
Option::Some(byte) => {
let byte: u64 = byte.into();
byte.shl(8_u64 * byte_counter.into())
byte.shl(8_u32 * byte_counter.into())
},
Option::None => { break; },
};
Expand Down
Loading

0 comments on commit dbf3196

Please sign in to comment.