From 1440c69f2d7f92617b2157b33ca31a8eeefa7e96 Mon Sep 17 00:00:00 2001 From: Mathieu <60658558+enitrat@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:37:31 +0700 Subject: [PATCH] feat: store_byte memory method (#329) * feat: store_byte memory method * chore: fmt fixes * chore: fmt fixes * address pr comments * address pr review * rebase * fix: var name --- .../src/instructions/memory_operations.cairo | 3 +- crates/evm/src/memory.cairo | 32 ++- crates/evm/src/tests/test_memory.cairo | 79 ++++++ crates/utils/src/helpers.cairo | 263 ++++++++++++++++++ crates/utils/src/math.cairo | 72 ++++- scripts/compare_snapshot.py | 14 +- 6 files changed, 452 insertions(+), 11 deletions(-) diff --git a/crates/evm/src/instructions/memory_operations.cairo b/crates/evm/src/instructions/memory_operations.cairo index d9bd3b8f7..6a9323d7d 100644 --- a/crates/evm/src/instructions/memory_operations.cairo +++ b/crates/evm/src/instructions/memory_operations.cairo @@ -129,8 +129,7 @@ impl MemoryOperation of MemoryOperationTrait { let offset = self.stack.pop_usize()?; let value = self.stack.pop()?; let value: u8 = (value.low & 0xFF).try_into().unwrap(); - let values = array![value].span(); - self.memory.store_n(values, offset); + self.memory.store_byte(value, offset); Result::Ok(()) } diff --git a/crates/evm/src/memory.cairo b/crates/evm/src/memory.cairo index 6dba65a87..0bfe9d5f4 100644 --- a/crates/evm/src/memory.cairo +++ b/crates/evm/src/memory.cairo @@ -11,11 +11,10 @@ use utils::constants::{ use cmp::{max}; use utils::{ helpers, helpers::SpanExtensionTrait, helpers::ArrayExtensionTrait, math::Exponentiation, - math::WrappingExponentiation + math::WrappingExponentiation, math::Bitshift }; use debug::PrintTrait; - #[derive(Destruct, Default)] struct Memory { items: Felt252Dict, @@ -26,6 +25,7 @@ trait MemoryTrait { fn new() -> Memory; fn size(ref self: Memory) -> usize; fn store(ref self: Memory, element: u256, offset: usize); + fn store_byte(ref self: Memory, value: u8, offset: usize); fn store_n(ref self: Memory, elements: Span, offset: usize); fn store_padded_segment(ref self: Memory, offset: usize, length: usize, source: Span); fn ensure_length(ref self: Memory, length: usize); @@ -75,6 +75,34 @@ impl MemoryImpl of MemoryTrait { self.store_element(element, chunk_index, offset_in_chunk); } + + /// Stores a single byte into memory at a specified offset. + /// + /// # Arguments + /// + /// * `self` - A mutable reference to the `Memory` instance to store the byte in. + /// * `value` - The byte value to store in memory. + /// * `offset` - The offset within memory to store the byte at. + #[inline(always)] + fn store_byte(ref self: Memory, value: u8, offset: usize) { + let new_min_bytes_len = helpers::ceil_bytes_len_to_next_32_bytes_word(offset + 1); + self.bytes_len = cmp::max(new_min_bytes_len, self.bytes_len); + + // Get offset's memory word index and left-based offset of byte in word. + let (chunk_index, left_offset) = u32_safe_divmod(offset, u32_as_non_zero(16)); + + // As the memory words are in big-endian order, we need to convert our left-based offset + // to a right-based one.a + let right_offset = 15 - left_offset; + let mask: u128 = 0xFF * helpers::pow2(right_offset.into() * 8); + + // First erase byte value at offset, then set the new value using bitwise ops + let word: u128 = self.items.get(chunk_index.into()); + let new_word = (word & ~mask) | (value.into().shl(right_offset.into() * 8)); + self.items.insert(chunk_index.into(), new_word); + } + + /// Stores a span of N bytes into memory at a specified offset. /// /// This function checks the alignment of the offset to 16-byte chunks, and handles the special case where the bytes to be diff --git a/crates/evm/src/tests/test_memory.cairo b/crates/evm/src/tests/test_memory.cairo index 8aea20bf4..d543f2090 100644 --- a/crates/evm/src/tests/test_memory.cairo +++ b/crates/evm/src/tests/test_memory.cairo @@ -510,3 +510,82 @@ fn test_store_padded_segment_should_add_n_elements_padded_with_offset_between_tw 'Wrong memory value' ); } +#[test] +#[available_gas(20000000)] +fn test_store_byte_should_store_byte_at_offset() { + // Given + let mut memory = MemoryTrait::new(); + + // When + memory.store_byte(0x01, 15); + + // Then + assert(memory.items[0] == 0x01, 'Wrong value for word 0'); + assert(memory.items[1] == 0x00, 'Wrong value for word 1'); + assert(memory.bytes_len == 32, 'Wrong memory length'); +} +#[test] +#[available_gas(20000000)] +fn test_store_byte_should_store_byte_at_offset_2() { + // Given + let mut memory = MemoryTrait::new(); + + // When + memory.store_byte(0xff, 14); + + // Then + assert(memory.items[0] == 0xff00, 'Wrong value for word 0'); + assert(memory.items[1] == 0x00, 'Wrong value for word 1'); + assert(memory.bytes_len == 32, 'Wrong memory length'); +} + +#[test] +#[available_gas(20000000)] +fn test_store_byte_should_store_byte_at_offset_in_existing_word() { + // Given + let mut memory = MemoryTrait::new(); + memory.items.insert(0, 0xFFFF); // Set the first word in memory to 0xFFFF; + memory.items.insert(1, 0xFFFF); + + // When + memory.store_byte(0x01, 30); + + // Then + assert(memory.items[0] == 0xFFFF, 'Wrong value for word 0'); + assert(memory.items[1] == 0x01FF, 'Wrong value for word 1'); + assert(memory.bytes_len == 32, 'Wrong memory length'); +} + +#[test] +#[available_gas(20000000)] +fn test_store_byte_should_store_byte_at_offset_in_new_word() { + // Given + let mut memory = MemoryTrait::new(); + + // When + memory.store_byte(0x01, 32); + + // Then + assert(memory.items[0] == 0x0, 'Wrong value for word 0'); + assert(memory.items[1] == 0x0, 'Wrong value for word 1'); + assert(memory.items[2] == 0x01000000000000000000000000000000, 'Wrong value for word 2'); + assert(memory.bytes_len == 64, 'Wrong memory length'); +} + +#[test] +#[available_gas(20000000)] +fn test_store_byte_should_store_byte_at_offset_in_new_word_with_existing_value_in_previous_word() { + // Given + let mut memory = MemoryTrait::new(); + memory.items.insert(0, 0x0100); + memory.items.insert(1, 0xffffffffffffffffffffffffffffffff); + + // When + memory.store_byte(0xAB, 17); + + // Then + assert(memory.items[0] == 0x0100, 'Wrong value in word 0'); + assert(memory.items[1] == 0xffABffffffffffffffffffffffffffff, 'Wrong value in word 1'); + assert(memory.bytes_len == 32, 'Wrong memory length'); +} + diff --git a/crates/utils/src/helpers.cairo b/crates/utils/src/helpers.cairo index 73161f9fe..ac0b2d098 100644 --- a/crates/utils/src/helpers.cairo +++ b/crates/utils/src/helpers.cairo @@ -69,6 +69,269 @@ fn pow256_rev(i: usize) -> u256 { } } +// Computes 2**pow for 0 <= pow < 128. +fn pow2(pow: usize) -> u128 { + if pow == 0 { + return 0x1; + } else if pow == 1 { + return 0x2; + } else if pow == 2 { + return 0x4; + } else if pow == 3 { + return 0x8; + } else if pow == 4 { + return 0x10; + } else if pow == 5 { + return 0x20; + } else if pow == 6 { + return 0x40; + } else if pow == 7 { + return 0x80; + } else if pow == 8 { + return 0x100; + } else if pow == 9 { + return 0x200; + } else if pow == 10 { + return 0x400; + } else if pow == 11 { + return 0x800; + } else if pow == 12 { + return 0x1000; + } else if pow == 13 { + return 0x2000; + } else if pow == 14 { + return 0x4000; + } else if pow == 15 { + return 0x8000; + } else if pow == 16 { + return 0x10000; + } else if pow == 17 { + return 0x20000; + } else if pow == 18 { + return 0x40000; + } else if pow == 19 { + return 0x80000; + } else if pow == 20 { + return 0x100000; + } else if pow == 21 { + return 0x200000; + } else if pow == 22 { + return 0x400000; + } else if pow == 23 { + return 0x800000; + } else if pow == 24 { + return 0x1000000; + } else if pow == 25 { + return 0x2000000; + } else if pow == 26 { + return 0x4000000; + } else if pow == 27 { + return 0x8000000; + } else if pow == 28 { + return 0x10000000; + } else if pow == 29 { + return 0x20000000; + } else if pow == 30 { + return 0x40000000; + } else if pow == 31 { + return 0x80000000; + } else if pow == 32 { + return 0x100000000; + } else if pow == 33 { + return 0x200000000; + } else if pow == 34 { + return 0x400000000; + } else if pow == 35 { + return 0x800000000; + } else if pow == 36 { + return 0x1000000000; + } else if pow == 37 { + return 0x2000000000; + } else if pow == 38 { + return 0x4000000000; + } else if pow == 39 { + return 0x8000000000; + } else if pow == 40 { + return 0x10000000000; + } else if pow == 41 { + return 0x20000000000; + } else if pow == 42 { + return 0x40000000000; + } else if pow == 43 { + return 0x80000000000; + } else if pow == 44 { + return 0x100000000000; + } else if pow == 45 { + return 0x200000000000; + } else if pow == 46 { + return 0x400000000000; + } else if pow == 47 { + return 0x800000000000; + } else if pow == 48 { + return 0x1000000000000; + } else if pow == 49 { + return 0x2000000000000; + } else if pow == 50 { + return 0x4000000000000; + } else if pow == 51 { + return 0x8000000000000; + } else if pow == 52 { + return 0x10000000000000; + } else if pow == 53 { + return 0x20000000000000; + } else if pow == 54 { + return 0x40000000000000; + } else if pow == 55 { + return 0x80000000000000; + } else if pow == 56 { + return 0x100000000000000; + } else if pow == 57 { + return 0x200000000000000; + } else if pow == 58 { + return 0x400000000000000; + } else if pow == 59 { + return 0x800000000000000; + } else if pow == 60 { + return 0x1000000000000000; + } else if pow == 61 { + return 0x2000000000000000; + } else if pow == 62 { + return 0x4000000000000000; + } else if pow == 63 { + return 0x8000000000000000; + } else if pow == 64 { + return 0x10000000000000000; + } else if pow == 65 { + return 0x20000000000000000; + } else if pow == 66 { + return 0x40000000000000000; + } else if pow == 67 { + return 0x80000000000000000; + } else if pow == 68 { + return 0x100000000000000000; + } else if pow == 69 { + return 0x200000000000000000; + } else if pow == 70 { + return 0x400000000000000000; + } else if pow == 71 { + return 0x800000000000000000; + } else if pow == 72 { + return 0x1000000000000000000; + } else if pow == 73 { + return 0x2000000000000000000; + } else if pow == 74 { + return 0x4000000000000000000; + } else if pow == 75 { + return 0x8000000000000000000; + } else if pow == 76 { + return 0x10000000000000000000; + } else if pow == 77 { + return 0x20000000000000000000; + } else if pow == 78 { + return 0x40000000000000000000; + } else if pow == 79 { + return 0x80000000000000000000; + } else if pow == 80 { + return 0x100000000000000000000; + } else if pow == 81 { + return 0x200000000000000000000; + } else if pow == 82 { + return 0x400000000000000000000; + } else if pow == 83 { + return 0x800000000000000000000; + } else if pow == 84 { + return 0x1000000000000000000000; + } else if pow == 85 { + return 0x2000000000000000000000; + } else if pow == 86 { + return 0x4000000000000000000000; + } else if pow == 87 { + return 0x8000000000000000000000; + } else if pow == 88 { + return 0x10000000000000000000000; + } else if pow == 89 { + return 0x20000000000000000000000; + } else if pow == 90 { + return 0x40000000000000000000000; + } else if pow == 91 { + return 0x80000000000000000000000; + } else if pow == 92 { + return 0x100000000000000000000000; + } else if pow == 93 { + return 0x200000000000000000000000; + } else if pow == 94 { + return 0x400000000000000000000000; + } else if pow == 95 { + return 0x800000000000000000000000; + } else if pow == 96 { + return 0x1000000000000000000000000; + } else if pow == 97 { + return 0x2000000000000000000000000; + } else if pow == 98 { + return 0x4000000000000000000000000; + } else if pow == 99 { + return 0x8000000000000000000000000; + } else if pow == 100 { + return 0x10000000000000000000000000; + } else if pow == 101 { + return 0x20000000000000000000000000; + } else if pow == 102 { + return 0x40000000000000000000000000; + } else if pow == 103 { + return 0x80000000000000000000000000; + } else if pow == 104 { + return 0x100000000000000000000000000; + } else if pow == 105 { + return 0x200000000000000000000000000; + } else if pow == 106 { + return 0x400000000000000000000000000; + } else if pow == 107 { + return 0x800000000000000000000000000; + } else if pow == 108 { + return 0x1000000000000000000000000000; + } else if pow == 109 { + return 0x2000000000000000000000000000; + } else if pow == 110 { + return 0x4000000000000000000000000000; + } else if pow == 111 { + return 0x8000000000000000000000000000; + } else if pow == 112 { + return 0x10000000000000000000000000000; + } else if pow == 113 { + return 0x20000000000000000000000000000; + } else if pow == 114 { + return 0x40000000000000000000000000000; + } else if pow == 115 { + return 0x80000000000000000000000000000; + } else if pow == 116 { + return 0x100000000000000000000000000000; + } else if pow == 117 { + return 0x200000000000000000000000000000; + } else if pow == 118 { + return 0x400000000000000000000000000000; + } else if pow == 119 { + return 0x800000000000000000000000000000; + } else if pow == 120 { + return 0x1000000000000000000000000000000; + } else if pow == 121 { + return 0x2000000000000000000000000000000; + } else if pow == 122 { + return 0x4000000000000000000000000000000; + } else if pow == 123 { + return 0x8000000000000000000000000000000; + } else if pow == 124 { + return 0x10000000000000000000000000000000; + } else if pow == 125 { + return 0x20000000000000000000000000000000; + } else if pow == 126 { + return 0x40000000000000000000000000000000; + } else if pow == 127 { + return 0x80000000000000000000000000000000; + } else { + return panic_with_felt252('pow2: pow >= 128'); + } +} + /// Splits a u256 into `len` bytes, big-endian, and appends the result to `dst`. fn split_word(mut value: u256, mut len: usize, ref dst: Array) { diff --git a/crates/utils/src/math.cairo b/crates/utils/src/math.cairo index 808a37541..a148b8b50 100644 --- a/crates/utils/src/math.cairo +++ b/crates/utils/src/math.cairo @@ -1,4 +1,6 @@ -use integer::{u256, u256_overflow_mul, u256_overflowing_add, u512, BoundedInt}; +use integer::{ + u256, u256_overflow_mul, u256_overflowing_add, u512, BoundedInt, u128_overflowing_mul +}; trait Exponentiation { /// Raise a number to a power. @@ -46,6 +48,37 @@ impl U256WrappingExponentiationImpl of WrappingExponentiation { } } +impl U128ExpImpl of Exponentiation { + fn pow(self: u128, mut exponent: u128) -> u128 { + if self == 0 { + return 0; + } + if exponent == 0 { + return 1; + } else { + return self * Exponentiation::pow(self, exponent - 1); + } + } +} + +impl U128WrappingExponentiationImpl of WrappingExponentiation { + fn wrapping_pow(self: u128, mut exponent: u128) -> u128 { + if self == 0 { + return 0; + } + let mut result = 1; + loop { + if exponent == 0 { + break; + } + let (new_result, _) = u128_overflowing_mul(result, self); + result = new_result; + exponent -= 1; + }; + result + } +} + impl Felt252WrappingExpImpl of WrappingExponentiation { fn wrapping_pow(self: felt252, mut exponent: felt252) -> felt252 { @@ -135,7 +168,7 @@ impl Felt252WrappingBitshiftImpl of WrappingBitshift { let val: u256 = self.into(); let shift: u256 = shift.into(); - // early return to save gas if shift > 255 + // early return to save gas if shift > 255 if shift > 255 { return 0; } @@ -163,3 +196,38 @@ impl U256WrappingBitshiftImpl of WrappingBitshift { self / 2.pow(shift) } } + +impl U128BitshiftImpl of Bitshift { + fn shl(self: u128, shift: u128) -> u128 { + if shift > 127 { + // 2.pow(shift) for shift > 255 will panic with 'u128_mul Overflow' + panic_with_felt252('u128_mul Overflow'); + } + self * 2.pow(shift) + } + + fn shr(self: u128, shift: u128) -> u128 { + if shift > 127 { + return 0; + } + self / 2.pow(shift) + } +} + +impl U128WrappingBitshiftImpl of WrappingBitshift { + fn wrapping_shl(self: u128, shift: u128) -> u128 { + let (result, _) = u128_overflowing_mul(self, 2.wrapping_pow(shift)); + result + } + + fn wrapping_shr(self: u128, shift: u128) -> u128 { + // if we shift by more than 255 bits, the result is 0 (the type is 128 bits wide) + // we early return to save gas + // and prevent unexpected behavior, e.g. 2.pow(128) == 0 mod 2^128, given we can't divide by zero + if shift > 127 { + return 0; + } + self / 2.pow(shift) + } +} + diff --git a/scripts/compare_snapshot.py b/scripts/compare_snapshot.py index 0af7db230..b7672013c 100644 --- a/scripts/compare_snapshot.py +++ b/scripts/compare_snapshot.py @@ -106,10 +106,9 @@ def compare_snapshots(current, previous): """Compare current and previous snapshots and return differences.""" worsened = [] improvements = [] + common_keys = set(current.keys()) & set(previous.keys()) - for key in previous: - if key not in current: - continue + for key in common_keys: prev = previous[key] cur = current[key] percentage_change = (cur - prev) * 100 / prev @@ -153,8 +152,13 @@ def print_colored_output(improvements, worsened, gas_changes): def total_gas_used(current, previous): - """Return the total gas used in the current and previous snapshot.""" - return sum(current.values()), sum(previous.values()) + """Return the total gas used in the current and previous snapshot, not taking into account added tests.""" + common_keys = set(current.keys()) & set(previous.keys()) + + cur_gas = sum(current[key] for key in common_keys) + prev_gas = sum(previous[key] for key in common_keys) + + return cur_gas, prev_gas def main():