Skip to content

Commit

Permalink
[perf] Store blocks in a potentially smarter way (#1335)
Browse files Browse the repository at this point in the history
* Store blocks in a potentially smarter way

* Merge smarter

* Revert "Merge smarter"

This reverts commit 9481f0d.

* Minor change

* Fix
  • Loading branch information
Golovanov399 authored Feb 4, 2025
1 parent 43c4ccd commit 57d11cb
Showing 1 changed file with 113 additions and 66 deletions.
179 changes: 113 additions & 66 deletions crates/vm/src/system/memory/offline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,92 @@ struct BlockData {
size: usize,
}

struct BlockMap {
/// Block ids. 0 is a special value standing for the default block.
id: AddressMap<usize, PAGE_SIZE>,
/// The place where non-default blocks are stored.
storage: Vec<BlockData>,
initial_block_size: usize,
}

impl BlockMap {
pub fn from_mem_config(mem_config: &MemoryConfig, initial_block_size: usize) -> Self {
assert!(initial_block_size.is_power_of_two());
Self {
id: AddressMap::from_mem_config(mem_config),
storage: vec![],
initial_block_size,
}
}

fn initial_block_data(pointer: u32, initial_block_size: usize) -> BlockData {
let aligned_pointer = (pointer / initial_block_size as u32) * initial_block_size as u32;
BlockData {
pointer: aligned_pointer,
size: initial_block_size,
timestamp: INITIAL_TIMESTAMP,
}
}

pub fn get_without_adding(&self, address: &(u32, u32)) -> BlockData {
let idx = self.id.get(address).unwrap_or(&0);
if idx == &0 {
Self::initial_block_data(address.1, self.initial_block_size)
} else {
self.storage[idx - 1].clone()
}
}

pub fn get(&mut self, address: &(u32, u32)) -> &BlockData {
let (address_space, pointer) = *address;
let idx = self.id.get(&(address_space, pointer)).unwrap_or(&0);
if idx == &0 {
// `initial_block_size` is a power of two, as asserted in `from_mem_config`.
let pointer = pointer & !(self.initial_block_size as u32 - 1);
self.set_range(
&(address_space, pointer),
self.initial_block_size,
Self::initial_block_data(pointer, self.initial_block_size),
);
self.storage.last().unwrap()
} else {
&self.storage[idx - 1]
}
}

pub fn get_mut(&mut self, address: &(u32, u32)) -> &mut BlockData {
let (address_space, pointer) = *address;
let idx = self.id.get(&(address_space, pointer)).unwrap_or(&0);
if idx == &0 {
let pointer = pointer - pointer % self.initial_block_size as u32;
self.set_range(
&(address_space, pointer),
self.initial_block_size,
Self::initial_block_data(pointer, self.initial_block_size),
);
self.storage.last_mut().unwrap()
} else {
&mut self.storage[idx - 1]
}
}

pub fn set_range(&mut self, address: &(u32, u32), len: usize, block: BlockData) {
let (address_space, pointer) = address;
self.storage.push(block);
for i in 0..len {
self.id
.insert(&(*address_space, pointer + i as u32), self.storage.len());
}
}

pub fn items(&self) -> impl Iterator<Item = ((u32, u32), &BlockData)> + '_ {
self.id
.items()
.filter(|(_, idx)| *idx > 0)
.map(|(address, idx)| (address, &self.storage[idx - 1]))
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct MemoryRecord<T> {
pub address_space: T,
Expand Down Expand Up @@ -53,10 +139,9 @@ impl<T: Copy> MemoryRecord<T> {
}

pub struct OfflineMemory<F> {
block_data: AddressMap<BlockData, PAGE_SIZE>,
block_data: BlockMap,
data: Vec<PagedVec<F, PAGE_SIZE>>,
as_offset: u32,
initial_block_size: usize,
timestamp: u32,
timestamp_max_bits: usize,

Expand All @@ -77,13 +162,10 @@ impl<F: PrimeField32> OfflineMemory<F> {
range_checker: SharedVariableRangeCheckerChip,
config: MemoryConfig,
) -> Self {
assert!(initial_block_size.is_power_of_two());

Self {
block_data: AddressMap::from_mem_config(&config),
block_data: BlockMap::from_mem_config(&config, initial_block_size),
data: Self::memory_image_to_paged_vec(initial_memory, config),
as_offset: config.as_offset,
initial_block_size,
timestamp: INITIAL_TIMESTAMP + 1,
timestamp_max_bits: config.clk_max_bits,
memory_bus,
Expand Down Expand Up @@ -231,18 +313,15 @@ impl<F: PrimeField32> OfflineMemory<F> {
.collect();

for &(address_space, pointer) in to_access.iter() {
let block = self.block_data.get(&(address_space, pointer)).unwrap();
if block.size > 0 && (block.pointer != pointer || block.size != N) {
let block = self.block_data.get(&(address_space, pointer));
if block.pointer != pointer || block.size != N {
self.access(address_space, pointer, N, adapter_records);
}
}

let mut equipartition = TimestampedEquipartition::<F, N>::new();
for (address_space, pointer) in to_access {
let block = self.block_data.get(&(address_space, pointer)).unwrap();
if block.size == 0 {
continue;
}
let block = self.block_data.get(&(address_space, pointer));

debug_assert_eq!(block.pointer % N as u32, 0);
debug_assert_eq!(block.size, N);
Expand Down Expand Up @@ -304,10 +383,8 @@ impl<F: PrimeField32> OfflineMemory<F> {
size: half_size,
timestamp,
};
for i in 0..half_size_u32 {
self.block_data
.insert(&(address_space, mid_ptr + i), block.clone());
}
self.block_data
.set_range(&(address_space, mid_ptr), half_size, block);
}
if query >= cur_ptr + half_size_u32 {
// The left is finalized; add it to the partition.
Expand All @@ -316,10 +393,8 @@ impl<F: PrimeField32> OfflineMemory<F> {
size: half_size,
timestamp,
};
for i in 0..half_size_u32 {
self.block_data
.insert(&(address_space, cur_ptr + i), block.clone());
}
self.block_data
.set_range(&(address_space, cur_ptr), half_size, block);
}
if mid_ptr <= query {
cur_ptr = mid_ptr;
Expand All @@ -342,21 +417,13 @@ impl<F: PrimeField32> OfflineMemory<F> {

let mut prev_timestamp = None;

for i in 0..size as u32 {
let block = self.block_data.get(&(address_space, pointer + i));
if block.is_none() || block.unwrap().size == 0 {
self.block_data.insert(
&(address_space, pointer + i),
Self::initial_block_data(pointer + i, self.initial_block_size),
);
}
let block = self
.block_data
.get_mut(&(address_space, pointer + i))
.unwrap();
let mut i = 0;
while i < size as u32 {
let block = self.block_data.get_mut(&(address_space, pointer + i));
debug_assert!(i == 0 || prev_timestamp == Some(block.timestamp));
prev_timestamp = Some(block.timestamp);
block.timestamp = self.timestamp;
i = block.pointer + block.size as u32;
}
prev_timestamp.unwrap()
}
Expand Down Expand Up @@ -403,31 +470,24 @@ impl<F: PrimeField32> OfflineMemory<F> {
) {
let left_block = self.block_data.get(&(address_space, pointer));

let left_timestamp = left_block.map(|b| b.timestamp).unwrap_or(INITIAL_TIMESTAMP);
let mut size = left_block
.map(|b| b.size)
.unwrap_or(self.initial_block_size);
if size == 0 {
size = self.initial_block_size;
}
let left_timestamp = left_block.timestamp;
let size = left_block.size;

let right_timestamp = self
.block_data
.get(&(address_space, pointer + size as u32))
.map(|b| b.timestamp)
.unwrap_or(INITIAL_TIMESTAMP);
.timestamp;

let timestamp = max(left_timestamp, right_timestamp);
for i in 0..2 * size as u32 {
self.block_data.insert(
&(address_space, pointer + i),
BlockData {
pointer,
size: 2 * size,
timestamp,
},
);
}
self.block_data.set_range(
&(address_space, pointer),
2 * size,
BlockData {
pointer,
size: 2 * size,
timestamp,
},
);
records.add_record(AccessAdapterRecord {
timestamp,
address_space: F::from_canonical_u32(address_space),
Expand All @@ -441,21 +501,8 @@ impl<F: PrimeField32> OfflineMemory<F> {
}

fn block_containing(&mut self, address_space: u32, pointer: u32) -> BlockData {
if let Some(block_data) = self.block_data.get(&(address_space, pointer)) {
if block_data.size > 0 {
return block_data.clone();
}
}
Self::initial_block_data(pointer, self.initial_block_size)
}

fn initial_block_data(pointer: u32, initial_block_size: usize) -> BlockData {
let aligned_pointer = (pointer / initial_block_size as u32) * initial_block_size as u32;
BlockData {
pointer: aligned_pointer,
size: initial_block_size,
timestamp: INITIAL_TIMESTAMP,
}
self.block_data
.get_without_adding(&(address_space, pointer))
}

pub fn get(&self, address_space: u32, pointer: u32) -> F {
Expand Down

0 comments on commit 57d11cb

Please sign in to comment.