From ba4d98cb2e2407c88927282d767b7cb034661ca0 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Fri, 7 Jul 2023 02:06:31 +0200 Subject: [PATCH 01/18] feat(stdlib): Allow arbitrary length sha256 hashes You can now hash arbitrary length input if you have done the padding manually. Execute prepare_message_schedule_and_consume once for every 512-bit block. ``` use.std::crypto::hashes::sha256 begin push.0x5be0cd19.0x1f83d9ab.0x9b05688c.0x510e527f push.0xa54ff53a.0x3c6ef372.0xbb67ae85.0x6a09e667 exec.sha256::prepare_message_schedule_and_consume exec.sha256::prepare_message_schedule_and_consume end ``` --- stdlib/asm/crypto/hashes/sha256.masm | 50 ++++++++-------------------- stdlib/docs/crypto/hashes/sha256.md | 1 + 2 files changed, 14 insertions(+), 37 deletions(-) diff --git a/stdlib/asm/crypto/hashes/sha256.masm b/stdlib/asm/crypto/hashes/sha256.masm index 421e718943..92ccc82357 100644 --- a/stdlib/asm/crypto/hashes/sha256.masm +++ b/stdlib/asm/crypto/hashes/sha256.masm @@ -232,11 +232,13 @@ end #! - state0 through state7 are the hash state (in terms of 8 SHA256 words) #! - msg0 through msg15 are the 64 -bytes input message (in terms of 16 SHA256 words) #! See https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2.hpp#L89-L113 -#! & https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2_256.hpp#L148-L187 ( loop body execution when i = 0 ) -proc.prepare_message_schedule_and_consume.2 +#! & https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2_256.hpp#L148-L187 ( loop body execution ) +export.prepare_message_schedule_and_consume.4 loc_storew.0 + loc_storew.2 dropw loc_storew.1 + loc_storew.3 dropw dup.15 @@ -1054,43 +1056,17 @@ proc.prepare_message_schedule_and_consume.2 movdn.8 exec.consume_message_word # consume msg[63] - push.0x6a09e667 - u32wrapping_add - - swap - push.0xbb67ae85 - u32wrapping_add - swap - - movup.2 - push.0x3c6ef372 - u32wrapping_add - movdn.2 - - movup.3 - push.0xa54ff53a - u32wrapping_add - movdn.3 - - movup.4 - push.0x510e527f - u32wrapping_add - movdn.4 - - movup.5 - push.0x9b05688c - u32wrapping_add - movdn.5 + push.0.0.0.0 + loc_loadw.3 - movup.6 - push.0x1f83d9ab - u32wrapping_add - movdn.6 + push.0.0.0.0 + loc_loadw.2 - movup.7 - push.0x5be0cd19 - u32wrapping_add - movdn.7 + repeat.8 + movup.8 + u32wrapping_add + movdn.7 + end end #! Consumes precomputed message schedule of padding bytes into hash state, returns final hash state. diff --git a/stdlib/docs/crypto/hashes/sha256.md b/stdlib/docs/crypto/hashes/sha256.md index 9d485a62e1..0d2ca114c5 100644 --- a/stdlib/docs/crypto/hashes/sha256.md +++ b/stdlib/docs/crypto/hashes/sha256.md @@ -2,5 +2,6 @@ ## std::crypto::hashes::sha256 | Procedure | Description | | ----------- | ------------- | +| prepare_message_schedule_and_consume | Computes whole message schedule of 64 message words and consumes them into hash state.

Input: [state0, state1, state2, state3, state4, state5, state6, state7, msg0, msg1, msg2, msg3, msg4, msg5, msg6, msg7, msg8, msg9, msg10, msg11, msg12, msg13, msg14, msg15]

Output: [state0', state1', state2', state3', state4', state5', state6', state7']

Where:

- state0 through state7 are the hash state (in terms of 8 SHA256 words)

- msg0 through msg15 are the 64 -bytes input message (in terms of 16 SHA256 words)

See https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2.hpp#L89-L113

& https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2_256.hpp#L148-L187 ( loop body execution ) | | hash_2to1 | Given 64 -bytes input, this routine computes 32 -bytes SHA256 digest

Input: [m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...]

Where: m[0,16) = 32 -bit word

Note, each SHA256 word is 32 -bit wide, so that's how input is expected.

As you've 64 -bytes, consider packing 4 consecutive bytes into single word,

maintaining big endian byte order.

SHA256 digest is represented in terms of eight 32 -bit words ( big endian byte order ). | | hash_1to1 | Given 32 -bytes input, this routine computes 32 -bytes SHA256 digest

Expected stack state:

Input: [m0, m1, m2, m3, m4, m5, m6, m7, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...]

Where: m[0,8) = 32 -bit word

Note, each SHA256 word is 32 -bit wide, so that's how input is expected.

As you've 32 -bytes, consider packing 4 consecutive bytes into single word,

maintaining big endian byte order.

SHA256 digest is represented in terms of eight 32 -bit words ( big endian byte order ). | From 98aa1cc5a8b7649028d4a8398523601989d7a069 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Mon, 31 Jul 2023 15:32:16 +0200 Subject: [PATCH 02/18] feat(stdlib): Add sha256::hash_memory(addr, len) --- stdlib/asm/crypto/hashes/sha256.masm | 62 ++++++++++++++++++++++++++++ stdlib/docs/crypto/hashes/sha256.md | 1 + 2 files changed, 63 insertions(+) diff --git a/stdlib/asm/crypto/hashes/sha256.masm b/stdlib/asm/crypto/hashes/sha256.masm index 92ccc82357..8bcabbedf7 100644 --- a/stdlib/asm/crypto/hashes/sha256.masm +++ b/stdlib/asm/crypto/hashes/sha256.masm @@ -1555,3 +1555,65 @@ export.hash_1to1 exec.prepare_message_schedule_and_consume end + +#! Given a memory address and a message length in bytes, compute its sha256 digest +#! +#! - There must be space for writing the padding after the message in memory +#! - The padding space after the message must be all zeros before this procedure is called +#! +#! Input: [addr, len, ...] +#! Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...] +export.hash_memory.12 + # loc.0 (input address) + loc_store.0 + + # loc.1 (input length) + loc_store.1 + + # loc.2 (padded length): len(x) + (55 - len(x)) % 64 + 9 + push.55 loc_load.1 u32wrapping_sub push.63 u32checked_and + loc_load.1 u32checked_add u32checked_add.9 loc_store.2 + + # loc.3 (last u32): addr + newlen/16 - 1 + loc_load.2 u32checked_div.16 loc_load.0 u32wrapping_add u32wrapping_sub.1 loc_store.3 + + # loc.4 (u32 aligned padding byte): 0x80000000 >> ((len % 4) * 8) + loc_load.1 u32checked_mod.4 u32checked_mul.8 push.0x80000000 swap u32checked_shr loc_store.4 + + # loc.5 (memory offset of first padding byte): (len / 4) % 4 + loc_load.1 u32checked_div.4 u32checked_mod.4 loc_store.5 + + # loc.6 (memory address of first padding byte): addr + (len / 16) + loc_load.0 loc_load.1 u32checked_div.16 u32checked_add loc_store.6 + + # loc.7 (number of remaining 512-bit blocks to consume): padded_length // 64 + loc_load.2 u32checked_div.64 loc_store.7 + + # Set the first byte after the message to 0x80 + padw loc_load.6 mem_loadw loc_store.8 loc_store.9 loc_store.10 loc_store.11 + locaddr.8 loc_load.5 u32wrapping_add dup mem_load loc_load.4 u32wrapping_add swap mem_store + loc_load.11 loc_load.10 loc_load.9 loc_load.8 loc_load.6 mem_storew dropw + + # Set message bit length at end of padding + padw loc_load.3 mem_loadw + movup.3 drop loc_load.1 u32checked_mul.8 movdn.3 + loc_load.3 mem_storew dropw + + # Sha256 init + push.0x5be0cd19.0x1f83d9ab.0x9b05688c.0x510e527f + push.0xa54ff53a.0x3c6ef372.0xbb67ae85.0x6a09e667 + + # Consume sha256 blocks + loc_load.7 u32checked_neq.0 + while.true + padw loc_load.0 u32checked_add.3 mem_loadw movdnw.2 + padw loc_load.0 u32checked_add.2 mem_loadw movdnw.2 + padw loc_load.0 u32checked_add.1 mem_loadw movdnw.2 + padw loc_load.0 u32checked_add.0 mem_loadw movdnw.2 + exec.prepare_message_schedule_and_consume + + loc_load.0 u32checked_add.4 loc_store.0 + loc_load.7 u32checked_sub.1 dup loc_store.7 + u32checked_neq.0 + end +end diff --git a/stdlib/docs/crypto/hashes/sha256.md b/stdlib/docs/crypto/hashes/sha256.md index 0d2ca114c5..22e986ea0d 100644 --- a/stdlib/docs/crypto/hashes/sha256.md +++ b/stdlib/docs/crypto/hashes/sha256.md @@ -5,3 +5,4 @@ | prepare_message_schedule_and_consume | Computes whole message schedule of 64 message words and consumes them into hash state.

Input: [state0, state1, state2, state3, state4, state5, state6, state7, msg0, msg1, msg2, msg3, msg4, msg5, msg6, msg7, msg8, msg9, msg10, msg11, msg12, msg13, msg14, msg15]

Output: [state0', state1', state2', state3', state4', state5', state6', state7']

Where:

- state0 through state7 are the hash state (in terms of 8 SHA256 words)

- msg0 through msg15 are the 64 -bytes input message (in terms of 16 SHA256 words)

See https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2.hpp#L89-L113

& https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2_256.hpp#L148-L187 ( loop body execution ) | | hash_2to1 | Given 64 -bytes input, this routine computes 32 -bytes SHA256 digest

Input: [m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...]

Where: m[0,16) = 32 -bit word

Note, each SHA256 word is 32 -bit wide, so that's how input is expected.

As you've 64 -bytes, consider packing 4 consecutive bytes into single word,

maintaining big endian byte order.

SHA256 digest is represented in terms of eight 32 -bit words ( big endian byte order ). | | hash_1to1 | Given 32 -bytes input, this routine computes 32 -bytes SHA256 digest

Expected stack state:

Input: [m0, m1, m2, m3, m4, m5, m6, m7, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...]

Where: m[0,8) = 32 -bit word

Note, each SHA256 word is 32 -bit wide, so that's how input is expected.

As you've 32 -bytes, consider packing 4 consecutive bytes into single word,

maintaining big endian byte order.

SHA256 digest is represented in terms of eight 32 -bit words ( big endian byte order ). | +| hash_memory | Given a memory address and a message length in bytes, compute its sha256 digest

- There must be space for writing the padding after the message in memory

- The padding space after the message must be all zeros

Input: [addr, len, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...] | From 69a50f1c5b73e1a925a4163785f36ecef786f8b0 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Mon, 31 Jul 2023 16:57:29 +0200 Subject: [PATCH 03/18] feat(stdlib): Improve comments in sha256::hash_memory --- stdlib/asm/crypto/hashes/sha256.masm | 14 +++++++------- stdlib/docs/crypto/hashes/sha256.md | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/stdlib/asm/crypto/hashes/sha256.masm b/stdlib/asm/crypto/hashes/sha256.masm index 8bcabbedf7..1c8fd90cab 100644 --- a/stdlib/asm/crypto/hashes/sha256.masm +++ b/stdlib/asm/crypto/hashes/sha256.masm @@ -1570,23 +1570,23 @@ export.hash_memory.12 # loc.1 (input length) loc_store.1 - # loc.2 (padded length): len(x) + (55 - len(x)) % 64 + 9 + # loc.2 (padded length): input_length + (55 - input_length) % 64 + 9 push.55 loc_load.1 u32wrapping_sub push.63 u32checked_and loc_load.1 u32checked_add u32checked_add.9 loc_store.2 - # loc.3 (last u32): addr + newlen/16 - 1 + # loc.3 (last memory address in padding): input_address + padded_length / 16 - 1 loc_load.2 u32checked_div.16 loc_load.0 u32wrapping_add u32wrapping_sub.1 loc_store.3 - # loc.4 (u32 aligned padding byte): 0x80000000 >> ((len % 4) * 8) + # loc.4 (u32 aligned padding byte): 0x80000000 >> ((input_length % 4) * 8) loc_load.1 u32checked_mod.4 u32checked_mul.8 push.0x80000000 swap u32checked_shr loc_store.4 - # loc.5 (memory offset of first padding byte): (len / 4) % 4 + # loc.5 (memory offset of first padding byte): (input_length / 4) % 4 loc_load.1 u32checked_div.4 u32checked_mod.4 loc_store.5 - # loc.6 (memory address of first padding byte): addr + (len / 16) + # loc.6 (memory address of first padding byte): input_address + (len / 16) loc_load.0 loc_load.1 u32checked_div.16 u32checked_add loc_store.6 - # loc.7 (number of remaining 512-bit blocks to consume): padded_length // 64 + # loc.7 (number of remaining 512-bit blocks to consume): padded_length / 64 loc_load.2 u32checked_div.64 loc_store.7 # Set the first byte after the message to 0x80 @@ -1594,7 +1594,7 @@ export.hash_memory.12 locaddr.8 loc_load.5 u32wrapping_add dup mem_load loc_load.4 u32wrapping_add swap mem_store loc_load.11 loc_load.10 loc_load.9 loc_load.8 loc_load.6 mem_storew dropw - # Set message bit length at end of padding + # Set message length in bits at end of padding padw loc_load.3 mem_loadw movup.3 drop loc_load.1 u32checked_mul.8 movdn.3 loc_load.3 mem_storew dropw diff --git a/stdlib/docs/crypto/hashes/sha256.md b/stdlib/docs/crypto/hashes/sha256.md index 22e986ea0d..e45dd6359b 100644 --- a/stdlib/docs/crypto/hashes/sha256.md +++ b/stdlib/docs/crypto/hashes/sha256.md @@ -5,4 +5,4 @@ | prepare_message_schedule_and_consume | Computes whole message schedule of 64 message words and consumes them into hash state.

Input: [state0, state1, state2, state3, state4, state5, state6, state7, msg0, msg1, msg2, msg3, msg4, msg5, msg6, msg7, msg8, msg9, msg10, msg11, msg12, msg13, msg14, msg15]

Output: [state0', state1', state2', state3', state4', state5', state6', state7']

Where:

- state0 through state7 are the hash state (in terms of 8 SHA256 words)

- msg0 through msg15 are the 64 -bytes input message (in terms of 16 SHA256 words)

See https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2.hpp#L89-L113

& https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2_256.hpp#L148-L187 ( loop body execution ) | | hash_2to1 | Given 64 -bytes input, this routine computes 32 -bytes SHA256 digest

Input: [m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...]

Where: m[0,16) = 32 -bit word

Note, each SHA256 word is 32 -bit wide, so that's how input is expected.

As you've 64 -bytes, consider packing 4 consecutive bytes into single word,

maintaining big endian byte order.

SHA256 digest is represented in terms of eight 32 -bit words ( big endian byte order ). | | hash_1to1 | Given 32 -bytes input, this routine computes 32 -bytes SHA256 digest

Expected stack state:

Input: [m0, m1, m2, m3, m4, m5, m6, m7, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...]

Where: m[0,8) = 32 -bit word

Note, each SHA256 word is 32 -bit wide, so that's how input is expected.

As you've 32 -bytes, consider packing 4 consecutive bytes into single word,

maintaining big endian byte order.

SHA256 digest is represented in terms of eight 32 -bit words ( big endian byte order ). | -| hash_memory | Given a memory address and a message length in bytes, compute its sha256 digest

- There must be space for writing the padding after the message in memory

- The padding space after the message must be all zeros

Input: [addr, len, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...] | +| hash_memory | Given a memory address and a message length in bytes, compute its sha256 digest

- There must be space for writing the padding after the message in memory

- The padding space after the message must be all zeros before this procedure is called

Input: [addr, len, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...] | From 32598b8d1bf97d17f3ec442b6b92e71a627d1199 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Mon, 22 May 2023 02:20:33 -0700 Subject: [PATCH 04/18] feat: add smtinsert advice injector --- core/src/operations/decorators/advice.rs | 4 + processor/src/decorators/mod.rs | 1 + processor/src/decorators/tests.rs | 120 +++++++++++++++++++---- 3 files changed, 108 insertions(+), 17 deletions(-) diff --git a/core/src/operations/decorators/advice.rs b/core/src/operations/decorators/advice.rs index d583c22f57..c939cdbc80 100644 --- a/core/src/operations/decorators/advice.rs +++ b/core/src/operations/decorators/advice.rs @@ -181,6 +181,9 @@ pub enum AdviceInjector { /// Where KEY is computed as hash(A || B, domain), where domain is provided via the immediate /// value. HdwordToMap { domain: Felt }, + + /// TODO: add docs + SmtInsert, } impl fmt::Display for AdviceInjector { @@ -202,6 +205,7 @@ impl fmt::Display for AdviceInjector { Self::Ext2Inv => write!(f, "ext2_inv"), Self::Ext2Intt => write!(f, "ext2_intt"), Self::SmtGet => write!(f, "smt_get"), + Self::SmtInsert => write!(f, "smt_insert"), Self::MemToMap => write!(f, "mem_to_map"), Self::HdwordToMap { domain } => write!(f, "hdword_to_map.{domain}"), } diff --git a/processor/src/decorators/mod.rs b/processor/src/decorators/mod.rs index 27b92166fa..25c7a28580 100644 --- a/processor/src/decorators/mod.rs +++ b/processor/src/decorators/mod.rs @@ -45,6 +45,7 @@ where AdviceInjector::Ext2Inv => self.push_ext2_inv_result(), AdviceInjector::Ext2Intt => self.push_ext2_intt_result(), AdviceInjector::SmtGet => self.push_smtget_inputs(), + AdviceInjector::SmtInsert => todo!(), AdviceInjector::MemToMap => self.insert_mem_values_into_adv_map(), AdviceInjector::HdwordToMap { domain } => self.insert_hdword_into_adv_map(*domain), } diff --git a/processor/src/decorators/tests.rs b/processor/src/decorators/tests.rs index 17aee7542c..4b9848f1b1 100644 --- a/processor/src/decorators/tests.rs +++ b/processor/src/decorators/tests.rs @@ -7,7 +7,7 @@ use test_utils::{crypto::get_smt_remaining_key, rand::seeded_word}; use vm_core::{ crypto::{ hash::{Rpo256, RpoDigest}, - merkle::{EmptySubtreeRoots, MerkleStore, MerkleTree, NodeIndex}, + merkle::{EmptySubtreeRoots, MerkleStore, MerkleTree, NodeIndex, TieredSmt}, }, utils::IntoBytes, AdviceInjector, Decorator, ONE, ZERO, @@ -60,6 +60,9 @@ fn push_merkle_node() { assert_eq!(expected_stack, process.stack.trace_state()); } +// SMTGET TESTS +// ================================================================================================ + #[test] fn push_smtget() { // setup the test @@ -144,8 +147,62 @@ fn push_smtget() { } } +// SMTINSERT TESTS +// ================================================================================================ + +#[test] +fn inject_smtinsert() { + let mut smt = TieredSmt::default(); + + // --- insert into empty tree --------------------------------------------- + + let raw_a = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = build_key(raw_a); + let val_a = [ONE, ZERO, ZERO, ZERO]; + + // insertion should happen at depth 16 and thus 16_or_32 and 16_or_48 flags should be set to ONE; + // since we are replacing a node which is an empty subtree, the is_empty flag should also be ONE + let expected_stack = [ONE, ONE, ONE]; + let process = prepare_smt_insert(key_a, val_a, &smt, expected_stack.len()); + assert_eq!(build_expected(&expected_stack), process.stack.trace_state()); + + // --- update same key with different value ------------------------------- + + let val_b = [ONE, ONE, ZERO, ZERO]; + smt.insert(key_a.into(), val_b); + + // we are updating a node at depth 16 and thus 16_or_32 and 16_or_48 flags should be set to ONE; + // since we are updating an existing leaf, the is_empty flag should be set to ZERO + let expected_stack = [ZERO, ONE, ONE]; + let process = prepare_smt_insert(key_a, val_b, &smt, expected_stack.len()); + assert_eq!(build_expected(&expected_stack), process.stack.trace_state()); +} + +fn prepare_smt_insert( + key: Word, + value: Word, + smt: &TieredSmt, + adv_stack_depth: usize, +) -> Process { + let root: Word = smt.root().into(); + let store = MerkleStore::from(smt); + + let stack_inputs = build_stack_inputs(value, key, root); + let advice_inputs = AdviceInputs::default().with_merkle_store(store); + let mut process = build_process(stack_inputs, advice_inputs); + + process.execute_op(Operation::Noop).unwrap(); + process + .execute_decorator(&Decorator::Advice(AdviceInjector::SmtInsert)) + .unwrap(); + + move_adv_to_stack(&mut process, adv_stack_depth); + + process +} + // HELPER FUNCTIONS -// -------------------------------------------------------------------------------------------- +// ================================================================================================ fn init_leaf(value: u64) -> Word { [Felt::new(value), Felt::ZERO, Felt::ZERO, Felt::ZERO] @@ -169,17 +226,7 @@ fn assert_case_smtget( expected_stack: &[Felt], ) { // build the process - let stack_inputs = StackInputs::try_from_values([ - root[0].as_int(), - root[1].as_int(), - root[2].as_int(), - root[3].as_int(), - key[0].as_int(), - key[1].as_int(), - key[2].as_int(), - key[3].as_int(), - ]) - .unwrap(); + let stack_inputs = build_stack_inputs(key, root, Word::default()); let remaining = get_smt_remaining_key(key, depth); let mapped = remaining.into_iter().chain(value.into_iter()).collect(); let advice_inputs = AdviceInputs::default() @@ -192,13 +239,52 @@ fn assert_case_smtget( // call the injector and clear the stack process.execute_op(Operation::Noop).unwrap(); process.execute_decorator(&Decorator::Advice(AdviceInjector::SmtGet)).unwrap(); - for _ in 0..8 { + + // replace operand stack contents with the data on the advice stack + move_adv_to_stack(&mut process, expected_stack.len()); + + assert_eq!(build_expected(expected_stack), process.stack.trace_state()); +} + +fn build_process( + stack_inputs: StackInputs, + adv_inputs: AdviceInputs, +) -> Process { + let advice_provider = MemAdviceProvider::from(adv_inputs); + Process::new(Kernel::default(), stack_inputs, advice_provider) +} + +fn build_stack_inputs(w0: Word, w1: Word, w2: Word) -> StackInputs { + StackInputs::try_from_values([ + w2[0].as_int(), + w2[1].as_int(), + w2[2].as_int(), + w2[3].as_int(), + w1[0].as_int(), + w1[1].as_int(), + w1[2].as_int(), + w1[3].as_int(), + w0[0].as_int(), + w0[1].as_int(), + w0[2].as_int(), + w0[3].as_int(), + ]) + .unwrap() +} + +fn build_key(prefix: u64) -> Word { + [ONE, ONE, ONE, Felt::new(prefix)] +} + +/// Removes all items from the operand stack and pushes the specified number of values from +/// the advice tack onto it. +fn move_adv_to_stack(process: &mut Process, adv_stack_depth: usize) { + let stack_depth = process.stack.depth(); + for _ in 0..stack_depth { process.execute_op(Operation::Drop).unwrap(); } - // expect the stack output - for _ in 0..expected_stack.len() { + for _ in 0..adv_stack_depth { process.execute_op(Operation::AdvPop).unwrap(); } - assert_eq!(build_expected(expected_stack), process.stack.trace_state()); } From eb083f8b77415f5bd0c0c1ea86b84af9ac0ae8fd Mon Sep 17 00:00:00 2001 From: grjte Date: Thu, 27 Jul 2023 17:07:28 -0400 Subject: [PATCH 05/18] feat(air): migrate range checker bus to LogUp --- air/src/constraints/chiplets/memory/mod.rs | 29 --- air/src/constraints/chiplets/mod.rs | 1 - air/src/constraints/mod.rs | 103 +++++++++++ air/src/constraints/range.rs | 203 +++++++++------------ air/src/constraints/stack/op_flags/mod.rs | 2 +- air/src/trace/decoder/mod.rs | 5 + air/src/trace/mod.rs | 6 +- air/src/trace/range.rs | 14 +- 8 files changed, 199 insertions(+), 164 deletions(-) diff --git a/air/src/constraints/chiplets/memory/mod.rs b/air/src/constraints/chiplets/memory/mod.rs index 94643a3a9c..e0120aa758 100644 --- a/air/src/constraints/chiplets/memory/mod.rs +++ b/air/src/constraints/chiplets/memory/mod.rs @@ -355,32 +355,3 @@ impl EvaluationFrameExt for &EvaluationFrame { self.selector_next(1) } } - -// EXTERNAL ACCESSORS -// ================================================================================================ -/// Trait to allow other processors to easily access the memory column values they need for -/// constraint calculations. -pub trait MemoryFrameExt { - // --- Column accessors ----------------------------------------------------------------------- - - /// The value of the lower 16-bits of the delta value being tracked between two consecutive - /// context IDs, addresses, or clock cycles in the current row. - fn memory_d0(&self) -> E; - /// The value of the upper 16-bits of the delta value being tracked between two consecutive - /// context IDs, addresses, or clock cycles in the current row. - fn memory_d1(&self) -> E; -} - -impl MemoryFrameExt for &EvaluationFrame { - // --- Column accessors ----------------------------------------------------------------------- - - #[inline(always)] - fn memory_d0(&self) -> E { - self.current()[MEMORY_D0_COL_IDX] - } - - #[inline(always)] - fn memory_d1(&self) -> E { - self.current()[MEMORY_D1_COL_IDX] - } -} diff --git a/air/src/constraints/chiplets/mod.rs b/air/src/constraints/chiplets/mod.rs index 7c960148ad..9810f91cec 100644 --- a/air/src/constraints/chiplets/mod.rs +++ b/air/src/constraints/chiplets/mod.rs @@ -6,7 +6,6 @@ use crate::utils::{are_equal, binary_not, is_binary}; mod bitwise; mod hasher; mod memory; -pub use memory::MemoryFrameExt; // CONSTANTS // ================================================================================================ diff --git a/air/src/constraints/mod.rs b/air/src/constraints/mod.rs index 5d55b79a6c..0b4330ca0e 100644 --- a/air/src/constraints/mod.rs +++ b/air/src/constraints/mod.rs @@ -1,3 +1,106 @@ +use super::{EvaluationFrame, ExtensionOf, Felt, FieldElement}; +use crate::trace::{ + chiplets::{MEMORY_D0_COL_IDX, MEMORY_D1_COL_IDX}, + decoder::{DECODER_OP_BITS_OFFSET, DECODER_USER_OP_HELPERS_OFFSET}, +}; +use crate::utils::binary_not; + pub mod chiplets; pub mod range; pub mod stack; + +// ACCESSORS +// ================================================================================================ +/// Trait to allow other processors to easily access the column values they need for constraint +/// calculations. +pub trait MainFrameExt +where + F: FieldElement, + E: FieldElement + ExtensionOf, +{ + /// Returns true when a u32 stack operation that requires range checks is being performed. + fn u32_rc_op(&self) -> F; + + // --- Range check lookup accessors ----------------------------------------------------------------------- + + /// The value required for the first memory lookup when the memory chiplet requests range + /// checks. The value returned is the denominator used for including the value into the LogUp + /// lookup: (alpha - d0). The value d0 which is being range-checked is the lower 16-bits of the + /// delta value being tracked between two consecutive context IDs, addresses, or clock cycles in + /// the current row. + fn lookup_mv0(&self, alpha: E) -> E; + /// The value required for the second memory lookup when the memory chiplet requests range + /// checks. The value returned is the denominator used for including the value into the LogUp + /// lookup: (alpha - d1). The value d1 which is being range-checked is the upper 16-bits of the + /// delta value being tracked between two consecutive context IDs, addresses, or clock cycles in + /// the current row. + fn lookup_mv1(&self, alpha: E) -> E; + /// The value required for the first stack lookup when the stack requests range checks. The + /// value returned is the denominator used for including the value into the LogUp lookup: + /// (alpha - h0). The value h0 which is being range checked by the stack operation is stored in + /// the helper columns of the decoder section of the trace. + fn lookup_sv0(&self, alpha: E) -> E; + /// The value required for the second stack lookup when the stack requests range checks. The + /// value returned is the denominator used for including the value into the LogUp lookup: + /// (alpha - h1). The value h1 which is being range checked by the stack operation is stored in + /// the helper columns of the decoder section of the trace. + fn lookup_sv1(&self, alpha: E) -> E; + /// The value required for the third stack lookup when the stack requests range checks. The + /// value returned is the denominator used for including the value into the LogUp lookup: + /// (alpha - h2). The value h2 which is being range checked by the stack operation is stored in + /// the helper columns of the decoder section of the trace. + fn lookup_sv2(&self, alpha: E) -> E; + /// The value required for the fourth stack lookup when the stack requests range checks. The + /// value returned is the denominator used for including the value into the LogUp lookup: + /// (alpha - h3). The value h3 which is being range checked by the stack operation is stored in + /// the helper columns of the decoder section of the trace. + fn lookup_sv3(&self, alpha: E) -> E; +} + +impl MainFrameExt for EvaluationFrame +where + F: FieldElement, + E: FieldElement + ExtensionOf, +{ + /// Returns true when the stack operation is a u32 operation that requires range checks. + /// TODO: this is also defined in the op flags. It's redefined here to avoid computing all of + /// the op flags when this is the only one needed, but ideally this should only be defined once. + #[inline(always)] + fn u32_rc_op(&self) -> F { + let not_4 = binary_not(self.current()[DECODER_OP_BITS_OFFSET + 4]); + let not_5 = binary_not(self.current()[DECODER_OP_BITS_OFFSET + 5]); + self.current()[DECODER_OP_BITS_OFFSET + 6].mul(not_5).mul(not_4) + } + + // --- Intermediate values for LogUp lookups -------------------------------------------------- + + #[inline(always)] + fn lookup_mv0(&self, alpha: E) -> E { + alpha - self.current()[MEMORY_D0_COL_IDX].into() + } + + #[inline(always)] + fn lookup_mv1(&self, alpha: E) -> E { + alpha - self.current()[MEMORY_D1_COL_IDX].into() + } + + #[inline(always)] + fn lookup_sv0(&self, alpha: E) -> E { + alpha - self.current()[DECODER_USER_OP_HELPERS_OFFSET].into() + } + + #[inline(always)] + fn lookup_sv1(&self, alpha: E) -> E { + alpha - self.current()[DECODER_USER_OP_HELPERS_OFFSET + 1].into() + } + + #[inline(always)] + fn lookup_sv2(&self, alpha: E) -> E { + alpha - self.current()[DECODER_USER_OP_HELPERS_OFFSET + 2].into() + } + + #[inline(always)] + fn lookup_sv3(&self, alpha: E) -> E { + alpha - self.current()[DECODER_USER_OP_HELPERS_OFFSET + 3].into() + } +} diff --git a/air/src/constraints/range.rs b/air/src/constraints/range.rs index 6d28b07317..c0ba0c6bf1 100644 --- a/air/src/constraints/range.rs +++ b/air/src/constraints/range.rs @@ -1,7 +1,8 @@ use crate::{ - chiplets::{ChipletsFrameExt, MemoryFrameExt}, - trace::range::{B_RANGE_COL_IDX, Q_COL_IDX, S0_COL_IDX, S1_COL_IDX, V_COL_IDX}, - utils::{are_equal, binary_not, is_binary}, + chiplets::ChipletsFrameExt, + constraints::MainFrameExt, + trace::range::{B_RANGE_COL_IDX, M_COL_IDX, V_COL_IDX}, + utils::are_equal, Assertion, EvaluationFrame, Felt, FieldElement, TransitionConstraintDegree, }; use vm_core::{utils::collections::Vec, ExtensionOf}; @@ -15,11 +16,10 @@ use winter_air::AuxTraceRandElements; /// The number of boundary constraints required by the Range Checker pub const NUM_ASSERTIONS: usize = 2; /// The number of transition constraints required by the Range Checker. -pub const NUM_CONSTRAINTS: usize = 3; +pub const NUM_CONSTRAINTS: usize = 1; /// The degrees of the range checker's constraints, in the order they'll be added to the the result /// array when a transition is evaluated. pub const CONSTRAINT_DEGREES: [usize; NUM_CONSTRAINTS] = [ - 2, 2, // Selector flags must be binary: s0, s1. 9, // Enforce values of column v transition. ]; @@ -30,7 +30,7 @@ pub const NUM_AUX_ASSERTIONS: usize = 2; /// The number of transition constraints required by multiset checks for the Range Checker. pub const NUM_AUX_CONSTRAINTS: usize = 1; /// The degrees of the Range Checker's auxiliary column constraints, used for multiset checks. -pub const AUX_CONSTRAINT_DEGREES: [usize; NUM_AUX_CONSTRAINTS] = [7]; +pub const AUX_CONSTRAINT_DEGREES: [usize; NUM_AUX_CONSTRAINTS] = [9]; // BOUNDARY CONSTRAINTS // ================================================================================================ @@ -81,11 +81,16 @@ pub fn get_transition_constraint_count() -> usize { /// Enforces constraints for the range checker. pub fn enforce_constraints(frame: &EvaluationFrame, result: &mut [E]) { - // Constrain the selector flags. - let index = enforce_flags(frame, result); - - // Constrain the transition between rows of the range checker table. - enforce_delta(frame, &mut result[index..]); + // Constrain the transition of the value column between rows in the range checker table. + result[0] = frame.change(V_COL_IDX) + * (frame.change(V_COL_IDX) - E::ONE) + * (frame.change(V_COL_IDX) - E::from(3_u8)) + * (frame.change(V_COL_IDX) - E::from(9_u8)) + * (frame.change(V_COL_IDX) - E::from(27_u8)) + * (frame.change(V_COL_IDX) - E::from(81_u8)) + * (frame.change(V_COL_IDX) - E::from(243_u8)) + * (frame.change(V_COL_IDX) - E::from(729_u16)) + * (frame.change(V_COL_IDX) - E::from(2187_u16)); } // --- AUXILIARY COLUMNS (FOR MULTISET CHECKS) ---------------------------------------------------- @@ -113,48 +118,40 @@ pub fn enforce_aux_constraints( let alpha = aux_rand_elements.get_segment_elements(0)[0]; // Enforce b_range. - enforce_running_product_b_range(main_frame, aux_frame, alpha, &mut result[..]); + enforce_b_range(main_frame, aux_frame, alpha, result); } // TRANSITION CONSTRAINT HELPERS // ================================================================================================ -// --- MAIN TRACE --------------------------------------------------------------------------------- - -/// Constrain the selector flags to binary values. -fn enforce_flags(frame: &EvaluationFrame, result: &mut [E]) -> usize { - let constraint_count = 2; - - result[0] = is_binary(frame.s0()); - result[1] = is_binary(frame.s1()); - - constraint_count -} - -/// Constrain the transition between rows in the range checker table. -fn enforce_delta(frame: &EvaluationFrame, result: &mut [E]) -> usize { - let constraint_count = 1; - - result[0] = frame.change(V_COL_IDX) - * (frame.change(V_COL_IDX) - E::ONE) - * (frame.change(V_COL_IDX) - E::from(3_u8)) - * (frame.change(V_COL_IDX) - E::from(9_u8)) - * (frame.change(V_COL_IDX) - E::from(27_u8)) - * (frame.change(V_COL_IDX) - E::from(81_u8)) - * (frame.change(V_COL_IDX) - E::from(243_u8)) - * (frame.change(V_COL_IDX) - E::from(729_u16)) - * (frame.change(V_COL_IDX) - E::from(2187_u16)); - - constraint_count -} - // --- AUXILIARY COLUMNS (FOR MULTISET CHECKS) ---------------------------------------------------- -/// Ensures that the running product is computed correctly in the column `b_range`. It enforces -/// that the value only changes after the padded rows, where the value of `z` is included at each -/// step, ensuring that the values in the range checker table are multiplied into `b_range` 0, 1, -/// 2, or 4 times, according to the selector flags. -fn enforce_running_product_b_range( +/// Ensures that the range checker bus is computed correctly. It enforces an implementation of the +/// LogUp lookup as a running sum "bus" column. All values in the range checker trace are saved +/// with their lookup multiplicity and the logarithmic derivatives are added to b_range. Values +/// for which lookups are requested from the stack and memory are each looked up with multiplicity +/// one, and the logarithmic derivatives are subtracted from b_range. +/// +/// Define the following variables: +/// - rc_value: the range checker value +/// - rc_multiplicity: the range checker multiplicity value +/// - flag_s: boolean flag indicating a stack operation with range checks. This flag is degree 3. +/// - sv0-sv3: stack value 0-3, the 4 values range-checked from the stack +/// - flag_m: boolean flag indicating the memory chiplet is active (i.e. range checks are required). +/// This flag is degree 3. +/// - mv0-mv1: memory value 0-1, the 2 values range-checked from the memory chiplet +/// +/// The constraint expression looks as follows: +/// b' = b + rc_multiplicity / (alpha - rc_value) +/// - flag_s / (alpha - sv0) - flag_s / (alpha - sv1) +/// - flag_s / (alpha - sv2) - flag_s / (alpha - sv3) +/// - flag_m / (alpha - mv0) - flag_m / (alpha - mv1) +/// +/// However, to enforce the constraint, all denominators are multiplied so that no divisions are +/// included in the actual constraint expression. +/// +/// Constraint degree: 9 +fn enforce_b_range( main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, alpha: E, @@ -163,53 +160,43 @@ fn enforce_running_product_b_range( F: FieldElement, E: FieldElement + ExtensionOf, { - // The running product column must enforce that the next step has the values from the range - // checker multiplied in (z) and the values from the stack (q) and the memory divided out. This - // is enforced by ensuring that b_range_next multiplied by the stack and memory lookups at this step - // is equal to the combination of b_range and the range checker's values for this step. - let lookups = aux_frame.q() * get_memory_lookups(main_frame, alpha); - let range_checks = get_z(main_frame, alpha); - - result[0] = are_equal(aux_frame.b_range_next() * lookups, aux_frame.b_range() * range_checks); -} - -/// The value to be included in the running product column for memory lookups at this row. These are -/// only included for steps in the memory section of the trace (when the memory_flag is one). -fn get_memory_lookups(main_frame: &EvaluationFrame, alpha: E) -> E -where - F: FieldElement, - E: FieldElement + ExtensionOf, -{ - let memory_flag: E = main_frame.chiplets_memory_flag().into(); - let d0: E = main_frame.memory_d0().into(); - let d1: E = main_frame.memory_d1().into(); - - E::ONE + memory_flag * ((d0 + alpha) * (d1 + alpha) - E::ONE) -} - -/// Returns the value `z` which is included in the running product columns at each step. `z` causes -/// the row's value to be included 0, 1, 2, or 4 times, according to the row's selector flags row. -fn get_z(main_frame: &EvaluationFrame, alpha: E) -> E -where - F: FieldElement, - E: FieldElement + ExtensionOf, -{ - // Get the selectors and the value from the main frame. - let s0: E = main_frame.s0().into(); - let s1: E = main_frame.s1().into(); - let v: E = main_frame.v().into(); - - // Define the flags. - let f0: E = binary_not(s0) * binary_not(s1); - let f1: E = s0 * binary_not(s1); - let f2: E = binary_not(s0) * s1; - let f3: E = s0 * s1; - - // Compute z. - let v_alpha = v + alpha; - let v_alpha2 = v_alpha.square(); - let v_alpha4 = v_alpha2.square(); - f3 * v_alpha4 + f2 * v_alpha2 + f1 * v_alpha + f0 + // The denominator values for the LogUp lookup. + let mv0: E = main_frame.lookup_mv0(alpha); + let mv1: E = main_frame.lookup_mv1(alpha); + let sv0: E = main_frame.lookup_sv0(alpha); + let sv1: E = main_frame.lookup_sv1(alpha); + let sv2: E = main_frame.lookup_sv2(alpha); + let sv3: E = main_frame.lookup_sv3(alpha); + let range_check: E = alpha - main_frame.v().into(); + let memory_lookups: E = mv0.mul(mv1); // degree 2 + let stack_lookups: E = sv0.mul(sv1).mul(sv2).mul(sv3); // degree 4 + let lookups = range_check.mul(stack_lookups).mul(memory_lookups); // degree 7 + + // An intermediate value required by all stack terms that includes the flag indicating a stack + // operation with range checks. This value has degree 6. + let sflag_rc_mem: E = range_check + .mul(memory_lookups) + .mul_base( as MainFrameExt>::u32_rc_op(main_frame)); + // An intermediate value required by all memory terms that includes the flag indicating the + // memory portion of the chiplets trace. This value has degree 8. + let mflag_rc_stack: E = + range_check.mul(stack_lookups).mul_base(main_frame.chiplets_memory_flag()); + + // The terms for the LogUp check after all denominators have been multiplied in. + let b_next_term = aux_frame.b_range_next().mul(lookups); // degree 8 + let b_term = aux_frame.b_range().mul(lookups); // degree 8 + let rc_term = stack_lookups.mul(memory_lookups).mul_base(main_frame.multiplicity()); // degree 7 + let s0_term = sflag_rc_mem.mul(sv1).mul(sv2).mul(sv3); // degree 9 + let s1_term = sflag_rc_mem.mul(sv0).mul(sv2).mul(sv3); // degree 9 + let s2_term = sflag_rc_mem.mul(sv0).mul(sv1).mul(sv3); // degree 9 + let s3_term = sflag_rc_mem.mul(sv0).mul(sv1).mul(sv2); // degree 9 + let m0_term = mflag_rc_stack.mul(mv1); // degree 9 + let m1_term = mflag_rc_stack.mul(mv0); // degree 9 + + result[0] = are_equal( + b_next_term, + b_term + rc_term - s0_term - s1_term - s2_term - s3_term - m0_term - m1_term, + ); } // RANGE CHECKER FRAME EXTENSION TRAIT @@ -220,22 +207,15 @@ where trait EvaluationFrameExt { // --- Column accessors ----------------------------------------------------------------------- - fn s0(&self) -> E; - /// The current value in column s1. - fn s1(&self) -> E; + /// The current value in the lookup multiplicity column. + fn multiplicity(&self) -> E; /// The current value in column V. fn v(&self) -> E; - /// The next value in column V. - fn v_next(&self) -> E; /// The current value in auxiliary column b_range. fn b_range(&self) -> E; - /// The next value in auxiliary column b_range. fn b_range_next(&self) -> E; - /// The current value in auxiliary column q. - fn q(&self) -> E; - // --- Intermediate variables & helpers ------------------------------------------------------- /// The change between the current value in the specified column and the next value, calculated @@ -247,13 +227,8 @@ impl EvaluationFrameExt for &EvaluationFrame { // --- Column accessors ----------------------------------------------------------------------- #[inline(always)] - fn s0(&self) -> E { - self.current()[S0_COL_IDX] - } - - #[inline(always)] - fn s1(&self) -> E { - self.current()[S1_COL_IDX] + fn multiplicity(&self) -> E { + self.current()[M_COL_IDX] } #[inline(always)] @@ -261,11 +236,6 @@ impl EvaluationFrameExt for &EvaluationFrame { self.current()[V_COL_IDX] } - #[inline(always)] - fn v_next(&self) -> E { - self.next()[V_COL_IDX] - } - #[inline(always)] fn b_range(&self) -> E { self.current()[B_RANGE_COL_IDX] @@ -276,11 +246,6 @@ impl EvaluationFrameExt for &EvaluationFrame { self.next()[B_RANGE_COL_IDX] } - #[inline(always)] - fn q(&self) -> E { - self.current()[Q_COL_IDX] - } - // --- Intermediate variables & helpers ------------------------------------------------------- #[inline(always)] diff --git a/air/src/constraints/stack/op_flags/mod.rs b/air/src/constraints/stack/op_flags/mod.rs index 1dc0c0958f..9a684326a7 100644 --- a/air/src/constraints/stack/op_flags/mod.rs +++ b/air/src/constraints/stack/op_flags/mod.rs @@ -989,7 +989,7 @@ impl OpFlags { self.control_flow } - /// Returns the flag when the stack operation is a u32 operation. + /// Returns true when the stack operation is a u32 operation that requires range checks. #[inline(always)] pub fn u32_rc_op(&self) -> E { self.u32_rc_op diff --git a/air/src/trace/decoder/mod.rs b/air/src/trace/decoder/mod.rs index d0e8f54f79..2795cc9646 100644 --- a/air/src/trace/decoder/mod.rs +++ b/air/src/trace/decoder/mod.rs @@ -97,3 +97,8 @@ pub const P2_COL_IDX: usize = DECODER_AUX_TRACE_OFFSET + 1; /// Running product column representing op group table. pub const P3_COL_IDX: usize = DECODER_AUX_TRACE_OFFSET + 2; + +// --- GLOBALLY-INDEXED DECODER COLUMN ACCESSORS -------------------------------------------------- +pub const DECODER_OP_BITS_OFFSET: usize = super::DECODER_TRACE_OFFSET + OP_BITS_OFFSET; +pub const DECODER_USER_OP_HELPERS_OFFSET: usize = + super::DECODER_TRACE_OFFSET + USER_OP_HELPERS_OFFSET; diff --git a/air/src/trace/mod.rs b/air/src/trace/mod.rs index 81c45b2fc5..fa908671f4 100644 --- a/air/src/trace/mod.rs +++ b/air/src/trace/mod.rs @@ -42,7 +42,7 @@ pub const STACK_TRACE_RANGE: Range = range(STACK_TRACE_OFFSET, STACK_TRAC // Range check trace pub const RANGE_CHECK_TRACE_OFFSET: usize = STACK_TRACE_RANGE.end; -pub const RANGE_CHECK_TRACE_WIDTH: usize = 3; +pub const RANGE_CHECK_TRACE_WIDTH: usize = 2; pub const RANGE_CHECK_TRACE_RANGE: Range = range(RANGE_CHECK_TRACE_OFFSET, RANGE_CHECK_TRACE_WIDTH); @@ -57,7 +57,7 @@ pub const TRACE_WIDTH: usize = CHIPLETS_OFFSET + CHIPLETS_WIDTH; // ------------------------------------------------------------------------------------------------ // decoder stack range checks hasher chiplets -// (3 columns) (1 column) (2 columns) (1 column) (1 column) +// (3 columns) (1 column) (1 column) (1 column) (1 column) // ├───────────────┴──────────────┴──────────────┴───────────────┴───────────────┤ // Decoder auxiliary columns @@ -74,7 +74,7 @@ pub const STACK_AUX_TRACE_RANGE: Range = // Range check auxiliary columns pub const RANGE_CHECK_AUX_TRACE_OFFSET: usize = STACK_AUX_TRACE_RANGE.end; -pub const RANGE_CHECK_AUX_TRACE_WIDTH: usize = 2; +pub const RANGE_CHECK_AUX_TRACE_WIDTH: usize = 1; pub const RANGE_CHECK_AUX_TRACE_RANGE: Range = range(RANGE_CHECK_AUX_TRACE_OFFSET, RANGE_CHECK_AUX_TRACE_WIDTH); diff --git a/air/src/trace/range.rs b/air/src/trace/range.rs index c86fffde31..3081bbe3a8 100644 --- a/air/src/trace/range.rs +++ b/air/src/trace/range.rs @@ -5,21 +5,13 @@ use super::{RANGE_CHECK_AUX_TRACE_OFFSET, RANGE_CHECK_TRACE_OFFSET}; // --- Column accessors in the main trace --------------------------------------------------------- -/// A binary selector column to help specify whether or not the value should be included in the -/// running product. -pub const S0_COL_IDX: usize = RANGE_CHECK_TRACE_OFFSET; -/// A binary selector column to help specify whether or not the value should be included in the -/// running product. -pub const S1_COL_IDX: usize = RANGE_CHECK_TRACE_OFFSET + 1; +/// A column to hold the multiplicity of how many times the value is being range-checked. +pub const M_COL_IDX: usize = RANGE_CHECK_TRACE_OFFSET; /// A column to hold the values being range-checked. -pub const V_COL_IDX: usize = RANGE_CHECK_TRACE_OFFSET + 2; +pub const V_COL_IDX: usize = RANGE_CHECK_TRACE_OFFSET + 1; // --- Column accessors in the auxiliary columns -------------------------------------------------- /// The running product column used for verifying that the range check lookups performed in the /// Stack and the Memory chiplet match the values checked in the Range Checker. pub const B_RANGE_COL_IDX: usize = RANGE_CHECK_AUX_TRACE_OFFSET; - -/// An auxiliary trace column of intermediate values used to enforce AIR constraints on `b_range`. -/// It contains the product of the lookups performed by the Stack processor at each cycle. -pub const Q_COL_IDX: usize = B_RANGE_COL_IDX + 1; From 3c64447e2b9cd7de34aa9374a0c2410f6402b7c8 Mon Sep 17 00:00:00 2001 From: grjte Date: Fri, 28 Jul 2023 09:55:15 -0400 Subject: [PATCH 06/18] feat(proc): update range checker trace gen main/aux --- processor/src/chiplets/memory/mod.rs | 2 +- processor/src/operations/u32_ops.rs | 2 +- processor/src/range/aux_trace.rs | 116 ++++++--------- processor/src/range/mod.rs | 205 ++++++--------------------- processor/src/range/tests.rs | 6 +- processor/src/trace/tests/range.rs | 45 +----- 6 files changed, 93 insertions(+), 283 deletions(-) diff --git a/processor/src/chiplets/memory/mod.rs b/processor/src/chiplets/memory/mod.rs index 41c254dfe1..2b87904fda 100644 --- a/processor/src/chiplets/memory/mod.rs +++ b/processor/src/chiplets/memory/mod.rs @@ -176,7 +176,7 @@ impl Memory { }; let (delta_hi, delta_lo) = split_u32_into_u16(delta); - range.add_mem_checks(row, &[delta_lo, delta_hi]); + range.add_range_checks(row, &[delta_lo, delta_hi]); // update values for the next iteration of the loop prev_ctx = ctx; diff --git a/processor/src/operations/u32_ops.rs b/processor/src/operations/u32_ops.rs index 28d2077c50..7c14d768ec 100644 --- a/processor/src/operations/u32_ops.rs +++ b/processor/src/operations/u32_ops.rs @@ -210,7 +210,7 @@ where let (t3, t2) = split_u32_into_u16(hi.as_int()); // add lookup values to the range checker. - self.range.add_stack_checks(self.system.clk(), &[t0, t1, t2, t3]); + self.range.add_range_checks(self.system.clk(), &[t0, t1, t2, t3]); // save the range check lookups to the decoder's user operation helper columns. let mut helper_values = diff --git a/processor/src/range/aux_trace.rs b/processor/src/range/aux_trace.rs index 548b112ef4..415689e7e0 100644 --- a/processor/src/range/aux_trace.rs +++ b/processor/src/range/aux_trace.rs @@ -1,8 +1,5 @@ -use super::{ - build_lookup_table_row_values, uninit_vector, BTreeMap, ColMatrix, CycleRangeChecks, Felt, - FieldElement, RangeCheckFlag, Vec, NUM_RAND_ROWS, -}; -use miden_air::trace::range::V_COL_IDX; +use super::{uninit_vector, BTreeMap, ColMatrix, Felt, FieldElement, Vec, NUM_RAND_ROWS}; +use miden_air::trace::range::{M_COL_IDX, V_COL_IDX}; // AUXILIARY TRACE BUILDER // ================================================================================================ @@ -10,16 +7,9 @@ use miden_air::trace::range::V_COL_IDX; /// Describes how to construct the execution trace of columns related to the range checker in the /// auxiliary segment of the trace. These are used in multiset checks. pub struct AuxTraceBuilder { - // Range check lookups performed by all user operations, grouped and sorted by clock cycle. Each - // cycle is mapped to a single CycleRangeChecks instance which includes lookups from the stack, - // memory, or both. - // TODO: once we switch to backfilling memory range checks this approach can change to tracking - // vectors of hints and rows like in the Stack and Hasher AuxTraceBuilders, and the - // CycleRangeChecks struct can be removed. - cycle_range_checks: BTreeMap, - // A trace-length vector of RangeCheckFlags which indicate how many times the range check value - // at that row should be included in the trace. - row_flags: Vec, + /// Range check lookups performed by all user operations, grouped and sorted by the clock cycle + /// at which they are requested. + cycle_range_checks: BTreeMap>, // The index of the first row of Range Checker's trace when the padded rows end and values to // be range checked start. values_start: usize, @@ -28,41 +18,27 @@ pub struct AuxTraceBuilder { impl AuxTraceBuilder { // CONSTRUCTOR // -------------------------------------------------------------------------------------------- - pub fn new( - cycle_range_checks: BTreeMap, - row_flags: Vec, - values_start: usize, - ) -> Self { + pub fn new(cycle_range_checks: BTreeMap>, values_start: usize) -> Self { Self { cycle_range_checks, - row_flags, values_start, } } - // ACCESSORS - // -------------------------------------------------------------------------------------------- - pub fn cycle_range_check_values(&self) -> Vec { - self.cycle_range_checks.values().cloned().collect() - } - // AUX COLUMN BUILDERS // -------------------------------------------------------------------------------------------- /// Builds and returns range checker auxiliary trace columns. Currently this consists of two /// columns: /// - `b_range`: ensures that the range checks performed by the Range Checker match those - /// requested - /// by the Stack and Memory processors. - /// - `q`: a helper column of intermediate values to reduce the degree of the constraints for - /// `b_range`. It contains the product of the lookups performed by the Stack at each row. + /// requested by the Stack and Memory processors. pub fn build_aux_columns>( &self, main_trace: &ColMatrix, rand_elements: &[E], ) -> Vec> { - let (b_range, q) = self.build_aux_col_b_range(main_trace, rand_elements); - vec![b_range, q] + let b_range = self.build_aux_col_b_range(main_trace, rand_elements); + vec![b_range] } /// Builds the execution trace of the range check `b_range` and `q` columns which ensure that the @@ -71,46 +47,40 @@ impl AuxTraceBuilder { &self, main_trace: &ColMatrix, alphas: &[E], - ) -> (Vec, Vec) { - // compute the inverses for range checks performed by operations. - let (_, inv_row_values) = - build_lookup_table_row_values(&self.cycle_range_check_values(), main_trace, alphas); + ) -> Vec { + // TODO: replace this with an efficient solution + // // compute the inverses for range checks performed by operations. + // let (_, inv_row_values) = + // build_lookup_table_row_values(&self.cycle_range_check_values(), main_trace, alphas); - // allocate memory for the running product column and set the initial value to ONE - let mut q = unsafe { uninit_vector(main_trace.num_rows()) }; + // allocate memory for the running sum column and set the initial value to ONE let mut b_range = unsafe { uninit_vector(main_trace.num_rows()) }; - q[0] = E::ONE; b_range[0] = E::ONE; - // keep track of the last updated row in the `b_range` running product column. the `q` - // column index is always one row behind, since `q` is filled with intermediate values in - // the same row as the operation is executed, whereas `b_range` is filled with result - // values that are added to the next row after the operation's execution. + // keep track of the last updated row in the `b_range` running sum column. `b_range` is + // filled with result values that are added to the next row after the operation's execution. let mut b_range_idx = 0_usize; - // keep track of the next row to be included from the user op range check values. - let mut rc_user_op_idx = 0; // the first half of the trace only includes values from the operations. for (clk, range_checks) in self.cycle_range_checks.range(0..self.values_start as u32) { let clk = *clk as usize; // if we skipped some cycles since the last update was processed, values in the last - // updated row should by copied over until the current cycle. + // updated row should be copied over until the current cycle. if b_range_idx < clk { let last_value = b_range[b_range_idx]; b_range[(b_range_idx + 1)..=clk].fill(last_value); - q[b_range_idx..clk].fill(E::ONE); } - // move the column pointers to the next row. + // move the column pointer to the next row. b_range_idx = clk + 1; - // update the intermediate values in the q column. - q[clk] = range_checks.to_stack_value(main_trace, alphas); - - // include the operation lookups in the running product. - b_range[b_range_idx] = b_range[clk] * inv_row_values[rc_user_op_idx]; - rc_user_op_idx += 1; + b_range[b_range_idx] = b_range[clk]; + // include the operation lookups + for lookup in range_checks.iter() { + let value = (alphas[0] - (*lookup).into()).inv(); + b_range[b_range_idx] -= value; + } } // if we skipped some cycles since the last update was processed, values in the last @@ -118,13 +88,13 @@ impl AuxTraceBuilder { if b_range_idx < self.values_start { let last_value = b_range[b_range_idx]; b_range[(b_range_idx + 1)..=self.values_start].fill(last_value); - q[b_range_idx..self.values_start].fill(E::ONE); } - // after the padded section of the range checker table, include `z` in the running product - // at each step and remove lookups from user ops at any step where user ops were executed. - for (row_idx, (hint, lookup)) in self - .row_flags + // after the padded section of the range checker table, include the lookup value specified + // by the range checker into the running sum at each step, and remove lookups from user ops + // at any step where user ops were executed. + for (row_idx, (multiplicity, lookup)) in main_trace + .get_column(M_COL_IDX) .iter() .zip(main_trace.get_column(V_COL_IDX).iter()) .enumerate() @@ -133,32 +103,26 @@ impl AuxTraceBuilder { { b_range_idx = row_idx + 1; - b_range[b_range_idx] = b_range[row_idx] * hint.to_value(*lookup, alphas); - - if let Some(range_check) = self.cycle_range_checks.get(&(row_idx as u32)) { - // update the intermediate values in the q column. - q[row_idx] = range_check.to_stack_value(main_trace, alphas); - - // include the operation lookups in the running product. - b_range[b_range_idx] *= inv_row_values[rc_user_op_idx]; - rc_user_op_idx += 1; - } else { - q[row_idx] = E::ONE; + // add the value in the range checker: multiplicity / (alpha - lookup) + let lookup_val = (alphas[0] - (*lookup).into()).inv().mul_base(*multiplicity); + b_range[b_range_idx] = b_range[row_idx] + lookup_val; + // subtract the range checks requested by operations + if let Some(range_checks) = self.cycle_range_checks.get(&(row_idx as u32)) { + for lookup in range_checks.iter() { + let value = (alphas[0] - (*lookup).into()).inv(); + b_range[b_range_idx] -= value; + } } } // at this point, all range checks from user operations and the range checker should be // matched - so, the last value must be ONE; - assert_eq!(q[b_range_idx - 1], E::ONE); assert_eq!(b_range[b_range_idx], E::ONE); - if (b_range_idx - 1) < b_range.len() - 1 { - q[b_range_idx..].fill(E::ONE); - } if b_range_idx < b_range.len() - 1 { b_range[(b_range_idx + 1)..].fill(E::ONE); } - (b_range, q) + b_range } } diff --git a/processor/src/range/mod.rs b/processor/src/range/mod.rs index 14c5e58ba5..d0794180bc 100644 --- a/processor/src/range/mod.rs +++ b/processor/src/range/mod.rs @@ -1,15 +1,11 @@ use super::{ - trace::{build_lookup_table_row_values, LookupTableRow, NUM_RAND_ROWS}, - utils::uninit_vector, - BTreeMap, ColMatrix, Felt, FieldElement, RangeCheckTrace, Vec, ONE, ZERO, + trace::NUM_RAND_ROWS, utils::uninit_vector, BTreeMap, ColMatrix, Felt, FieldElement, + RangeCheckTrace, Vec, ZERO, }; mod aux_trace; pub use aux_trace::AuxTraceBuilder; -mod request; -use request::CycleRangeChecks; - #[cfg(test)] mod tests; @@ -23,35 +19,31 @@ mod tests; /// into 16-bits, but rather keeps track of all 16-bit range checks performed by the VM. /// /// ## Execution trace -/// Execution trace generated by the range checker consists of 3 columns. Conceptually, the table -/// starts with value 0 and end with value 65535. +/// The execution trace generated by the range checker consists of 2 columns. Conceptually, the +/// table starts with value 0 and ends with value 65535. /// -/// The layout illustrated below. +/// The layout is illustrated below. /// -/// s0 s1 v -/// ├─────┴──────┴─────┤ +/// m v +/// ├─────┴─────┤ /// /// In the above, the meaning of the columns is as follows: /// - Column `v` contains the value being range-checked where `v` must be a 16-bit value. The /// values must be in increasing order and the jump allowed between two values should be a power /// of 3 less than or equal to 3^7, and duplicates are allowed. -/// - Column `s0` and `s1` specify how many lookups are to be included for a given value. -/// Specifically: (0, 0) means no lookups, (1, 0) means one lookup, (0, 1), means two lookups, -/// and (1, 1) means four lookups. +/// - Column `m` specifies the lookup multiplicity, which is how many lookups are to be included for +/// a given value. /// /// Thus, for example, if a value was range-checked just once, we'll need to add a single row to -/// the table with (s0, s1, v) set to (1, 0, v), where v is the value. -/// -/// If, on the other hand, the value was range-checked 5 times, we'll need two rows in the table: -/// (1, 1, v) and (1, 0, v). The first row specifies that there were 4 lookups and the second -/// row add the fifth lookup. +/// the table with (m, v) set to (1, v), where v is the value. If the value was range-checked 5 +/// times, we'll need to specify the row (5, v). pub struct RangeChecker { /// Tracks lookup count for each checked value. lookups: BTreeMap, - // Range check lookups performed by all user operations, grouped and sorted by clock cycle. Each - // cycle is mapped to a single CycleRangeChecks instance which includes lookups from the stack, - // memory, or both. - cycle_range_checks: BTreeMap, + /// Range check lookups performed by all user operations, grouped and sorted by clock cycle. + /// Each cycle is mapped to a vector of the range checks requested at that cycle, which can come + /// from the stack, memory, or both. + cycle_range_checks: BTreeMap>, } impl RangeChecker { @@ -72,41 +64,36 @@ impl RangeChecker { // TRACE MUTATORS // -------------------------------------------------------------------------------------------- + /// Adds the specified value to the trace of this range checker's lookups. pub fn add_value(&mut self, value: u16) { self.lookups.entry(value).and_modify(|v| *v += 1).or_insert(1); } - /// Adds range check lookups from the [Stack] to this [RangeChecker] instance. Stack lookups are - /// guaranteed to be added at unique clock cycles, since operations are sequential and no range - /// check lookups are added before or during the stack operation processing. - pub fn add_stack_checks(&mut self, clk: u32, values: &[u16; 4]) { - self.add_value(values[0]); - self.add_value(values[1]); - self.add_value(values[2]); - self.add_value(values[3]); - - // Stack operations are added before memory operations at unique clock cycles. - self.cycle_range_checks.insert(clk, CycleRangeChecks::new_from_stack(values)); - } + /// Adds range check lookups from the stack or memory to this [RangeChecker] instance. + pub fn add_range_checks(&mut self, clk: u32, values: &[u16]) { + // range checks requests only come from memory or from the stack, which always request 2 or + // 4 lookups respectively. + debug_assert!(values.len() == 2 || values.len() == 4); - /// Adds range check lookups from [Memory] to this [RangeChecker] instance. Memory lookups are - /// always added after all stack lookups have completed, since they are processed during trace - /// finalization. - pub fn add_mem_checks(&mut self, clk: u32, values: &[u16; 2]) { - self.add_value(values[0]); - self.add_value(values[1]); + let mut requests = Vec::new(); + for value in values.iter() { + // add the specified value to the trace of this range checker's lookups. + self.add_value(*value); + requests.push(Felt::from(*value)); + } + // track the range check requests at each cycle self.cycle_range_checks .entry(clk) - .and_modify(|entry| entry.add_memory_checks(values)) - .or_insert_with(|| CycleRangeChecks::new_from_memory(values)); + .and_modify(|entry| entry.append(&mut requests)) + .or_insert_with(|| requests); } // EXECUTION TRACE GENERATION (INTERNAL) // -------------------------------------------------------------------------------------------- - /// Converts this [RangeChecker] into an execution trace with 3 columns and the number of rows + /// Converts this [RangeChecker] into an execution trace with 2 columns and the number of rows /// specified by the `target_len` parameter. /// /// If the number of rows need to represent execution trace of this range checker is smaller @@ -133,28 +120,19 @@ impl RangeChecker { // allocated memory for the trace; this memory is un-initialized but this is not a problem // because we'll overwrite all values in it anyway. - let mut trace = unsafe { - [uninit_vector(target_len), uninit_vector(target_len), uninit_vector(target_len)] - }; - // Allocate uninitialized memory for accumulating the precomputed auxiliary column hints. - let mut row_flags = unsafe { uninit_vector(target_len) }; + let mut trace = unsafe { [uninit_vector(target_len), uninit_vector(target_len)] }; // determine the number of padding rows needed to get to target trace length and pad the // table with the required number of rows. let num_padding_rows = target_len - trace_len - num_rand_rows; trace[0][..num_padding_rows].fill(ZERO); trace[1][..num_padding_rows].fill(ZERO); - trace[2][..num_padding_rows].fill(ZERO); - - // Initialize the padded rows of the auxiliary column hints with the default flag, F0, - // indicating s0 = s1 = ZERO. - row_flags[..num_padding_rows].fill(RangeCheckFlag::F0); // build the trace table let mut i = num_padding_rows; let mut prev_value = 0u16; for (&value, &num_lookups) in self.lookups.iter() { - write_rows(&mut trace, &mut i, num_lookups, value, prev_value, &mut row_flags); + write_rows(&mut trace, &mut i, num_lookups, value, prev_value); prev_value = value; } @@ -163,11 +141,11 @@ impl RangeChecker { // (When there is data at the end of the main trace, auxiliary bus columns always need to be // one row longer than the main trace, since values in the bus column are based on data from // the "current" row of the main trace but placed into the "next" row of the bus column.) - write_value(&mut trace, &mut i, 0, (u16::MAX).into(), &mut row_flags); + write_trace_row(&mut trace, &mut i, 0, (u16::MAX).into()); RangeCheckTrace { trace, - aux_builder: AuxTraceBuilder::new(self.cycle_range_checks, row_flags, num_padding_rows), + aux_builder: AuxTraceBuilder::new(self.cycle_range_checks, num_padding_rows), } } @@ -181,16 +159,16 @@ impl RangeChecker { let mut num_rows = 1; let mut prev_value = 0u16; - for (&value, &num_lookups) in self.lookups.iter() { - // determine how many lookup rows we need for this value - num_rows += lookups_to_rows(num_lookups); + for value in self.lookups.keys() { + // add one row for each value in the range checker table + num_rows += 1; // determine the delta between this and the previous value. we need to know this delta // to determine if we need to insert any "bridge" rows to the table, this is needed // since the gap between two values in the range checker can only be a power of 3 less // than or equal to 3^7. let delta = value - prev_value; num_rows += get_num_bridge_rows(delta); - prev_value = value; + prev_value = *value; } num_rows } @@ -222,58 +200,9 @@ impl Default for RangeChecker { } } -// RANGE CHECKER ROWS -// ================================================================================================ - -/// A precomputed hint value that can be used to help construct the execution trace for the -/// auxiliary column b_range used for multiset checks. The hint is a precomputed flag value based -/// on the selectors s0 and s1 in the trace. -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum RangeCheckFlag { - F0, - F1, - F2, - F3, -} - -impl RangeCheckFlag { - /// Reduces this row to a single field element in the field specified by E. This requires - /// at least 1 alpha value. - pub fn to_value>(&self, value: Felt, alphas: &[E]) -> E { - let alpha: E = alphas[0]; - - match self { - RangeCheckFlag::F0 => E::ONE, - RangeCheckFlag::F1 => alpha + value.into(), - RangeCheckFlag::F2 => (alpha + value.into()).square(), - RangeCheckFlag::F3 => ((alpha + value.into()).square()).square(), - } - } -} - // HELPER FUNCTIONS // ================================================================================================ -/// Returns the number of rows needed to perform the specified number of lookups for an 8-bit -/// value. Note that even if the number of lookups is 0, at least one row is required. This is -/// because for an 8-bit table, rows must contain contiguous values. -/// -/// The number of rows is determined as follows: -/// - First we compute the number of rows for 4 lookups per row. -/// - Then we compute the number of rows for 2 lookups per row. -/// - Then, we compute the number of rows for a single lookup per row. -/// -/// The return value is the sum of these three values. -fn lookups_to_rows(num_lookups: usize) -> usize { - if num_lookups == 0 { - 1 - } else { - let (num_rows4, num_lookups) = div_rem(num_lookups, 4); - let (num_rows2, num_rows1) = div_rem(num_lookups, 2); - num_rows4 + num_rows2 + num_rows1 - } -} - /// Calculates the number of bridge rows that are need to be added to the trace between two values /// to be range checked. pub fn get_num_bridge_rows(delta: u16) -> usize { @@ -299,7 +228,6 @@ fn write_rows( num_lookups: usize, value: u16, prev_value: u16, - row_flags: &mut [RangeCheckFlag], ) { let mut gap = value - prev_value; let mut prev_val = prev_value; @@ -308,62 +236,17 @@ fn write_rows( if gap > stride { gap -= stride; prev_val += stride; - write_value(trace, step, 0, prev_val as u64, row_flags); + write_trace_row(trace, step, 0, prev_val as u64); } else { stride /= 3; } } - write_value(trace, step, num_lookups, value as u64, row_flags); -} - -/// Populates the trace with the rows needed to support the specified number of lookups against -/// the specified value. -fn write_value( - trace: &mut [Vec], - step: &mut usize, - num_lookups: usize, - value: u64, - row_flags: &mut [RangeCheckFlag], -) { - // if the number of lookups is 0, only one trace row is required - if num_lookups == 0 { - row_flags[*step] = RangeCheckFlag::F0; - write_trace_row(trace, step, ZERO, ZERO, value); - return; - } - - // write rows which can support 4 lookups per row - let (num_rows, num_lookups) = div_rem(num_lookups, 4); - for _ in 0..num_rows { - row_flags[*step] = RangeCheckFlag::F3; - write_trace_row(trace, step, ONE, ONE, value); - } - - // write rows which can support 2 lookups per row - let (num_rows, num_lookups) = div_rem(num_lookups, 2); - for _ in 0..num_rows { - row_flags[*step] = RangeCheckFlag::F2; - write_trace_row(trace, step, ZERO, ONE, value); - } - - // write rows which can support only one lookup per row - for _ in 0..num_lookups { - row_flags[*step] = RangeCheckFlag::F1; - write_trace_row(trace, step, ONE, ZERO, value); - } + write_trace_row(trace, step, num_lookups, value as u64); } /// Populates a single row at the specified step in the trace table. -fn write_trace_row(trace: &mut [Vec], step: &mut usize, s0: Felt, s1: Felt, value: u64) { - trace[0][*step] = s0; - trace[1][*step] = s1; - trace[2][*step] = Felt::new(value); +fn write_trace_row(trace: &mut [Vec], step: &mut usize, num_lookups: usize, value: u64) { + trace[0][*step] = Felt::new(num_lookups as u64); + trace[1][*step] = Felt::new(value); *step += 1; } - -/// Returns quotient and remainder of dividing the provided value by the divisor. -fn div_rem(value: usize, divisor: usize) -> (usize, usize) { - let q = value / divisor; - let r = value % divisor; - (q, r) -} diff --git a/processor/src/range/tests.rs b/processor/src/range/tests.rs index 7ec9b3d3ff..0ca4db8fd1 100644 --- a/processor/src/range/tests.rs +++ b/processor/src/range/tests.rs @@ -1,9 +1,10 @@ -use super::{BTreeMap, Felt, RangeChecker, Vec, ONE, ZERO}; +use super::{super::ONE, BTreeMap, Felt, RangeChecker, Vec, ZERO}; use crate::{utils::get_trace_len, RangeCheckTrace}; use rand_utils::rand_array; use vm_core::{utils::ToElements, StarkField}; #[test] +#[ignore = "update required"] fn range_checks() { let mut checker = RangeChecker::new(); @@ -21,7 +22,7 @@ fn range_checks() { // skip the padded rows let mut i = 0; - while trace[0][i] == ZERO && trace[1][i] == ZERO && trace[2][i] == ZERO { + while trace[0][i] == ZERO && trace[1][i] == ZERO { i += 1; } @@ -47,6 +48,7 @@ fn range_checks() { } #[test] +#[ignore = "update required"] fn range_checks_rand() { let mut checker = RangeChecker::new(); let values = rand_array::(); diff --git a/processor/src/trace/tests/range.rs b/processor/src/trace/tests/range.rs index 54035cbb66..96e1c44e2a 100644 --- a/processor/src/trace/tests/range.rs +++ b/processor/src/trace/tests/range.rs @@ -1,56 +1,16 @@ use super::{build_trace_from_ops, Felt, FieldElement, Trace, NUM_RAND_ROWS, ONE, ZERO}; use miden_air::trace::{ - chiplets::hasher::HASH_CYCLE_LEN, - range::{B_RANGE_COL_IDX, Q_COL_IDX}, - AUX_TRACE_RAND_ELEMENTS, + chiplets::hasher::HASH_CYCLE_LEN, range::B_RANGE_COL_IDX, AUX_TRACE_RAND_ELEMENTS, }; use rand_utils::rand_array; use vm_core::Operation; -#[test] -#[allow(clippy::needless_range_loop)] -fn q_trace() { - let stack = [1, 255]; - let operations = vec![ - Operation::U32add, - Operation::MStoreW, - Operation::Drop, - Operation::Drop, - Operation::Drop, - Operation::Drop, - ]; - let mut trace = build_trace_from_ops(operations, &stack); - - let rand_elements = rand_array::(); - let alpha = rand_elements[0]; - let aux_columns = trace.build_aux_segment(&[], &rand_elements).unwrap(); - let q = aux_columns.get_column(Q_COL_IDX); - - assert_eq!(trace.length(), q.len()); - - // --- Check the stack processor's range check lookups. --------------------------------------- - - // Before any range checks are executed, the value in b_range should be one. - assert_eq!(Felt::ONE, q[0]); - - // The first range check lookup from the stack will happen when the add operation is executed, - // at cycle 1. (The trace begins by executing `span`). It must be divided out of `b_range`. - // The range-checked values are 0, 256, 0, 0. - let expected = (alpha) * (Felt::new(256) + alpha) * alpha.square(); - assert_eq!(expected, q[1]); - - // --- Check the last value of the q column is one. ------------------------------------------ - - for row in 2..(q.len() - NUM_RAND_ROWS) { - assert_eq!(Felt::ONE, q[row]); - } -} - /// This test checks that range check lookups from stack operations are balanced by the range checks /// processed in the Range Checker. /// /// The `U32add` operation results in 4 16-bit range checks of 256, 0, 0, 0. #[test] +#[ignore = "update required"] fn b_range_trace_stack() { let stack = [1, 255]; let operations = vec![Operation::U32add]; @@ -115,6 +75,7 @@ fn b_range_trace_stack() { /// The `LoadW` memory operation results in 2 16-bit range checks of 0, 0. #[test] #[allow(clippy::needless_range_loop)] +#[ignore = "update required"] fn b_range_trace_mem() { let stack = [0, 1, 2, 3, 4, 0]; let operations = vec![ From 5a78583e21a411c3e98d42c64848a7273dd1e786 Mon Sep 17 00:00:00 2001 From: grjte Date: Fri, 28 Jul 2023 09:59:44 -0400 Subject: [PATCH 07/18] test(proc): update range checker tests --- processor/src/range/tests.rs | 54 +++++++----------------- processor/src/trace/tests/range.rs | 68 ++++++++++++++---------------- 2 files changed, 48 insertions(+), 74 deletions(-) diff --git a/processor/src/range/tests.rs b/processor/src/range/tests.rs index 0ca4db8fd1..1cdce79756 100644 --- a/processor/src/range/tests.rs +++ b/processor/src/range/tests.rs @@ -1,17 +1,17 @@ -use super::{super::ONE, BTreeMap, Felt, RangeChecker, Vec, ZERO}; +use super::{BTreeMap, Felt, RangeChecker, Vec, ZERO}; use crate::{utils::get_trace_len, RangeCheckTrace}; use rand_utils::rand_array; use vm_core::{utils::ToElements, StarkField}; #[test] -#[ignore = "update required"] fn range_checks() { let mut checker = RangeChecker::new(); let values = [0, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 100, 355, 620].to_elements(); for &value in values.iter() { - checker.add_value(value.as_int() as u16) + // add the value to the range checker's trace + checker.add_value(value.as_int() as u16); } let RangeCheckTrace { @@ -30,8 +30,7 @@ fn range_checks() { validate_row(&trace, &mut i, 0, 1); validate_row(&trace, &mut i, 1, 1); validate_row(&trace, &mut i, 2, 4); - validate_row(&trace, &mut i, 3, 2); - validate_row(&trace, &mut i, 3, 1); + validate_row(&trace, &mut i, 3, 3); validate_row(&trace, &mut i, 4, 2); validate_bridge_rows(&trace, &mut i, 4, 100); @@ -48,7 +47,6 @@ fn range_checks() { } #[test] -#[ignore = "update required"] fn range_checks_rand() { let mut checker = RangeChecker::new(); let values = rand_array::(); @@ -69,22 +67,13 @@ fn range_checks_rand() { // ================================================================================================ fn validate_row(trace: &[Vec], row_idx: &mut usize, value: u64, num_lookups: u64) { - let (s0, s1) = match num_lookups { - 0 => (ZERO, ZERO), - 1 => (ONE, ZERO), - 2 => (ZERO, ONE), - 4 => (ONE, ONE), - _ => panic!("invalid lookup value"), - }; - - assert_eq!(s0, trace[0][*row_idx]); - assert_eq!(s1, trace[1][*row_idx]); - assert_eq!(Felt::new(value), trace[2][*row_idx]); + assert_eq!(trace[0][*row_idx], Felt::from(num_lookups)); + assert_eq!(trace[1][*row_idx], Felt::from(value)); *row_idx += 1; } fn validate_trace(trace: &[Vec], lookups: &[Felt]) { - assert_eq!(3, trace.len()); + assert_eq!(2, trace.len()); // trace length must be a power of two let trace_len = get_trace_len(trace); @@ -95,8 +84,8 @@ fn validate_trace(trace: &[Vec], lookups: &[Felt]) { let mut lookups_16bit = BTreeMap::new(); // process the first row - assert_eq!(ZERO, trace[2][i]); - let count = get_lookup_count(trace, i); + assert_eq!(trace[1][i], ZERO); + let count = trace[0][i].as_int(); lookups_16bit.insert(0u16, count); i += 1; @@ -104,7 +93,7 @@ fn validate_trace(trace: &[Vec], lookups: &[Felt]) { let mut prev_value = 0u16; while i < trace_len { // make sure the value is a 16-bit value - let value = trace[2][i].as_int(); + let value = trace[1][i].as_int(); assert!(value <= 65535, "not a 16-bit value"); let value = value as u16; @@ -113,15 +102,18 @@ fn validate_trace(trace: &[Vec], lookups: &[Felt]) { assert!(valid_delta(delta)); // keep track of lookup count for each value - let count = get_lookup_count(trace, i); - lookups_16bit.entry(value).and_modify(|value| *value += count).or_insert(count); + let multiplicity = trace[0][i].as_int(); + lookups_16bit + .entry(value) + .and_modify(|count| *count += multiplicity) + .or_insert(multiplicity); i += 1; prev_value = value; } // validate the last row (must be 65535) - let last_value = trace[2][i - 1].as_int(); + let last_value = trace[1][i - 1].as_int(); assert_eq!(65535, last_value); // remove all the looked up values from the lookup table @@ -163,20 +155,6 @@ fn validate_bridge_rows( } } -fn get_lookup_count(trace: &[Vec], step: usize) -> usize { - if trace[0][step] == ZERO && trace[1][step] == ZERO { - 0 - } else if trace[0][step] == ONE && trace[1][step] == ZERO { - 1 - } else if trace[0][step] == ZERO && trace[1][step] == ONE { - 2 - } else if trace[0][step] == ONE && trace[1][step] == ONE { - 4 - } else { - panic!("not a valid count"); - } -} - /// Checks if the delta between two values is 0 or a power of 3 and at most 3^7 fn valid_delta(delta: u16) -> bool { delta == 0 || (59049 % delta == 0 && delta <= 2187) diff --git a/processor/src/trace/tests/range.rs b/processor/src/trace/tests/range.rs index 96e1c44e2a..653fdce0fc 100644 --- a/processor/src/trace/tests/range.rs +++ b/processor/src/trace/tests/range.rs @@ -3,14 +3,13 @@ use miden_air::trace::{ chiplets::hasher::HASH_CYCLE_LEN, range::B_RANGE_COL_IDX, AUX_TRACE_RAND_ELEMENTS, }; use rand_utils::rand_array; -use vm_core::Operation; +use vm_core::{ExtensionOf, Operation}; /// This test checks that range check lookups from stack operations are balanced by the range checks /// processed in the Range Checker. /// /// The `U32add` operation results in 4 16-bit range checks of 256, 0, 0, 0. #[test] -#[ignore = "update required"] fn b_range_trace_stack() { let stack = [1, 255]; let operations = vec![Operation::U32add]; @@ -30,37 +29,34 @@ fn b_range_trace_stack() { assert_eq!(Felt::ONE, b_range[1]); // The first range check lookup from the stack will happen when the add operation is executed, - // at cycle 1. (The trace begins by executing `span`). It must be divided out of `b_range`. - // The range-checked values are 0, 256, 0, 0. - let lookup_product = (alpha) * (Felt::new(256) + alpha) * alpha.square(); - let mut expected = lookup_product.inv(); + // at cycle 1. (The trace begins by executing `span`). It must be subtracted out of `b_range`. + // The range-checked values are 0, 256, 0, 0, so the values to subtract are 3/(alpha - 0) and + // 1/(alpha - 256). + let lookups = alpha.inv().mul_base(Felt::new(3)) + (alpha - Felt::new(256)).inv(); + let mut expected = b_range[1] - lookups; assert_eq!(expected, b_range[2]); // --- Check the range checker's lookups. ----------------------------------------------------- - // 45 rows are needed for 0, 0, 243, 252, 255, 256, ... 38 additional bridge rows of - // powers of 3 ..., 65535. (0 is range-checked in 2 rows for a total of 3 lookups. 256 is - // range-checked in one row. 65535 is the max, and the rest are "bridge" values.) An extra row - // is added to pad the u16::MAX value. - let len_16bit = 45 + 1; + // 44 rows are needed for 0, 243, 252, 255, 256, ... 38 additional bridge rows of powers of + // 3 ..., 65535. (0 and 256 are range-checked. 65535 is the max, and the rest are "bridge" + // values.) An extra row is added to pad the u16::MAX value. + let len_16bit = 44 + 1; // The start of the values in the range checker table. let values_start = trace.length() - len_16bit - NUM_RAND_ROWS; // After the padded rows, the first value will be unchanged. assert_eq!(expected, b_range[values_start]); - // We include 2 lookups of 0, so the next value should be multiplied by alpha squared. - expected *= alpha.square(); + // We include 3 lookups of 0. + expected += alpha.inv().mul_base(Felt::new(3)); assert_eq!(expected, b_range[values_start + 1]); - // Then we include our third lookup of 0, so the next value should be multiplied by alpha. - expected *= alpha; - assert_eq!(expected, b_range[values_start + 2]); // Then we have 3 bridge rows between 0 and 255 where the value does not change + assert_eq!(expected, b_range[values_start + 2]); assert_eq!(expected, b_range[values_start + 3]); assert_eq!(expected, b_range[values_start + 4]); - assert_eq!(expected, b_range[values_start + 5]); // Then we include 1 lookup of 256, so it should be multiplied by alpha + 256. - expected *= alpha + Felt::new(256); - assert_eq!(expected, b_range[values_start + 6]); + expected += (alpha - Felt::new(256)).inv(); + assert_eq!(expected, b_range[values_start + 5]); // --- Check the last value of the b_range column is one. ------------------------------------------ @@ -75,7 +71,6 @@ fn b_range_trace_stack() { /// The `LoadW` memory operation results in 2 16-bit range checks of 0, 0. #[test] #[allow(clippy::needless_range_loop)] -#[ignore = "update required"] fn b_range_trace_mem() { let stack = [0, 1, 2, 3, 4, 0]; let operations = vec![ @@ -98,16 +93,15 @@ fn b_range_trace_mem() { // The memory section of the chiplets trace starts after the span hash. let memory_start = HASH_CYCLE_LEN; - // 41 rows are needed for 0, 0, 3, 4, ... 36 bridge additional bridge rows of powers of - // 3 ..., 65535. (0 is range-checked in 2 rows for a total of 3 lookups. Four is range - // checked in one row for a total of one lookup. 65535 is the max, and the rest are "bridge" + // 40 rows are needed for 0, 3, 4, ... 36 bridge additional bridge rows of powers of + // 3 ..., 65535. (0 and 4 are both range-checked. 65535 is the max, and the rest are "bridge" // values.) An extra row is added to pad the u16::MAX value. - let len_16bit = 41 + 1; + let len_16bit = 40 + 1; let values_start = trace.length() - len_16bit - NUM_RAND_ROWS; // The value should start at ONE and be unchanged until the memory processor section begins. let mut expected = ONE; - for row in 0..=memory_start { + for row in 0..memory_start { assert_eq!(expected, b_range[row]); } @@ -121,31 +115,33 @@ fn b_range_trace_mem() { let (d0_load, d1_load) = (Felt::new(4), ZERO); // Include the lookups from the `MStoreW` operation at the next row. - expected *= ((d0_store + alpha) * (d1_store + alpha)).inv(); + expected -= (alpha - d0_store).inv() + (alpha - d1_store).inv(); assert_eq!(expected, b_range[memory_start + 1]); // Include the lookup from the `MLoadW` operation at the next row. - expected *= ((d0_load + alpha) * (d1_load + alpha)).inv(); + expected -= (alpha - d0_load).inv() + (alpha - d1_load).inv(); assert_eq!(expected, b_range[memory_start + 2]); + // The value should be unchanged until the range checker's lookups are included. + for row in memory_start + 2..=values_start { + assert_eq!(expected, b_range[row]); + } + // --- Check the range checker's lookups. ----------------------------------------------------- - // We include 2 lookups of ZERO in the next row. - expected *= alpha.square(); + // We include 3 lookups of ZERO in the next row. + expected += alpha.inv().mul_base(Felt::new(3)); assert_eq!(expected, b_range[values_start + 1]); - // We include 1 more lookup of ZERO in the next row. - expected *= d0_store + alpha; - assert_eq!(expected, b_range[values_start + 2]); // then we have one bridge row between 0 and 4 where the value does not change. - assert_eq!(expected, b_range[values_start + 3]); + assert_eq!(expected, b_range[values_start + 2]); // We include 1 lookup of 4 in the next row. - expected *= d0_load + alpha; - assert_eq!(expected, b_range[values_start + 4]); + expected += (alpha - d0_load).inv(); + assert_eq!(expected, b_range[values_start + 3]); // --- The value should now be ONE for the rest of the trace. --------------------------------- assert_eq!(expected, ONE); - for i in (values_start + 4)..(b_range.len() - NUM_RAND_ROWS) { + for i in (values_start + 3)..(b_range.len() - NUM_RAND_ROWS) { assert_eq!(ONE, b_range[i]); } } From 382ff12f638dd94bdeafd1adaf37601c5a3eb77d Mon Sep 17 00:00:00 2001 From: grjte Date: Fri, 28 Jul 2023 20:01:51 -0400 Subject: [PATCH 08/18] feat(proc): optimize range checker aux trace generation --- processor/src/range/aux_trace.rs | 78 +++++++++++++++++++++++++------- processor/src/range/mod.rs | 23 ++++++---- 2 files changed, 76 insertions(+), 25 deletions(-) diff --git a/processor/src/range/aux_trace.rs b/processor/src/range/aux_trace.rs index 415689e7e0..a0d9c258bd 100644 --- a/processor/src/range/aux_trace.rs +++ b/processor/src/range/aux_trace.rs @@ -1,5 +1,6 @@ use super::{uninit_vector, BTreeMap, ColMatrix, Felt, FieldElement, Vec, NUM_RAND_ROWS}; use miden_air::trace::range::{M_COL_IDX, V_COL_IDX}; +use vm_core::StarkField; // AUXILIARY TRACE BUILDER // ================================================================================================ @@ -7,9 +8,11 @@ use miden_air::trace::range::{M_COL_IDX, V_COL_IDX}; /// Describes how to construct the execution trace of columns related to the range checker in the /// auxiliary segment of the trace. These are used in multiset checks. pub struct AuxTraceBuilder { + /// A list of the unique values for which range checks are performed. + lookup_values: Vec, /// Range check lookups performed by all user operations, grouped and sorted by the clock cycle /// at which they are requested. - cycle_range_checks: BTreeMap>, + cycle_lookups: BTreeMap>, // The index of the first row of Range Checker's trace when the padded rows end and values to // be range checked start. values_start: usize, @@ -18,9 +21,14 @@ pub struct AuxTraceBuilder { impl AuxTraceBuilder { // CONSTRUCTOR // -------------------------------------------------------------------------------------------- - pub fn new(cycle_range_checks: BTreeMap>, values_start: usize) -> Self { + pub fn new( + lookup_values: Vec, + cycle_lookups: BTreeMap>, + values_start: usize, + ) -> Self { Self { - cycle_range_checks, + lookup_values, + cycle_lookups, values_start, } } @@ -46,12 +54,10 @@ impl AuxTraceBuilder { fn build_aux_col_b_range>( &self, main_trace: &ColMatrix, - alphas: &[E], + rand_elements: &[E], ) -> Vec { - // TODO: replace this with an efficient solution - // // compute the inverses for range checks performed by operations. - // let (_, inv_row_values) = - // build_lookup_table_row_values(&self.cycle_range_check_values(), main_trace, alphas); + // run batch inversion on the lookup values + let divisors = get_divisors(&self.lookup_values, rand_elements[0]); // allocate memory for the running sum column and set the initial value to ONE let mut b_range = unsafe { uninit_vector(main_trace.num_rows()) }; @@ -62,7 +68,7 @@ impl AuxTraceBuilder { let mut b_range_idx = 0_usize; // the first half of the trace only includes values from the operations. - for (clk, range_checks) in self.cycle_range_checks.range(0..self.values_start as u32) { + for (clk, range_checks) in self.cycle_lookups.range(0..self.values_start as u32) { let clk = *clk as usize; // if we skipped some cycles since the last update was processed, values in the last @@ -78,8 +84,8 @@ impl AuxTraceBuilder { b_range[b_range_idx] = b_range[clk]; // include the operation lookups for lookup in range_checks.iter() { - let value = (alphas[0] - (*lookup).into()).inv(); - b_range[b_range_idx] -= value; + let value = divisors.get(lookup).expect("invalid lookup value {}"); + b_range[b_range_idx] -= *value; } } @@ -103,14 +109,19 @@ impl AuxTraceBuilder { { b_range_idx = row_idx + 1; - // add the value in the range checker: multiplicity / (alpha - lookup) - let lookup_val = (alphas[0] - (*lookup).into()).inv().mul_base(*multiplicity); - b_range[b_range_idx] = b_range[row_idx] + lookup_val; + if multiplicity.as_int() != 0 { + // add the value in the range checker: multiplicity / (alpha - lookup) + let value = divisors.get(&(lookup.as_int() as u16)).expect("invalid lookup value"); + b_range[b_range_idx] = b_range[row_idx] + value.mul_base(*multiplicity); + } else { + b_range[b_range_idx] = b_range[row_idx]; + } + // subtract the range checks requested by operations - if let Some(range_checks) = self.cycle_range_checks.get(&(row_idx as u32)) { + if let Some(range_checks) = self.cycle_lookups.get(&(row_idx as u32)) { for lookup in range_checks.iter() { - let value = (alphas[0] - (*lookup).into()).inv(); - b_range[b_range_idx] -= value; + let value = divisors.get(lookup).expect("invalid lookup value"); + b_range[b_range_idx] -= *value; } } } @@ -126,3 +137,36 @@ impl AuxTraceBuilder { b_range } } + +/// Runs batch inversion on all range check lookup values and returns a map which maps of each value +/// to the divisor used for including it in the LogUp lookup. In other words, the map contains +/// mappings of x to 1/(alpha - x). +fn get_divisors>( + lookup_values: &[u16], + alpha: E, +) -> BTreeMap { + // run batch inversion on the lookup values + let mut values = unsafe { uninit_vector(lookup_values.len()) }; + let mut inv_values = unsafe { uninit_vector(lookup_values.len()) }; + let mut log_values = BTreeMap::new(); + + let mut acc = E::ONE; + for (i, (value, inv_value)) in values.iter_mut().zip(inv_values.iter_mut()).enumerate() { + *inv_value = acc; + *value = alpha - E::from(lookup_values[i]); + acc *= *value; + } + + // invert the accumulated product + acc = acc.inv(); + + // multiply the accumulated product by the original values to compute the inverses, then + // build a map of inverses for the lookup values + for i in (0..lookup_values.len()).rev() { + inv_values[i] *= acc; + acc *= values[i]; + log_values.insert(lookup_values[i], inv_values[i]); + } + + log_values +} diff --git a/processor/src/range/mod.rs b/processor/src/range/mod.rs index d0794180bc..b22b2169a8 100644 --- a/processor/src/range/mod.rs +++ b/processor/src/range/mod.rs @@ -43,7 +43,7 @@ pub struct RangeChecker { /// Range check lookups performed by all user operations, grouped and sorted by clock cycle. /// Each cycle is mapped to a vector of the range checks requested at that cycle, which can come /// from the stack, memory, or both. - cycle_range_checks: BTreeMap>, + cycle_lookups: BTreeMap>, } impl RangeChecker { @@ -58,7 +58,7 @@ impl RangeChecker { lookups.insert(u16::MAX, 0); Self { lookups, - cycle_range_checks: BTreeMap::new(), + cycle_lookups: BTreeMap::new(), } } @@ -76,18 +76,21 @@ impl RangeChecker { // 4 lookups respectively. debug_assert!(values.len() == 2 || values.len() == 4); - let mut requests = Vec::new(); for value in values.iter() { // add the specified value to the trace of this range checker's lookups. self.add_value(*value); - requests.push(Felt::from(*value)); } // track the range check requests at each cycle - self.cycle_range_checks + // TODO: optimize this to use a struct instead of vectors, e.g.: + // struct MemoryLookupValues { + // num_lookups: u8, + // lookup_values: [u16; 6], + // } + self.cycle_lookups .entry(clk) - .and_modify(|entry| entry.append(&mut requests)) - .or_insert_with(|| requests); + .and_modify(|entry| entry.append(&mut values.to_vec())) + .or_insert_with(|| values.to_vec()); } // EXECUTION TRACE GENERATION (INTERNAL) @@ -145,7 +148,11 @@ impl RangeChecker { RangeCheckTrace { trace, - aux_builder: AuxTraceBuilder::new(self.cycle_range_checks, num_padding_rows), + aux_builder: AuxTraceBuilder::new( + self.lookups.keys().cloned().collect(), + self.cycle_lookups, + num_padding_rows, + ), } } From 2aaeb4042b2d87151466be492f935866fbfd7c35 Mon Sep 17 00:00:00 2001 From: grjte Date: Mon, 7 Aug 2023 09:40:28 -0400 Subject: [PATCH 09/18] docs: update changelog for range checker logup migration --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index da570a2677..0311a9e05a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ #### VM Internals - Simplified range checker and removed 1 main and 1 auxiliary trace column (#949). +- Migrated range checker lookups to use LogUp and reduced the number of trace columns to 2 main and + 1 auxiliary (#1027). - Added `get_mapped_values()` and `get_store_subset()` methods to the `AdviceProvider` trait (#987). - [BREAKING] Added options to specify maximum number of cycles and expected number of cycles for a program (#998). - Improved handling of invalid/incomplete parameters in `StackOutputs` constructors (#1010). From a2a33284529238281e3df433aacd5ea2ac616228 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Wed, 9 Aug 2023 01:09:02 +0200 Subject: [PATCH 10/18] feat(stdlib): Change prepare_message_schedule_and_consume from being exported to internal --- stdlib/asm/crypto/hashes/sha256.masm | 2 +- stdlib/docs/crypto/hashes/sha256.md | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/stdlib/asm/crypto/hashes/sha256.masm b/stdlib/asm/crypto/hashes/sha256.masm index 1c8fd90cab..197d966aee 100644 --- a/stdlib/asm/crypto/hashes/sha256.masm +++ b/stdlib/asm/crypto/hashes/sha256.masm @@ -233,7 +233,7 @@ end #! - msg0 through msg15 are the 64 -bytes input message (in terms of 16 SHA256 words) #! See https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2.hpp#L89-L113 #! & https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2_256.hpp#L148-L187 ( loop body execution ) -export.prepare_message_schedule_and_consume.4 +proc.prepare_message_schedule_and_consume.4 loc_storew.0 loc_storew.2 dropw diff --git a/stdlib/docs/crypto/hashes/sha256.md b/stdlib/docs/crypto/hashes/sha256.md index e45dd6359b..edea4a9727 100644 --- a/stdlib/docs/crypto/hashes/sha256.md +++ b/stdlib/docs/crypto/hashes/sha256.md @@ -2,7 +2,6 @@ ## std::crypto::hashes::sha256 | Procedure | Description | | ----------- | ------------- | -| prepare_message_schedule_and_consume | Computes whole message schedule of 64 message words and consumes them into hash state.

Input: [state0, state1, state2, state3, state4, state5, state6, state7, msg0, msg1, msg2, msg3, msg4, msg5, msg6, msg7, msg8, msg9, msg10, msg11, msg12, msg13, msg14, msg15]

Output: [state0', state1', state2', state3', state4', state5', state6', state7']

Where:

- state0 through state7 are the hash state (in terms of 8 SHA256 words)

- msg0 through msg15 are the 64 -bytes input message (in terms of 16 SHA256 words)

See https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2.hpp#L89-L113

& https://github.com/itzmeanjan/merklize-sha/blob/8a2c006/include/sha2_256.hpp#L148-L187 ( loop body execution ) | | hash_2to1 | Given 64 -bytes input, this routine computes 32 -bytes SHA256 digest

Input: [m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...]

Where: m[0,16) = 32 -bit word

Note, each SHA256 word is 32 -bit wide, so that's how input is expected.

As you've 64 -bytes, consider packing 4 consecutive bytes into single word,

maintaining big endian byte order.

SHA256 digest is represented in terms of eight 32 -bit words ( big endian byte order ). | | hash_1to1 | Given 32 -bytes input, this routine computes 32 -bytes SHA256 digest

Expected stack state:

Input: [m0, m1, m2, m3, m4, m5, m6, m7, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...]

Where: m[0,8) = 32 -bit word

Note, each SHA256 word is 32 -bit wide, so that's how input is expected.

As you've 32 -bytes, consider packing 4 consecutive bytes into single word,

maintaining big endian byte order.

SHA256 digest is represented in terms of eight 32 -bit words ( big endian byte order ). | | hash_memory | Given a memory address and a message length in bytes, compute its sha256 digest

- There must be space for writing the padding after the message in memory

- The padding space after the message must be all zeros before this procedure is called

Input: [addr, len, ...]

Output: [dig0, dig1, dig2, dig3, dig4, dig5, dig6, dig7, ...] | From 631e663f24c1ceb1e0f08ac5d358ebdadd859666 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Wed, 9 Aug 2023 01:58:41 +0200 Subject: [PATCH 11/18] feat(stdlib): Add test for sha256::hash_memory --- stdlib/tests/crypto/sha256.rs | 62 ++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/stdlib/tests/crypto/sha256.rs b/stdlib/tests/crypto/sha256.rs index 166d4ce83d..07c8028483 100644 --- a/stdlib/tests/crypto/sha256.rs +++ b/stdlib/tests/crypto/sha256.rs @@ -1,6 +1,66 @@ use crate::build_test; use sha2::{Digest, Sha256}; -use test_utils::{group_slice_elements, rand::rand_array, Felt, IntoBytes}; +use test_utils::{ + group_slice_elements, + rand::{rand_array, rand_value, rand_vector}, + Felt, IntoBytes, +}; + +#[test] +fn sha256_hash_memory() { + let source = " + use.std::crypto::hashes::sha256 + + begin + # mem.0 - input data address + push.10000 mem_store.0 + + # mem.1 - length in bytes + mem_store.1 + + # mem.2 - length in felts + mem_load.1 u32checked_add.3 u32checked_div.4 mem_store.2 + + # Load input data into memory address 10000, 10001, ... + mem_load.2 u32checked_neq.0 + while.true + mem_load.0 mem_storew dropw + mem_load.0 u32checked_add.1 mem_store.0 + mem_load.2 u32checked_sub.1 dup mem_store.2 u32checked_neq.0 + end + + # Compute hash of memory address 10000, 10001, ... + mem_load.1 + push.10000 + exec.sha256::hash_memory + end"; + + let length = rand_value::() & 1023; // length: 0-1023 + let ibytes: Vec = rand_vector(length as usize); + let ipadding: Vec = vec![0; (4 - (length as usize % 4)) % 4]; + + let ifelts = [ + group_slice_elements::(&[ibytes.clone(), ipadding].concat()) + .iter() + .map(|&bytes| u32::from_be_bytes(bytes) as u64) + .rev() + .collect::>(), + vec![length as u64; 1], + ] + .concat(); + + let mut hasher = Sha256::new(); + hasher.update(ibytes); + + let obytes = hasher.finalize(); + let ofelts = group_slice_elements::(&obytes) + .iter() + .map(|&bytes| u32::from_be_bytes(bytes) as u64) + .collect::>(); + + let test = build_test!(source, &ifelts); + test.expect_stack(&ofelts); +} #[test] fn sha256_2_to_1_hash() { From 5b3b93b65cd8a422c30217596f32ea45db629c74 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Thu, 1 Jun 2023 01:59:11 -0700 Subject: [PATCH 12/18] feat: implementing updates and simple insertions into TSMT --- assembly/src/ast/nodes/advice.rs | 22 +- assembly/src/ast/parsers/adv_ops.rs | 4 + core/src/operations/decorators/advice.rs | 2 +- .../src/decorators/adv_stack_injectors.rs | 114 +++++ processor/src/decorators/mod.rs | 2 +- processor/src/decorators/tests.rs | 80 ++-- stdlib/asm/collections/smt.masm | 419 +++++++++++++++++- stdlib/docs/collections/smt.md | 3 +- stdlib/tests/collections/smt.rs | 151 ++++++- test-utils/src/crypto.rs | 12 +- 10 files changed, 745 insertions(+), 64 deletions(-) diff --git a/assembly/src/ast/nodes/advice.rs b/assembly/src/ast/nodes/advice.rs index 413fc31bd5..00d054da9e 100644 --- a/assembly/src/ast/nodes/advice.rs +++ b/assembly/src/ast/nodes/advice.rs @@ -18,6 +18,7 @@ pub enum AdviceInjectorNode { PushU64div, PushExt2intt, PushSmtGet, + PushSmtInsert, PushMapVal, PushMapValImm { offset: u8 }, PushMapValN, @@ -35,6 +36,7 @@ impl From<&AdviceInjectorNode> for AdviceInjector { PushU64div => Self::DivU64, PushExt2intt => Self::Ext2Intt, PushSmtGet => Self::SmtGet, + PushSmtInsert => Self::SmtInsert, PushMapVal => Self::MapValueToStack { include_len: false, key_offset: 0, @@ -68,6 +70,7 @@ impl fmt::Display for AdviceInjectorNode { PushU64div => write!(f, "push_u64div"), PushExt2intt => write!(f, "push_ext2intt"), PushSmtGet => write!(f, "push_smtget"), + PushSmtInsert => write!(f, "push_smtinsert"), PushMapVal => write!(f, "push_mapval"), PushMapValImm { offset } => write!(f, "push_mapval.{offset}"), PushMapValN => write!(f, "push_mapvaln"), @@ -86,14 +89,15 @@ impl fmt::Display for AdviceInjectorNode { const PUSH_U64DIV: u8 = 0; const PUSH_EXT2INTT: u8 = 1; const PUSH_SMTGET: u8 = 2; -const PUSH_MAPVAL: u8 = 3; -const PUSH_MAPVAL_IMM: u8 = 4; -const PUSH_MAPVALN: u8 = 5; -const PUSH_MAPVALN_IMM: u8 = 6; -const PUSH_MTNODE: u8 = 7; -const INSERT_MEM: u8 = 8; -const INSERT_HDWORD: u8 = 9; -const INSERT_HDWORD_IMM: u8 = 10; +const PUSH_SMTINSERT: u8 = 3; +const PUSH_MAPVAL: u8 = 4; +const PUSH_MAPVAL_IMM: u8 = 5; +const PUSH_MAPVALN: u8 = 6; +const PUSH_MAPVALN_IMM: u8 = 7; +const PUSH_MTNODE: u8 = 8; +const INSERT_MEM: u8 = 9; +const INSERT_HDWORD: u8 = 10; +const INSERT_HDWORD_IMM: u8 = 11; impl Serializable for AdviceInjectorNode { fn write_into(&self, target: &mut W) { @@ -102,6 +106,7 @@ impl Serializable for AdviceInjectorNode { PushU64div => target.write_u8(PUSH_U64DIV), PushExt2intt => target.write_u8(PUSH_EXT2INTT), PushSmtGet => target.write_u8(PUSH_SMTGET), + PushSmtInsert => target.write_u8(PUSH_SMTINSERT), PushMapVal => target.write_u8(PUSH_MAPVAL), PushMapValImm { offset } => { target.write_u8(PUSH_MAPVAL_IMM); @@ -129,6 +134,7 @@ impl Deserializable for AdviceInjectorNode { PUSH_U64DIV => Ok(AdviceInjectorNode::PushU64div), PUSH_EXT2INTT => Ok(AdviceInjectorNode::PushExt2intt), PUSH_SMTGET => Ok(AdviceInjectorNode::PushSmtGet), + PUSH_SMTINSERT => Ok(AdviceInjectorNode::PushSmtInsert), PUSH_MAPVAL => Ok(AdviceInjectorNode::PushMapVal), PUSH_MAPVAL_IMM => { let offset = source.read_u8()?; diff --git a/assembly/src/ast/parsers/adv_ops.rs b/assembly/src/ast/parsers/adv_ops.rs index 4c7042484d..1393796a05 100644 --- a/assembly/src/ast/parsers/adv_ops.rs +++ b/assembly/src/ast/parsers/adv_ops.rs @@ -33,6 +33,10 @@ pub fn parse_adv_inject(op: &Token) -> Result { 2 => AdvInject(PushSmtGet), _ => return Err(ParsingError::extra_param(op)), }, + "push_smtinsert" => match op.num_parts() { + 2 => AdvInject(PushSmtInsert), + _ => return Err(ParsingError::extra_param(op)), + }, "push_mapval" => match op.num_parts() { 2 => AdvInject(PushMapVal), 3 => { diff --git a/core/src/operations/decorators/advice.rs b/core/src/operations/decorators/advice.rs index c939cdbc80..d87ec83219 100644 --- a/core/src/operations/decorators/advice.rs +++ b/core/src/operations/decorators/advice.rs @@ -181,7 +181,7 @@ pub enum AdviceInjector { /// Where KEY is computed as hash(A || B, domain), where domain is provided via the immediate /// value. HdwordToMap { domain: Felt }, - + /// TODO: add docs SmtInsert, } diff --git a/processor/src/decorators/adv_stack_injectors.rs b/processor/src/decorators/adv_stack_injectors.rs index 1d9a2e6f2d..ec0f0ae8ca 100644 --- a/processor/src/decorators/adv_stack_injectors.rs +++ b/processor/src/decorators/adv_stack_injectors.rs @@ -338,6 +338,120 @@ where Ok(()) } + + /// Pushes values onto the advice stack which are required for successful insertion of a + /// key-value pair into a Sparse Merkle Tree data structure. + /// + /// The Sparse Merkle Tree is tiered, meaning it will have leaf depths in `{16, 32, 48, 64}`. + /// + /// Inputs: + /// Operand stack: [VALUE, KEY, ROOT, ...] + /// Advice stack: [...] + /// + /// Outputs: + /// Operand stack: [OLD_VALUE, NEW_ROOT, ...] + /// Advice stack, depends on the type of insert: + /// - Simple insert at depth 16: [d0, d1, ONE (is_simple_insert), ZERO (is_update)] + /// - Simple insert at depth 32 or 48: [d0, d1, ONE (is_simple_insert), ZERO (is_update), P_NODE] + /// - Update of an existing leaf: [ZERO (padding), d0, d1, ONE (is_update), OLD_VALUE] + /// + /// Where: + /// - d0 is a boolean flag set to `1` if the depth is `16` or `48`. + /// - d1 is a boolean flag set to `1` if the depth is `16` or `32`. + /// - P_NODE is an internal node located at the tier above the insert tier. + /// - VALUE is the value to be inserted. + /// - OLD_VALUE is the value previously associated with the specified KEY. + /// - ROOT and NEW_ROOT are the roots of the TSMT prior and post the insert respectively. + /// + /// # Errors + /// Will return an error if the provided Merkle root doesn't exist on the advice provider. + /// + /// # Panics + /// Will panic as unimplemented if the target depth is `64`. + pub(super) fn push_smtinsert_inputs(&mut self) -> Result<(), ExecutionError> { + // get the key and tree root from the stack + let key = [self.stack.get(7), self.stack.get(6), self.stack.get(5), self.stack.get(4)]; + let root = [self.stack.get(11), self.stack.get(10), self.stack.get(9), self.stack.get(8)]; + + // determine the depth of the first leaf or an empty tree node + let index = &key[3]; + let depth = self.advice_provider.get_leaf_depth(root, &SMT_MAX_TREE_DEPTH, index)?; + debug_assert!(depth < 65); + + // map the depth value to its tier; this rounds up depth to 16, 32, 48, or 64 + let depth = SMT_NORMALIZED_DEPTHS[depth as usize]; + if depth == 64 { + unimplemented!("handling of depth=64 tier hasn't been implemented yet"); + } + + // get the value of the node a this index/depth + let index = index.as_int() >> (64 - depth); + let index = Felt::new(index); + let node = self.advice_provider.get_tree_node(root, &Felt::new(depth as u64), &index)?; + + // figure out what kind of insert we are doing; possible options are: + // - if the node is a root of an empty subtree, this is a simple insert. + // - if the node is a leaf, this could be either an update (for the same key), or a + // complex insert (i.e., the existing leaf needs to be moved to a lower tier). + let empty = EmptySubtreeRoots::empty_hashes(64)[depth as usize]; + let (is_update, is_simple_insert) = if node == Word::from(empty) { + // handle simple insert case + if depth == 32 || depth == 48 { + // for depth 32 and 48, we need to provide the internal node located on the tier + // above the insert tier + let p_index = Felt::from(index.as_int() >> 16); + let p_depth = Felt::from(depth - 16); + let p_node = self.advice_provider.get_tree_node(root, &p_depth, &p_index)?; + for &element in p_node.iter().rev() { + self.advice_provider.push_stack(AdviceSource::Value(element))?; + } + } + + // return is_update = ZERO, is_simple_insert = ONE + (ZERO, ONE) + } else { + // if the node is a leaf node, push the elements mapped to this node onto the advice + // stack; the elements should be [KEY, VALUE], with key located at the top of the + // advice stack. + self.advice_provider.push_stack(AdviceSource::Map { + key: node, + include_len: false, + })?; + + // remove the KEY from the advice stack, leaving only the VALUE on the stack + let leaf_key = self.advice_provider.pop_stack_word()?; + + // if the key for the value to be inserted is the same as the leaf's key, we are + // dealing with a simple update. otherwise, we are dealing with a complex insert + // (i.e., the leaf needs to be moved to a lower tier). + if leaf_key == key { + // return is_update = ONE, is_simple_insert = ZERO + (ONE, ZERO) + } else { + // return is_update = ZERO, is_simple_insert = ZERO + (ZERO, ZERO) + } + }; + + // set the flags used to determine which tier the insert is happening at + let is_16_or_32 = if depth == 16 || depth == 32 { ONE } else { ZERO }; + let is_16_or_48 = if depth == 16 || depth == 48 { ONE } else { ZERO }; + + self.advice_provider.push_stack(AdviceSource::Value(is_update))?; + if is_update == ONE { + // for update we don't need to specify whether we are dealing with an insert; but we + // insert an extra ONE at the end so that we can read 4 values from the advice stack + // regardless of which branch is taken. + self.advice_provider.push_stack(AdviceSource::Value(is_16_or_32))?; + self.advice_provider.push_stack(AdviceSource::Value(is_16_or_48))?; + self.advice_provider.push_stack(AdviceSource::Value(ZERO))?; + } else { + self.advice_provider.push_stack(AdviceSource::Value(is_simple_insert))?; + self.advice_provider.push_stack(AdviceSource::Value(is_16_or_32))?; + self.advice_provider.push_stack(AdviceSource::Value(is_16_or_48))?; + } + Ok(()) + } } // HELPER FUNCTIONS diff --git a/processor/src/decorators/mod.rs b/processor/src/decorators/mod.rs index 25c7a28580..811f53b712 100644 --- a/processor/src/decorators/mod.rs +++ b/processor/src/decorators/mod.rs @@ -45,7 +45,7 @@ where AdviceInjector::Ext2Inv => self.push_ext2_inv_result(), AdviceInjector::Ext2Intt => self.push_ext2_intt_result(), AdviceInjector::SmtGet => self.push_smtget_inputs(), - AdviceInjector::SmtInsert => todo!(), + AdviceInjector::SmtInsert => self.push_smtinsert_inputs(), AdviceInjector::MemToMap => self.insert_mem_values_into_adv_map(), AdviceInjector::HdwordToMap { domain } => self.insert_hdword_into_adv_map(*domain), } diff --git a/processor/src/decorators/tests.rs b/processor/src/decorators/tests.rs index 4b9848f1b1..6cccdd8eb5 100644 --- a/processor/src/decorators/tests.rs +++ b/processor/src/decorators/tests.rs @@ -3,7 +3,7 @@ use super::{ Process, }; use crate::{MemAdviceProvider, StackInputs, Word}; -use test_utils::{crypto::get_smt_remaining_key, rand::seeded_word}; +use test_utils::rand::seeded_word; use vm_core::{ crypto::{ hash::{Rpo256, RpoDigest}, @@ -74,13 +74,10 @@ fn push_smtget() { // check leaves on empty trees for depth in [16, 32, 48] { - // compute the remaining key - let remaining = get_smt_remaining_key(key, depth); - // compute node value let depth_element = Felt::from(depth); let store = MerkleStore::new(); - let node = Rpo256::merge_in_domain(&[remaining.into(), value.into()], depth_element); + let node = Rpo256::merge_in_domain(&[key.into(), value.into()], depth_element); // expect absent value with constant depth 16 let expected = [ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ONE, ONE]; @@ -89,9 +86,6 @@ fn push_smtget() { // check leaves inserted on all tiers for depth in [16, 32, 48] { - // compute the remaining key - let remaining = get_smt_remaining_key(key, depth); - // set depth flags let is_16_or_32 = (depth == 16 || depth == 32).then_some(ONE).unwrap_or(ZERO); let is_16_or_48 = (depth == 16 || depth == 48).then_some(ONE).unwrap_or(ZERO); @@ -100,7 +94,7 @@ fn push_smtget() { let index = key[3].as_int() >> 64 - depth; let index = NodeIndex::new(depth, index).unwrap(); let depth_element = Felt::from(depth); - let node = Rpo256::merge_in_domain(&[remaining.into(), value.into()], depth_element); + let node = Rpo256::merge_in_domain(&[key.into(), value.into()], depth_element); // set tier node value and expect the value from the injector let mut store = MerkleStore::new(); @@ -111,10 +105,10 @@ fn push_smtget() { value[2], value[1], value[0], - remaining[3], - remaining[2], - remaining[1], - remaining[0], + key[3], + key[2], + key[1], + key[0], is_16_or_32, is_16_or_48, ]; @@ -158,23 +152,41 @@ fn inject_smtinsert() { let raw_a = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; let key_a = build_key(raw_a); - let val_a = [ONE, ZERO, ZERO, ZERO]; - - // insertion should happen at depth 16 and thus 16_or_32 and 16_or_48 flags should be set to ONE; - // since we are replacing a node which is an empty subtree, the is_empty flag should also be ONE - let expected_stack = [ONE, ONE, ONE]; - let process = prepare_smt_insert(key_a, val_a, &smt, expected_stack.len()); + let val_a = [Felt::new(3), Felt::new(5), Felt::new(7), Felt::new(9)]; + + // this is a simple insertion at depth 16, and thus the flags should look as follows: + let is_update = ZERO; + let is_simple_insert = ONE; + let is_16_or_32 = ONE; + let is_16_or_48 = ONE; + let expected_stack = [is_update, is_simple_insert, is_16_or_32, is_16_or_48]; + let process = prepare_smt_insert(key_a, val_a, &smt, expected_stack.len(), Vec::new()); assert_eq!(build_expected(&expected_stack), process.stack.trace_state()); // --- update same key with different value ------------------------------- + // insert val_a into the tree so that val_b overwrites it + smt.insert(key_a.into(), val_a); let val_b = [ONE, ONE, ZERO, ZERO]; - smt.insert(key_a.into(), val_b); - // we are updating a node at depth 16 and thus 16_or_32 and 16_or_48 flags should be set to ONE; - // since we are updating an existing leaf, the is_empty flag should be set to ZERO - let expected_stack = [ZERO, ONE, ONE]; - let process = prepare_smt_insert(key_a, val_b, &smt, expected_stack.len()); + // this is a simple update, and thus the flags should look as follows: + let is_update = ONE; + let is_16_or_32 = ONE; + let is_16_or_48 = ONE; + + // also, the old value should be present in the advice stack: + let expected_stack = [ + val_a[3], + val_a[2], + val_a[1], + val_a[0], + is_update, + is_16_or_32, + is_16_or_48, + ZERO, + ]; + let adv_map = vec![build_adv_map_entry(key_a, val_a, 16)]; + let process = prepare_smt_insert(key_a, val_b, &smt, expected_stack.len(), adv_map); assert_eq!(build_expected(&expected_stack), process.stack.trace_state()); } @@ -183,12 +195,13 @@ fn prepare_smt_insert( value: Word, smt: &TieredSmt, adv_stack_depth: usize, + adv_map: Vec<([u8; 32], Vec)>, ) -> Process { let root: Word = smt.root().into(); let store = MerkleStore::from(smt); let stack_inputs = build_stack_inputs(value, key, root); - let advice_inputs = AdviceInputs::default().with_merkle_store(store); + let advice_inputs = AdviceInputs::default().with_merkle_store(store).with_map(adv_map); let mut process = build_process(stack_inputs, advice_inputs); process.execute_op(Operation::Noop).unwrap(); @@ -217,7 +230,7 @@ fn build_expected(values: &[Felt]) -> [Felt; 16] { } fn assert_case_smtget( - depth: u8, + _depth: u8, key: Word, value: Word, node: RpoDigest, @@ -226,9 +239,8 @@ fn assert_case_smtget( expected_stack: &[Felt], ) { // build the process - let stack_inputs = build_stack_inputs(key, root, Word::default()); - let remaining = get_smt_remaining_key(key, depth); - let mapped = remaining.into_iter().chain(value.into_iter()).collect(); + let stack_inputs = build_stack_inputs(key, root.into(), Word::default()); + let mapped = key.into_iter().chain(value.into_iter()).collect(); let advice_inputs = AdviceInputs::default() .with_merkle_store(store) .with_map([(node.into_bytes(), mapped)]); @@ -251,7 +263,7 @@ fn build_process( adv_inputs: AdviceInputs, ) -> Process { let advice_provider = MemAdviceProvider::from(adv_inputs); - Process::new(Kernel::default(), stack_inputs, advice_provider) + Process::new(Kernel::default(), stack_inputs, advice_provider, ExecutionOptions::default()) } fn build_stack_inputs(w0: Word, w1: Word, w2: Word) -> StackInputs { @@ -288,3 +300,11 @@ fn move_adv_to_stack(process: &mut Process, adv_stack_depth: process.execute_op(Operation::AdvPop).unwrap(); } } + +fn build_adv_map_entry(key: Word, val: Word, depth: u8) -> ([u8; 32], Vec) { + let node = Rpo256::merge_in_domain(&[key.into(), val.into()], Felt::from(depth)); + let mut elements = Vec::new(); + elements.extend_from_slice(&key); + elements.extend_from_slice(&val); + (node.into(), elements) +} diff --git a/stdlib/asm/collections/smt.masm b/stdlib/asm/collections/smt.masm index b7f93a4112..96faca1fbe 100644 --- a/stdlib/asm/collections/smt.masm +++ b/stdlib/asm/collections/smt.masm @@ -16,6 +16,9 @@ const.EMPTY_48_1=5634734408638476525 const.EMPTY_48_2=9233115969432897632 const.EMPTY_48_3=1437907447409278328 +# HELPER METHODS +# ================================================================================================= + #! Extracts 16 most significant bits from the passed-in value. #! #! Input: [v, ...] @@ -28,6 +31,16 @@ proc.get_top_16_bits u32unchecked_shr.16 end +#! Extracts 32 most significant bits from the passed-in value. +#! +#! Input: [v, ...] +#! Output: [v >> 32, ...] +#! +#! Cycles: 3 +proc.get_top_32_bits + u32split swap drop +end + #! Extracts 48 most significant bits from the passed-in value. #! #! Input: [v, ...] @@ -43,13 +56,16 @@ proc.get_top_48_bits add end +# GET +# ================================================================================================= + #! Get the leaf value for depth 16. #! #! Input: [K, R, ...] #! Output: [V, R, ...] #! #! Cycles: 85 -proc.get16.2 +proc.get_16.2 # compute index of the node by extracting top 16 bits from the key (8 cycles) dup exec.get_top_16_bits movdn.4 # => [K, i, R, ...] @@ -123,7 +139,7 @@ end #! Output: [V, R, ...] #! #! Cycles: 81 -proc.get32.2 +proc.get_32.2 # compute index of the node by extracting top 16 bits from the key (4 cycles) dup u32split movdn.5 drop # => [K, i, R, ...] @@ -197,7 +213,7 @@ end #! Output: [V, R, ...] #! #! Cycles: 88 -proc.get48.2 +proc.get_48.2 # compute index of the node by extracting top 48 bits from the key (11 cycles) dup exec.get_top_48_bits movdn.4 # => [K, i, R, ...] @@ -275,6 +291,7 @@ end #! Depth 16: 91 cycles #! Depth 32: 87 cycles #! Depth 48: 94 cycles +#! Depth 64: unimplemented export.get # invoke adv and fetch target depth flags adv.push_smtget adv_push.2 @@ -284,15 +301,15 @@ export.get if.true if.true # depth 16 - exec.get16 + exec.get_16 else # depth 32 - exec.get32 + exec.get_32 end else if.true # depth 48 - exec.get48 + exec.get_48 else # depth 64 # currently not implemented @@ -301,3 +318,393 @@ export.get end # => [V, R, ...] end + +# INSERT +# ================================================================================================= + +#! Updates a leaf node at depths 16, 32, or 48. +#! +#! Input: [d, idx, V, K, R, ...]; +#! Output: [V_old, R_new, ...] +#! +#! Where: +#! - R is the initial root of the TSMT, and R_new is the new root of the TSMT. +#! - d, idx are the depth and index (at that depth) of the leaf node to be updated. +#! - K, V are the key-value pair for the leaf node where V is a new value for key K. +#! - V_old is the value previously stored under key K. +#! +#! This procedure succeeds only if: +#! - Node to be replaced at (d, idx) is a leaf node for the same key K. +#! +#! Cycles: 101 +proc.update_16_32_48.2 + # save [idx, d, 0, 0] in loc[0] (5 cycles) + push.0.0 loc_storew.0 + # => [0, 0, d, idx, V, K, R, ...] + + # prepare the stack for computing N = hash([K, V], domain=d), and also save K into loc[1] + # (10 cycles) + movdn.3 movup.2 drop push.0 swapw.2 loc_storew.1 swapw + # => [V, K, 0, 0, d, 0, R, ...] + + # compute the hash of the node N = hash([K, V], domain=d) - (1 cycle) + hperm + # => [X, N, X, R, ...] + + # prepare the stack for the mtree_set operation (8 cycles) + swapw.3 swapw swapw.2 loc_loadw.0 drop drop + # => [d, idx, R, N, X, ...] + + # insert the new leaf node into the tree at the specified index/depth; this also leaves the + # previous value of the node on the stack (29 cycle) + mtree_set + # => [N_old, R_new, X, ...] + + # verify that N_old is a leaf node for the same key K + + # prepare the stack for computing E = hash([K, V_old], domain=d); value of V_old is read + # from the advice provider and is saved into loc[0] (21 cycles) + swapw.2 loc_loadw.0 movdn.3 push.0 movup.3 push.0.0.0 loc_loadw.1 adv_push.4 loc_storew.0 + # => [V_old, K, 0, 0, d, 0, R_new, N_old, ...] + + # compute E = hash([K, V_old], domain=d) + # (10 cycle) + hperm dropw swapw dropw + # => [E, R_new, N_old, ...] + + # make sure E and N_old are the same (14 cycles) + swapw swapw.2 + repeat.4 + dup.4 assert_eq + end + # => [E, R_new, ...] + + # load the old value (which we saved previously) onto the stack (3 cycles) + loc_loadw.0 + # => [V_old, R_new, ...] +end + +#! Inserts a new leaf node at depth 16. +#! +#! Input: [V, K, R, ...]; +#! Output:[0, 0, 0, 0, R_new, ...] +#! +#! Where: +#! - R is the initial root of the TSMT, and R_new is the new root of the TSMT. +#! - K and V is the key-value pair for the leaf node to be inserted. +#! +#! This procedure succeeds only if: +#! - Node to be replaced at depth 16 is a root of an empty subtree. +#! +#! Cycles: 73 +proc.insert_16 + # extract 16-bit index from the key (8 cycles) + swapw dup exec.get_top_16_bits + # => [idx, K, V, R, ...] + + # prepare the stack for computing leaf node value (6 cycles) + movdn.8 push.0.16.0.0 swapw.2 + # => [V, K, 0, 0, 16, 0, idx, R, ...] + + # compute leaf node value as N = hash([K, V], domain=16) (10 cycles) + hperm dropw swapw dropw + # => [N, idx, R, ...] + + # prepare the stack for mtree_set operation (4 cycles) + swapw movup.8 movdn.4 push.16 + # => [16, idx, R, N, ...] + + # insert the node into the tree at depth 16; this also leaves the old value of the node on the + # stack (29 cycle) + mtree_set + # => [N_old, R_new, ...] + + # verify that the old value of the node was a root of an empty subtree for depth 16 (12 cycles) + push.EMPTY_16_3 assert_eq + push.EMPTY_16_2 assert_eq + push.EMPTY_16_1 assert_eq + push.EMPTY_16_0 assert_eq + + # put the return value onto the stack and return (4 cycles) + padw + # => [0, 0, 0, 0, R_new, ...] +end + +#! Inserts a new leaf node at depth 32. +#! +#! Input: [V, K, R, ...]; +#! Output:[0, 0, 0, 0, R_new, ...] +#! +#! Where: +#! - R is the initial root of the TSMT, and R_new is the new root of the TSMT. +#! - K, V is the key-value pair for the leaf node to be inserted into the TSMT. +#! +#! This procedure consists of two high-level steps: +#! - First, insert N = hash([K, V], domain=32) into a subtree with root P, where P is the +#! internal node at depth 16 on the path to the new leaf node. This outputs the new root +#! of the subtree P_new. +#! - Then, insert P_new into the TSMT with root R. +#! +#! We do this to minimize the number of hashes consumed by the procedure for Merkle path +#! verification. Specifically, Merkle path verification will require exactly 64 hashes. +#! +#! This procedure succeeds only if: +#! - Node at depth 16 is an internal node. +#! - Node at depth 32 is a root of an empty subtree. +#! +#! Cycles: 154 +proc.insert_32.2 + # load the value of P from the advice provider (5 cycles) + adv_push.4 swapw.2 + # => [K, V, P, R, ...] + + # save k3 into loc[0][0] (4 cycles) + dup loc_store.0 + # => [K, V, P, R, ...] + + # prepare the stack for computing N = hash([K, V], domain=32) - (5 cycles) + push.0.32.0.0 swapw.2 + # => [V, K, 0, 0, 32, 0, P, R, ...] + + # compute N = hash([K, V], domain=32) (1 cycle) + hperm + # => [X, N, X, P, R, ...] + + # save P into loc[1] to be used later (5 cycles) + swapw.3 loc_storew.1 + # => [P, N, X, X, R, ...] + + # make sure P is not a root of an empty subtree at depth 16 (17 cycles) + dup push.EMPTY_16_3 eq + dup.2 push.EMPTY_16_2 eq + dup.4 push.EMPTY_16_1 eq + dup.6 push.EMPTY_16_0 eq + and and and assertz + # => [P, N, X, X, R, ...] + + # load k3 from memory, extract upper 32 bits from it and split them into two 16-bit values + # such that the top 16-bits are in idx_hi and the next 16 bits are in idx_lo (9 cycles) + loc_load.0 exec.get_top_32_bits u32unchecked_divmod.65536 + # => [idx_lo, idx_hi, P, N, X, X, R, ...] + + # save idx_hi into loc[0][0] to be used later (5 cycles) + swap loc_store.0 + # => [idx_lo, P, N, X, X, R, ...] + + # replace node at idx_lo in P with N, the old value of the node is left on the stack; this also + # proves that P is a leaf node because a leaf node cannot have children at depth 16 (30 cycles) + push.16 mtree_set + # => [N_old, P_new, X, X, R, ...] + + # make sure that N_old is a root of an empty subtree at depth 32 (12 cycles) + push.EMPTY_32_3 assert_eq + push.EMPTY_32_2 assert_eq + push.EMPTY_32_1 assert_eq + push.EMPTY_32_0 assert_eq + # => [P_new, X, X, R, ...] + + # prepare the stack for mtree_set operation against R; here we load idx_hi from loc[0][0] + # (11 cycles) + swapw.2 dropw swapw.2 loc_load.0 push.16 + # => [16, idx_hi, R, P_new, X, ...] + + # insert P_new into tree with root R at depth 16 and idx_hi index (29 cycles) + mtree_set + # => [P_old, R_new, X, ...] + + # load previously saved P to compare it with P_old (6 cycles) + swapw swapw.2 loc_loadw.1 + # => [P, P_old, R_new, ...] + + # make sure P and P_old are the same (11 cycles) + assert_eqw + # => [R_new, ...] + + # put the return value onto the stack and return (4 cycles) + padw + # => [0, 0, 0, 0, R_new, ...] +end + +#! Inserts a new leaf node at depth 48. +#! +#! Input: [V, K, R, ...]; +#! Output:[0, 0, 0, 0, R_new, ...] +#! +#! This procedure is nearly identical to the insert_32 procedure above, adjusted for the use of +#! constants and idx_hi/idx_lo computation. It may be possible to combine the two at the expense +#! of extra 10 - 20 cycles. +proc.insert_48.2 + # load the value of P from the advice provider (5 cycles) + adv_push.4 swapw.2 + # => [K, V, P, R, ...] + + # save k3 into loc[0][0] (4 cycles) + dup loc_store.0 + # => [K, V, P, R, ...] + + # prepare the stack for computing N = hash([K, V], domain=48) - (5 cycles) + push.0.48.0.0 swapw.2 + # => [V, K, 0, 0, 48, 0, P, R, ...] + + # compute N = hash([K, V], domain=48) (1 cycle) + hperm + # => [X, N, X, P, R, ...] + + # save P into loc[1] to be used later (5 cycles) + swapw.3 loc_storew.1 + # => [P, N, X, X, R, ...] + + # make sure P is not a root of an empty subtree at depth 32 (17 cycles) + dup push.EMPTY_32_3 eq + dup.2 push.EMPTY_32_2 eq + dup.4 push.EMPTY_32_1 eq + dup.6 push.EMPTY_32_0 eq + and and and assertz + # => [P, N, X, X, R, ...] + + # load k3 from memory, extract upper 48 bits from it and split them into two values such that + # the top 32-bits are in idx_hi and the next 16 bits are in idx_lo (9 cycles) + loc_load.0 u32split swap u32unchecked_divmod.65536 drop + # => [idx_lo, idx_hi, P, N, X, X, R, ...] + + # save idx_hi into loc[0][0] to be used later (5 cycles) + swap loc_store.0 + # => [idx_lo, P, N, X, X, R, ...] + + # replace node at idx_lo in P with N, the old value of the node is left on the stack; this also + # proves that P is a leaf node because a leaf node cannot have children at depth 16 (30 cycles) + push.16 mtree_set + # => [N_old, P_new, X, X, R, ...] + + # make sure that N_old is a root of an empty subtree at depth 48 (12 cycles) + push.EMPTY_48_3 assert_eq + push.EMPTY_48_2 assert_eq + push.EMPTY_48_1 assert_eq + push.EMPTY_48_0 assert_eq + # => [P_new, X, X, R, ...] + + # prepare the stack for mtree_set operation against R; here we load idx_hi from loc[0][0] + # (11 cycles) + swapw.2 dropw swapw.2 loc_load.0 push.32 + # => [32, idx_hi, R, P_new, X, ...] + + # insert P_new into tree with root R at depth 32 and idx_hi index (29 cycles) + mtree_set + # => [P_old, R_new, X, ...] + + # load previously saved P with P_old to make sure they are the same (6 cycles) + swapw swapw.2 loc_loadw.1 + # => [P, P_old, R_new, ...] + + # make sure P and P_old are the same (11 cycles) + assert_eqw + # => [R_new, ...] + + # put the return value onto the stack and return (4 cycles) + padw + # => [0, 0, 0, 0, R_new, ...] +end + +#! Inserts the specified value into a Sparse Merkle Tree with the specified root under the +#! specified key. +#! +#! The value previously stored in the SMT under this key is left on the stack together with +#! the updated tree root. +#! +#! This assumes that the value is not [ZERO; 4]. If it is, the procedure fails. +#! +#! Input: [V, K, R, ...]; +#! Output:[V_old, R', ...] +#! +#! Cycles: +#! - Update existing leaf: +#! - Depth 16: 129 +#! - Depth 32: 126 +#! - Depth 48: 131 +#! - Insert new leaf: +#! - Depth 16: 100 +#! - Depth 32: 181 +#! - Depth 48: 181 +#! - Replace a leaf with a subtree: +#! - Depth 32: TODO +#! - Depth 48: TODO +export.insert + # make sure the value is not [ZERO; 4] (17 cycles) + repeat.4 + dup.3 eq.0 + end + and and and assertz + # => [V, K, R, ...] + + # arrange the data needed for the insert procedure on the advice stack and move the + # first 4 flags onto the operand stack; meaning of the flags f0, f1, and f2 depends + # on what type of insert is being executed (4 cycles) + adv.push_smtinsert adv_push.4 + # => [is_update, f0, f1, f2, V, K, R, ...] + + # call the inner procedure depending on the type of insert and depth + if.true # --- update leaf ------------------------------------------------- + # => [is_16_or_32, is_16_or_48, ZERO, V, K, R, ...] + if.true + if.true # --- update a leaf node at depth 16 --- + drop + # => [V, K, R, ...] + + # (cycles 8) + dup.4 exec.get_top_16_bits + push.16 + # => [16, idx, V, K, R, ...] + + exec.update_16_32_48 + else # --- update a leaf node at depth 32 --- + drop + # => [V, K, R, ...] + + #(5 cycles) + dup.4 exec.get_top_32_bits + push.32 + # => [32, idx, V, K, R, ...] + + exec.update_16_32_48 + end + else + if.true # --- update a leaf node at depth 48 --- + drop + # => [V, K, R, ...] + + # (10 cycles) + dup.4 exec.get_top_48_bits + push.48 + # => [48, idx, V, K, R, ...] + + exec.update_16_32_48 + else + # depth 64 - currently not implemented + push.0 assert + end + end + else + # => [is_simple_insert, is_16_or_32, is_16_or_48, V, K, R, ...] + if.true # --- inset new leaf ---------------------------------------------- + if.true + if.true + exec.insert_16 + else + exec.insert_32 + end + else + if.true + exec.insert_48 + else + # depth 64 - currently not implemented + push.0 assert + end + end + else # --- replace leaf with subtree ---------------------------------- + # TODO: implement replace leaf with subtree + push.0 assert + end + end + + # => [V, R, ...] +end diff --git a/stdlib/docs/collections/smt.md b/stdlib/docs/collections/smt.md index 85ee7a889d..edb31797fe 100644 --- a/stdlib/docs/collections/smt.md +++ b/stdlib/docs/collections/smt.md @@ -2,4 +2,5 @@ ## std::collections::smt | Procedure | Description | | ----------- | ------------- | -| get | Returns the value stored under the specified key in a Sparse Merkle Tree with the specified root.

If the value for a given key has not been set, the returned `V` will consist of all zeroes.

Input: [K, R, ...]

Output: [V, R, ...]

Depth 16: 91 cycles

Depth 32: 87 cycles

Depth 48: 94 cycles | +| get | Returns the value stored under the specified key in a Sparse Merkle Tree with the specified root.

If the value for a given key has not been set, the returned `V` will consist of all zeroes.

Input: [K, R, ...]

Output: [V, R, ...]

Depth 16: 91 cycles

Depth 32: 87 cycles

Depth 48: 94 cycles

Depth 64: unimplemented | +| insert | Inserts the specified value into a Sparse Merkle Tree with the specified root under the

specified key.

The value previously stored in the SMT under this key is left on the stack together with

the updated tree root.

This assumes that the value is not [ZERO; 4]. If it is, the procedure fails.

Input: [V, K, R, ...];

Output:[V_old, R', ...]

Cycles:

- Update existing leaf:

- Depth 16: 129

- Depth 32: 126

- Depth 48: 131

- Insert new leaf:

- Depth 16: 100

- Depth 32: 181

- Depth 48: 181

- Replace a leaf with a subtree:

- Depth 32: TODO

- Depth 48: TODO | diff --git a/stdlib/tests/collections/smt.rs b/stdlib/tests/collections/smt.rs index 9618368cd9..10f6c4ad60 100644 --- a/stdlib/tests/collections/smt.rs +++ b/stdlib/tests/collections/smt.rs @@ -13,7 +13,7 @@ const EMPTY_VALUE: Word = TieredSmt::EMPTY_VALUE; // ================================================================================================ #[test] -fn smtget_depth_16() { +fn tsmt_get_16() { let mut smt = TieredSmt::default(); // create a key @@ -40,7 +40,7 @@ fn smtget_depth_16() { } #[test] -fn smtget_depth_32() { +fn tsmt_get_32() { let mut smt = TieredSmt::default(); // populate the tree with two key-value pairs sharing the same 16-bit prefix for the keys @@ -75,7 +75,7 @@ fn smtget_depth_32() { } #[test] -fn smtget_depth_48() { +fn tsmt_get_48() { let mut smt = TieredSmt::default(); // populate the tree with two key-value pairs sharing the same 32-bit prefix for the keys @@ -109,6 +109,140 @@ fn smtget_depth_48() { assert_smt_get_opens_correctly(&smt, key_e, EMPTY_VALUE); } +// INSERTS +// ================================================================================================ + +#[test] +fn tsmt_insert_16() { + let mut smt = TieredSmt::default(); + + let raw_a = 0b00000000_00000000_11111111_11111111_11111111_11111111_11111111_11111111_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let val_a1 = [ONE, ZERO, ZERO, ZERO]; + let val_a2 = [ONE, ONE, ZERO, ZERO]; + + // insert a value under key_a into an empty tree + let init_smt = smt.clone(); + smt.insert(key_a.into(), val_a1); + assert_insert(&init_smt, key_a, EMPTY_VALUE, val_a1, smt.root().into()); + + // update a value under key_a + let init_smt = smt.clone(); + smt.insert(key_a.into(), val_a2); + assert_insert(&init_smt, key_a, val_a1, val_a2, smt.root().into()); +} + +#[test] +fn tsmt_insert_32() { + let mut smt = TieredSmt::default(); + + // insert a value under key_a into an empty tree + let raw_a = 0b00000000_00000000_11111111_11111111_11111111_11111111_11111111_11111111_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let val_a = [ONE, ZERO, ZERO, ZERO]; + smt.insert(key_a.into(), val_a); + + // insert a value under key_b which has the same 16-bit prefix as A + let raw_b = 0b00000000_00000000_01111111_11111111_11111111_11111111_11111111_11111111_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let val_b = [ONE, ONE, ZERO, ZERO]; + + // TODO: test this insertion once complex inserts are working + smt.insert(key_b.into(), val_b); + + // update a value under key_a + let init_smt = smt.clone(); + let val_a2 = [ONE, ZERO, ZERO, ONE]; + smt.insert(key_a.into(), val_a2); + assert_insert(&init_smt, key_a, val_a, val_a2, smt.root().into()); + + // insert a value under key_c which has the same 16-bit prefix as A and B + let raw_c = 0b00000000_00000000_00111111_11111111_11111111_11111111_11111111_11111111_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + let val_c = [ONE, ONE, ONE, ZERO]; + + let init_smt = smt.clone(); + smt.insert(key_c.into(), val_c); + assert_insert(&init_smt, key_c, EMPTY_VALUE, val_c, smt.root().into()); +} + +#[test] +fn tsmt_insert_48() { + let mut smt = TieredSmt::default(); + + // insert a value under key_a into an empty tree + let raw_a = 0b00000000_00000000_11111111_11111111_11111111_11111111_11111111_11111111_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let val_a = [ONE, ZERO, ZERO, ZERO]; + smt.insert(key_a.into(), val_a); + + // insert a value under key_b which has the same 32-bit prefix as A + let raw_b = 0b00000000_00000000_11111111_11111111_01111111_11111111_11111111_11111111_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let val_b = [ONE, ONE, ZERO, ZERO]; + + // TODO: test this insertion once complex inserts are working + smt.insert(key_b.into(), val_b); + + // update a value under key_a + let init_smt = smt.clone(); + let val_a2 = [ONE, ZERO, ZERO, ONE]; + smt.insert(key_a.into(), val_a2); + assert_insert(&init_smt, key_a, val_a, val_a2, smt.root().into()); + + // insert a value under key_c which has the same 32-bit prefix as A and B + let raw_c = 0b00000000_00000000_11111111_11111111_00111111_11111111_11111111_11111111_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + let val_c = [ONE, ONE, ONE, ZERO]; + + let init_smt = smt.clone(); + smt.insert(key_c.into(), val_c); + assert_insert(&init_smt, key_c, EMPTY_VALUE, val_c, smt.root().into()); +} + +fn assert_insert( + init_smt: &TieredSmt, + key: RpoDigest, + old_value: Word, + new_value: Word, + new_root: RpoDigest, +) { + let old_root = init_smt.root(); + let source = r#" + use.std::collections::smt + + begin + exec.smt::insert + end + "#; + let initial_stack = [ + old_root[0].as_int(), + old_root[1].as_int(), + old_root[2].as_int(), + old_root[3].as_int(), + key[0].as_int(), + key[1].as_int(), + key[2].as_int(), + key[3].as_int(), + new_value[0].as_int(), + new_value[1].as_int(), + new_value[2].as_int(), + new_value[3].as_int(), + ]; + let expected_output = [ + old_value[3].as_int(), + old_value[2].as_int(), + old_value[1].as_int(), + old_value[0].as_int(), + new_root[3].as_int(), + new_root[2].as_int(), + new_root[1].as_int(), + new_root[0].as_int(), + ]; + let (store, adv_map) = build_advice_inputs(init_smt); + build_test!(source, &initial_stack, &[], store, adv_map).expect_stack(&expected_output); +} + // TEST HELPERS // ================================================================================================ @@ -143,6 +277,13 @@ fn assert_smt_get_opens_correctly(smt: &TieredSmt, key: RpoDigest, value: Word) root[0].as_int(), ]; + let (store, advice_map) = build_advice_inputs(smt); + let advice_stack = []; + build_test!(source, &initial_stack, &advice_stack, store, advice_map.into_iter()) + .expect_stack(&expected_output); +} + +fn build_advice_inputs(smt: &TieredSmt) -> (MerkleStore, Vec<([u8; 32], Vec)>) { let store = MerkleStore::from(smt); let advice_map = smt .upper_leaves() @@ -153,7 +294,5 @@ fn assert_smt_get_opens_correctly(smt: &TieredSmt, key: RpoDigest, value: Word) }) .collect::>(); - let advice_stack = []; - build_test!(source, &initial_stack, &advice_stack, store, advice_map.into_iter()) - .expect_stack(&expected_output); + (store, advice_map) } diff --git a/test-utils/src/crypto.rs b/test-utils/src/crypto.rs index e06c516098..5a807f32b8 100644 --- a/test-utils/src/crypto.rs +++ b/test-utils/src/crypto.rs @@ -1,4 +1,4 @@ -use super::{Felt, FieldElement, StarkField, Vec, Word}; +use super::{Felt, FieldElement, Vec, Word}; // RE-EXPORTS // ================================================================================================ @@ -32,13 +32,3 @@ pub fn init_merkle_leaves(values: &[u64]) -> Vec { pub fn init_merkle_leaf(value: u64) -> Word { [Felt::new(value), Felt::ZERO, Felt::ZERO, Felt::ZERO] } - -/// Returns a remaining path key for a Sparse Merkle Tree -pub fn get_smt_remaining_key(mut key: Word, depth: u8) -> Word { - key[3] = Felt::new(match depth { - 16 | 32 | 48 => (key[3].as_int() << depth) >> depth, - 64 => 0, - _ => unreachable!(), - }); - key -} From 4f1dc0ecf5bdad8a0979685a9701ab3e82e84889 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Thu, 10 Aug 2023 01:09:18 -0700 Subject: [PATCH 13/18] feat: implement TSMT complex insertion for 15->32 case --- .../src/decorators/adv_stack_injectors.rs | 28 ++++ stdlib/asm/collections/smt.masm | 148 +++++++++++++++++- stdlib/docs/collections/smt.md | 2 +- stdlib/tests/collections/smt.rs | 3 +- 4 files changed, 174 insertions(+), 7 deletions(-) diff --git a/processor/src/decorators/adv_stack_injectors.rs b/processor/src/decorators/adv_stack_injectors.rs index ec0f0ae8ca..ee240ef740 100644 --- a/processor/src/decorators/adv_stack_injectors.rs +++ b/processor/src/decorators/adv_stack_injectors.rs @@ -354,6 +354,8 @@ where /// - Simple insert at depth 16: [d0, d1, ONE (is_simple_insert), ZERO (is_update)] /// - Simple insert at depth 32 or 48: [d0, d1, ONE (is_simple_insert), ZERO (is_update), P_NODE] /// - Update of an existing leaf: [ZERO (padding), d0, d1, ONE (is_update), OLD_VALUE] + /// - Replace leaf node with subtree 16->32: [ONE, ONE, ZERO, ZERO, P_KEY, P_VALUE] + /// - Update of an existing leaf: [ONE, d0, d1, ONE, OLD_VALUE] /// /// Where: /// - d0 is a boolean flag set to `1` if the depth is `16` or `48`. @@ -361,6 +363,7 @@ where /// - P_NODE is an internal node located at the tier above the insert tier. /// - VALUE is the value to be inserted. /// - OLD_VALUE is the value previously associated with the specified KEY. + /// - P_KEY and P_VALUE are the key-value pair for a leaf which is to be replaced by a subtree. /// - ROOT and NEW_ROOT are the roots of the TSMT prior and post the insert respectively. /// /// # Errors @@ -428,6 +431,25 @@ where // return is_update = ONE, is_simple_insert = ZERO (ONE, ZERO) } else { + // TODO: improve code readability as more cases are handled + let common_prefix = get_common_prefix(&key, &leaf_key); + if depth == 16 { + if common_prefix < 32 { + // put the key back onto the advice stack + for &element in leaf_key.iter().rev() { + self.advice_provider.push_stack(AdviceSource::Value(element))?; + } + } else { + todo!("handle moving leaf from depth 16 to 48 or 64") + } + } else if depth == 32 { + todo!("handle moving leaf from depth 32 to 48 or 64") + } else if depth == 48 { + todo!("handle moving leaf from depth 48 to 64") + } else { + todo!("handle inserting key-value pair into existing leaf at depth 64") + } + // return is_update = ZERO, is_simple_insert = ZERO (ZERO, ZERO) } @@ -462,3 +484,9 @@ fn u64_to_u32_elements(value: u64) -> (Felt, Felt) { let lo = Felt::new((value as u32) as u64); (hi, lo) } + +fn get_common_prefix(key1: &Word, key2: &Word) -> u8 { + let k1 = key1[3].as_int(); + let k2 = key2[3].as_int(); + (k1 ^ k2).leading_zeros() as u8 +} diff --git a/stdlib/asm/collections/smt.masm b/stdlib/asm/collections/smt.masm index 96faca1fbe..6e910668b0 100644 --- a/stdlib/asm/collections/smt.masm +++ b/stdlib/asm/collections/smt.masm @@ -605,6 +605,133 @@ proc.insert_48.2 # => [0, 0, 0, 0, R_new, ...] end +#! Replaces a leaf node at depth 16 with a subtree containing two leaf nodes at depth 32 such that +#! one of the leaf nodes commits to a key-value pair equal to the leaf node at depth 16, and the +#! other leaf node comments to the key-value pair being inserted. +#! +#! Input: [V, K, R, ...]; +#! Output:[0, 0, 0, 0, R_new, ...] +#! +#! Where: +#! - R is the initial root of the TSMT, and R_new is the new root of the TSMT. +#! - K, V is the key-value pair for the leaf node to be inserted into the TSMT. +#! +#! This procedure consists of three high-level steps: +#! - First, insert M = hash([K_e, V_e], domain=32) into an empty subtree at depth 16, where K_e +#! and V_e are the key-value pair for the existing leaf node. This outputs the new root +#! of the subtree T. +#! - Then, insert N = hash([K, V], domain=32) into a subtree with root T. This outputs the new +#! root of the subtree P_new. +#! - Then, insert P_new into the TSMT with root R. +#! +#! This procedure succeeds only if: +#! - Node at depth 16 is a leaf node. +#! - The key in this node has a common prefix with the key to be inserted. This common prefix +#! must be greater or equal to 16, but smaller than 32. +#! +#! Cycles: 216 +proc.replace_16.3 + # save k3 into loc[0][0] - (6 cycles) + swapw dup loc_store.0 + # => [K, V, R, ...] + + # compute N = hash([K, V], domain=32) - (6 cycles) + push.0.32.0.0 swapw.2 hperm + # => [X, N, X, R, ...] + + # load the key associated with the existing leaf P from the advice provider and save it in + # loc[1] - (5 cycles) + adv_loadw loc_storew.1 + # => [K_e, N, X, R, ...] + + # load the value associated with the existing leaf P from the advice provider and save it in + # loc[2] - (10 cycles) + push.0.16.0.0 swapw.2 swapw.3 adv_loadw loc_storew.2 + # => [V_e, K_e, 0, 0, 16, 0, N, R, ...] + + # compute P = hash([K_e, V_e], domain=16); we will use this later to prove correct execution + # of mtree_set instruction (1 cycle) + hperm + # => [X, P, X, N, R, ...] + + # load K_e from loc[1] - (9 cycles) + push.0.32.0.0 swapw loc_loadw.1 + # => [K_e, 0, 0, 32, 0, P, X, N, R, ...] + + # extract from the most significant element of K_e (i.e., ke_3) two most significant 16-bit + # limbs: idx_hi_eidx_lo_e - (6 cycles) + dup exec.get_top_32_bits u32unchecked_divmod.65536 + # => [idx_lo_e, idx_hi_e, K_e, 0, 0, 32, 0, P, X, N, R, ...] + + # load k3 from loc[0][0] and also extract the two most significant 16-bit limbs from it + # (8 cycles) + loc_load.0 exec.get_top_32_bits u32unchecked_divmod.65536 + # => [idx_lo, idx_hi, idx_lo_e, idx_hi_e, K_e, 0, 0, 32, 0, P, X, N, R, ...] + + # make sure the top 16 bits of both keys are the same (4 cycles) + movup.3 dup.2 assert_eq + # => [idx_lo, idx_hi, idx_lo_e, K_e, 0, 0, 32, 0, P, X, N, R, ...] + + # make sure that the next 16 bits of the keys are not the same; this proves that the keys + # have the same 16-bit prefix, but not the same 32-bit prefix (6 cycles) + movup.2 dup dup.2 neq assert + # => [idx_lo_e, idx_lo, idx_hi, K_e, 0, 0, 32, 0, P, X, N, R, ...] + + # save [idx_hi, idx_lo, idx_lo_e, 0] into loc[0] - (4 cycles) + push.0 loc_storew.0 + # => [0, idx_lo_e, idx_lo, idx_hi, K_e, 0, 0, 32, 0, P, X, N, R, ...] + + # load the value V_e from loc[2] and compute M = hash([K_e, K_e], domain=32) - (4 cycles) + loc_loadw.2 hperm + # => [X, M, X, P, X, N, R, ...] + + # load the indexes from loc[0] and drop all but idx_lo_e from the stack (7 cycles) + loc_loadw.0 drop movdn.2 drop drop + # => [idx_lo_e, M, X, P, X, N, R, ...] + + # push the root of an empty subtree at depth 16 onto the stack (4 cycles) + push.EMPTY_16_0.EMPTY_16_1.EMPTY_16_2.EMPTY_16_3 + # => [E16, idx_lo_e, M, X, P, X, N, R, ...] + + # insert node M into the empty subtree at depth 16; this leaves the new root of the + # subtree T together with the root of an empty subtree at depth 32 - (31 cycles) + movup.4 push.16 mtree_set + # => [E32, T, X, P, X, N, R, ...] + + # drop the E32 root as we don't need it, and arrange the stack for inserting the next + # leaf (12 cycles) + dropw swapw dropw swapw swapw.3 swapw.2 + # => [X, N, T, P, R, ...] + + # load the indexes from loc[0] and drop all but idx_lo from the stack (7 cycles) + loc_loadw.0 drop drop swap drop + # => [idx_lo, N, T, P, R, ...] + + # insert node N into the subtree with root T at depth 16; this leaves the new root of the + # subtree P_new on the stack together with the root of an empty subtree at depth 32 - (30 cycles) + push.16 mtree_set + # => [E32, P_new, P, R, ...] + + # prepare the stack for an mtree_set operation against R; we drop the E32 value as we don't + # need it; the index idx_hi is loaded from memory (10 cycles) + dropw swapw swapw.2 loc_load.0 push.16 + # => [16, idx_hi, R, P_new, P, ...] + + # insert node P_new into the TSMT at depth 16; this puts the new value of TSMT root onto the + # stack together with the old value of the node at depth 16 - (29 cycles) + mtree_set + # => [P_old, R_new, P, ...] + + # make sure P (which we computed as hash([K_e, V_e], domain=16)) and P_old are the same + # (13 cycles) + swapw swapw.2 assert_eqw + # => [R_new, ...] + + # put the return value onto the stack and return (4 cycles) + padw + # => [0, 0, 0, 0, R_new, ...] +end + #! Inserts the specified value into a Sparse Merkle Tree with the specified root under the #! specified key. #! @@ -614,7 +741,7 @@ end #! This assumes that the value is not [ZERO; 4]. If it is, the procedure fails. #! #! Input: [V, K, R, ...]; -#! Output:[V_old, R', ...] +#! Output:[V_old, R_new, ...] #! #! Cycles: #! - Update existing leaf: @@ -626,7 +753,7 @@ end #! - Depth 32: 181 #! - Depth 48: 181 #! - Replace a leaf with a subtree: -#! - Depth 32: TODO +#! - Depth 32: 243 #! - Depth 48: TODO export.insert # make sure the value is not [ZERO; 4] (17 cycles) @@ -701,10 +828,21 @@ export.insert end end else # --- replace leaf with subtree ---------------------------------- - # TODO: implement replace leaf with subtree - push.0 assert + if.true + if.true + # replace a leaf node at depth 16 with a subtree containing + # two leaf nodes at depth 32. + exec.replace_16 + else + # not implemented + push.0 assert + end + else + # not implemented + push.0 assert + end end end - # => [V, R, ...] + # => [V_old, R_new, ...] end diff --git a/stdlib/docs/collections/smt.md b/stdlib/docs/collections/smt.md index edb31797fe..3af197d770 100644 --- a/stdlib/docs/collections/smt.md +++ b/stdlib/docs/collections/smt.md @@ -3,4 +3,4 @@ | Procedure | Description | | ----------- | ------------- | | get | Returns the value stored under the specified key in a Sparse Merkle Tree with the specified root.

If the value for a given key has not been set, the returned `V` will consist of all zeroes.

Input: [K, R, ...]

Output: [V, R, ...]

Depth 16: 91 cycles

Depth 32: 87 cycles

Depth 48: 94 cycles

Depth 64: unimplemented | -| insert | Inserts the specified value into a Sparse Merkle Tree with the specified root under the

specified key.

The value previously stored in the SMT under this key is left on the stack together with

the updated tree root.

This assumes that the value is not [ZERO; 4]. If it is, the procedure fails.

Input: [V, K, R, ...];

Output:[V_old, R', ...]

Cycles:

- Update existing leaf:

- Depth 16: 129

- Depth 32: 126

- Depth 48: 131

- Insert new leaf:

- Depth 16: 100

- Depth 32: 181

- Depth 48: 181

- Replace a leaf with a subtree:

- Depth 32: TODO

- Depth 48: TODO | +| insert | Inserts the specified value into a Sparse Merkle Tree with the specified root under the

specified key.

The value previously stored in the SMT under this key is left on the stack together with

the updated tree root.

This assumes that the value is not [ZERO; 4]. If it is, the procedure fails.

Input: [V, K, R, ...];

Output:[V_old, R_new, ...]

Cycles:

- Update existing leaf:

- Depth 16: 129

- Depth 32: 126

- Depth 48: 131

- Insert new leaf:

- Depth 16: 100

- Depth 32: 181

- Depth 48: 181

- Replace a leaf with a subtree:

- Depth 32: 243

- Depth 48: TODO | diff --git a/stdlib/tests/collections/smt.rs b/stdlib/tests/collections/smt.rs index 10f6c4ad60..5c5a15fe4e 100644 --- a/stdlib/tests/collections/smt.rs +++ b/stdlib/tests/collections/smt.rs @@ -147,8 +147,9 @@ fn tsmt_insert_32() { let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); let val_b = [ONE, ONE, ZERO, ZERO]; - // TODO: test this insertion once complex inserts are working + let init_smt = smt.clone(); smt.insert(key_b.into(), val_b); + assert_insert(&init_smt, key_b, EMPTY_VALUE, val_b, smt.root().into()); // update a value under key_a let init_smt = smt.clone(); From 5cc332707f9abcde373d1a6a146258f2ce23882e Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Thu, 10 Aug 2023 12:15:26 +0000 Subject: [PATCH 14/18] feat: add support for module aliases --- CHANGELOG.md | 1 + assembly/src/ast/imports.rs | 7 +- assembly/src/errors.rs | 8 ++ assembly/src/tests.rs | 85 +++++++++++++++++++ assembly/src/tokens/mod.rs | 35 +++++++- .../user_docs/assembly/code_organization.md | 12 +++ 6 files changed, 140 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0311a9e05a..7777072cde 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Added ability to attach doc comments to re-exported procedures (#994). - Added support for nested modules (#992). - Added support for the arithmetic expressions in constant values (#1026). +- Added support for module aliases (#1037). #### VM Internals - Simplified range checker and removed 1 main and 1 auxiliary trace column (#949). diff --git a/assembly/src/ast/imports.rs b/assembly/src/ast/imports.rs index 4895a9a759..8c767d36b6 100644 --- a/assembly/src/ast/imports.rs +++ b/assembly/src/ast/imports.rs @@ -51,13 +51,12 @@ impl ModuleImports { while let Some(token) = tokens.read() { match token.parts()[0] { Token::USE => { - let module_path = token.parse_use()?; - let module_name = module_path.last(); - if imports.contains_key(module_name) { + let (module_path, module_name) = token.parse_use()?; + if imports.values().any(|path| *path == module_path) { return Err(ParsingError::duplicate_module_import(token, &module_path)); } - imports.insert(module_name.to_string(), module_path); + imports.insert(module_name, module_path); // consume the `use` token tokens.advance(); diff --git a/assembly/src/errors.rs b/assembly/src/errors.rs index 7d307667e3..d2abab12ff 100644 --- a/assembly/src/errors.rs +++ b/assembly/src/errors.rs @@ -593,6 +593,14 @@ impl ParsingError { } } + pub fn invalid_module_name(token: &Token, name: &str) -> Self { + ParsingError { + message: format!("invalid module name: {name}"), + location: *token.location(), + op: token.to_string(), + } + } + pub fn import_inside_body(token: &Token) -> Self { ParsingError { message: "import in procedure body".to_string(), diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 197db63b7b..9bbb619d0f 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -968,6 +968,91 @@ fn program_with_reexported_proc_in_another_library() { assert!(assembler.compile(source).is_err()); } +#[test] +fn module_alias() { + const NAMESPACE: &str = "dummy"; + const MODULE: &str = "math::u64"; + const PROCEDURE: &str = r#" + export.checked_add + swap + movup.3 + u32assert.2 + u32overflowing_add + movup.3 + movup.3 + u32assert.2 + u32overflowing_add3 + eq.0 + assert + end"#; + + let namespace = LibraryNamespace::try_from(NAMESPACE.to_string()).unwrap(); + let path = LibraryPath::try_from(MODULE.to_string()).unwrap().prepend(&namespace).unwrap(); + let ast = ModuleAst::parse(PROCEDURE).unwrap(); + let modules = vec![Module { path, ast }]; + let library = DummyLibrary::new(namespace, modules); + + let assembler = super::Assembler::default().with_library(&library).unwrap(); + + let source = " + use.dummy::math::u64->bigint + + begin + push.1.0 + push.2.0 + exec.bigint::checked_add + end"; + let program = assembler.compile(source).unwrap(); + let expected = "\ + begin \ + span \ + pad incr pad push(2) pad \ + swap movup3 u32assert2 \ + u32add movup3 movup3 \ + u32assert2 u32add3 eqz assert \ + end \ + end"; + assert_eq!(expected, format!("{program}")); + + // --- invalid module alias ----------------------------------------------- + let source = " + use.dummy::math::u64->bigint->invalidname + + begin + push.1.0 + push.2.0 + exec.bigint->invalidname::checked_add + end"; + assert!(assembler.compile(source).is_err()); + + // --- duplicate module import -------------------------------------------- + let source = " + use.dummy::math::u64 + use.dummy::math::u64->bigint + + begin + push.1.0 + push.2.0 + exec.bigint::checked_add + end"; + + assert!(assembler.compile(source).is_err()); + + // --- duplicate module imports with different aliases -------------------- + let source = " + use.dummy::math::u64->bigint + use.dummy::math::u64->bigint2 + + begin + push.1.0 + push.2.0 + exec.bigint::checked_add + exec.bigint2::checked_add + end"; + + assert!(assembler.compile(source).is_err()); +} + #[test] fn program_with_import_errors() { // --- non-existent import ------------------------------------------------ diff --git a/assembly/src/tokens/mod.rs b/assembly/src/tokens/mod.rs index f0a96e58e8..6256e5d8d9 100644 --- a/assembly/src/tokens/mod.rs +++ b/assembly/src/tokens/mod.rs @@ -54,7 +54,7 @@ impl<'a> Token<'a> { // -------------------------------------------------------------------------------------------- pub const DOC_COMMENT_PREFIX: &str = "#!"; pub const COMMENT_PREFIX: char = '#'; - pub const EXPORT_ALIAS_DELIM: &str = "->"; + pub const ALIAS_DELIM: &str = "->"; // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -112,12 +112,23 @@ impl<'a> Token<'a> { // CONTROL TOKEN PARSERS / VALIDATORS // -------------------------------------------------------------------------------------------- - pub fn parse_use(&self) -> Result { + pub fn parse_use(&self) -> Result<(LibraryPath, String), ParsingError> { assert_eq!(Self::USE, self.parts[0], "not a use"); match self.num_parts() { 0 => unreachable!(), 1 => Err(ParsingError::missing_param(self)), - 2 => validate_import_path(self.parts[1], self), + 2 => { + if let Some((module_path, module_name)) = + self.parts[1].split_once(Self::ALIAS_DELIM) + { + validate_module_name(module_name, self)?; + Ok((validate_import_path(module_path, self)?, module_name.to_string())) + } else { + let module_path = validate_import_path(self.parts[1], self)?; + let module_name = module_path.last().to_string(); + Ok((module_path, module_name)) + } + } _ => Err(ParsingError::extra_param(self)), } } @@ -171,7 +182,7 @@ impl<'a> Token<'a> { // get the alias name if it exists else export it with the original name let (ref_name, proc_name) = proc_name_with_alias - .split_once(Self::EXPORT_ALIAS_DELIM) + .split_once(Self::ALIAS_DELIM) .unwrap_or((proc_name_with_alias, proc_name_with_alias)); // validate the procedure names @@ -289,3 +300,19 @@ fn validate_proc_locals(locals: &str, token: &Token) -> Result Err(ParsingError::invalid_proc_locals(token, locals)), } } + +/// A module name must comply with the following rules: +/// - The name must be between 1 and 255 characters long. +/// - The name must start with an ASCII letter. +/// - The name can contain only ASCII letters, numbers, or underscores. +fn validate_module_name(name: &str, token: &Token) -> Result<(), ParsingError> { + if name.is_empty() + || name.len() > crate::MAX_LABEL_LEN + || !name.chars().next().unwrap().is_ascii_alphabetic() + || !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') + { + Err(ParsingError::invalid_module_name(token, name)) + } else { + Ok(()) + } +} diff --git a/docs/src/user_docs/assembly/code_organization.md b/docs/src/user_docs/assembly/code_organization.md index 65ee845bc5..06bc5e72f2 100644 --- a/docs/src/user_docs/assembly/code_organization.md +++ b/docs/src/user_docs/assembly/code_organization.md @@ -90,6 +90,18 @@ end ``` In the above example we import `std::math::u64` module from the [standard library](../stdlib/main.md). We then execute a program which pushes two 64-bit integers onto the stack, and then invokes a 64-bit addition procedure from the imported module. +We can also define aliases for imported modules. For example: + +``` +use.std::math::u64->bigint + +begin + push.1.0 + push.2.0 + exec.bigint::checked_add +end +``` + The set of modules which can be imported by a program can be specified via a Module Provider when instantiating the [Miden Assembler](https://crates.io/crates/miden-assembly) used to compile the program. #### Re-exporting procedures From 57110fa60c27e64cfdf3c8a6fd8cb4a39a79baf2 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sat, 12 Aug 2023 08:47:16 -0700 Subject: [PATCH 15/18] fix: error message format --- assembly/src/errors.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assembly/src/errors.rs b/assembly/src/errors.rs index d2abab12ff..51f699dc84 100644 --- a/assembly/src/errors.rs +++ b/assembly/src/errors.rs @@ -143,7 +143,7 @@ impl fmt::Display for AssemblyError { LibraryError(err) | ParsingError(err) | ProcedureNameError(err) => write!(f, "{err}"), LocalProcNotFound(proc_idx, module_path) => write!(f, "procedure at index {proc_idx} not found in module {module_path}"), ParamOutOfBounds(value, min, max) => write!(f, "parameter value must be greater than or equal to {min} and less than or equal to {max}, but was {value}"), - PhantomCallsNotAllowed(mast_root) => write!(f, "cannot call phantom procedure with MAST root 0x{mast_root}: phantom calls not allowed"), + PhantomCallsNotAllowed(mast_root) => write!(f, "cannot call phantom procedure with MAST root {mast_root}: phantom calls not allowed"), SysCallInKernel(proc_name) => write!(f, "syscall instruction used in kernel procedure '{proc_name}'"), } } From 9bcbd2418ee84c3dfe59c2e7a91c82d27416e118 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Fri, 11 Aug 2023 00:11:33 -0700 Subject: [PATCH 16/18] feat: implement TSMT procedure to handle complex inserts at depth 48 --- core/src/operations/decorators/advice.rs | 36 ++- processor/src/advice/providers.rs | 10 +- processor/src/advice/source.rs | 3 + .../src/decorators/adv_stack_injectors.rs | 250 +++++++++++----- processor/src/errors.rs | 14 +- stdlib/asm/collections/smt.masm | 280 +++++++++++++++++- stdlib/docs/collections/smt.md | 2 +- stdlib/tests/collections/smt.rs | 32 +- 8 files changed, 520 insertions(+), 107 deletions(-) diff --git a/core/src/operations/decorators/advice.rs b/core/src/operations/decorators/advice.rs index d87ec83219..b6c7229d84 100644 --- a/core/src/operations/decorators/advice.rs +++ b/core/src/operations/decorators/advice.rs @@ -151,6 +151,39 @@ pub enum AdviceInjector { /// - f2 is a boolean flag set to `1` if a remaining key is not zero. SmtGet, + /// Pushes values onto the advice stack which are required for successful insertion of a + /// key-value pair into a Sparse Merkle Tree data structure. + /// + /// The Sparse Merkle Tree is tiered, meaning it will have leaf depths in `{16, 32, 48, 64}`. + /// + /// Inputs: + /// Operand stack: [VALUE, KEY, ROOT, ...] + /// Advice stack: [...] + /// + /// Outputs: + /// Operand stack: [OLD_VALUE, NEW_ROOT, ...] + /// Advice stack depends on the type of insert operation as follows: + /// - Update of an existing leaf: [ZERO (padding), d0, d1, ONE (is_update), OLD_VALUE] + /// - Simple insert at depth 16: [d0, d1, ONE (is_simple_insert), ZERO (is_update)] + /// - Simple insert at depth 32 or 48: [d0, d1, ONE (is_simple_insert), ZERO (is_update), P_NODE] + /// - Complex insert: [f0, f1, ZERO (is_simple_insert), ZERO (is_update), E_KEY, E_VALUE] + /// + /// Where: + /// - ROOT and NEW_ROOT are the roots of the TSMT before and after the insert respectively. + /// - VALUE is the value to be inserted. + /// - OLD_VALUE is the value previously associated with the specified KEY. + /// - d0 is a boolean flag set to `1` if the depth is `16` or `48`. + /// - d1 is a boolean flag set to `1` if the depth is `16` or `32`. + /// - P_NODE is an internal node located at the tier above the insert tier. + /// - f0 and f1 are boolean flags a combination of which determines the source and the target + /// tiers as follows: + /// - (0, 0): depth 16 -> 32 + /// - (0, 1): depth 16 -> 48 + /// - (1, 0): depth 32 -> 48 + /// - (1, 1): depth 16, 32, or 48 -> 64 + /// - E_KEY and E_VALUE are the key-value pair for a leaf which is to be replaced by a subtree. + SmtInsert, + // ADVICE MAP INJECTORS // -------------------------------------------------------------------------------------------- /// Reads words from memory at the specified range and inserts them into the advice map under @@ -181,9 +214,6 @@ pub enum AdviceInjector { /// Where KEY is computed as hash(A || B, domain), where domain is provided via the immediate /// value. HdwordToMap { domain: Felt }, - - /// TODO: add docs - SmtInsert, } impl fmt::Display for AdviceInjector { diff --git a/processor/src/advice/providers.rs b/processor/src/advice/providers.rs index 11b350b654..18c151fa18 100644 --- a/processor/src/advice/providers.rs +++ b/processor/src/advice/providers.rs @@ -83,22 +83,24 @@ where match source { AdviceSource::Value(value) => { self.stack.push(value); - Ok(()) } - + AdviceSource::Word(word) => { + self.stack.extend(word.iter().rev()); + } AdviceSource::Map { key, include_len } => { let values = self .map .get(&key.into_bytes()) - .ok_or(ExecutionError::AdviceKeyNotFound(key))?; + .ok_or(ExecutionError::AdviceMapKeyNotFound(key))?; self.stack.extend(values.iter().rev()); if include_len { self.stack.push(Felt::from(values.len() as u64)); } - Ok(()) } } + + Ok(()) } fn insert_into_map(&mut self, key: Word, values: Vec) -> Result<(), ExecutionError> { diff --git a/processor/src/advice/source.rs b/processor/src/advice/source.rs index b2ba99e42f..d87f7ed436 100644 --- a/processor/src/advice/source.rs +++ b/processor/src/advice/source.rs @@ -9,6 +9,9 @@ pub enum AdviceSource { /// Puts a single value onto the advice stack. Value(Felt), + /// Puts a word (4 elements) ont the the stack. + Word(Word), + /// Fetches a list of elements under the specified key from the advice map and pushes them onto /// the advice stack. /// diff --git a/processor/src/decorators/adv_stack_injectors.rs b/processor/src/decorators/adv_stack_injectors.rs index ee240ef740..b22d86b1cb 100644 --- a/processor/src/decorators/adv_stack_injectors.rs +++ b/processor/src/decorators/adv_stack_injectors.rs @@ -1,7 +1,8 @@ use super::{super::Ext2InttError, AdviceProvider, AdviceSource, ExecutionError, Process}; use vm_core::{ - crypto::merkle::EmptySubtreeRoots, utils::collections::Vec, Felt, FieldElement, QuadExtension, - StarkField, Word, ONE, ZERO, + crypto::{hash::RpoDigest, merkle::EmptySubtreeRoots}, + utils::collections::Vec, + Felt, FieldElement, QuadExtension, StarkField, Word, ONE, WORD_SIZE, ZERO, }; use winter_prover::math::fft; @@ -350,31 +351,26 @@ where /// /// Outputs: /// Operand stack: [OLD_VALUE, NEW_ROOT, ...] - /// Advice stack, depends on the type of insert: - /// - Simple insert at depth 16: [d0, d1, ONE (is_simple_insert), ZERO (is_update)] - /// - Simple insert at depth 32 or 48: [d0, d1, ONE (is_simple_insert), ZERO (is_update), P_NODE] - /// - Update of an existing leaf: [ZERO (padding), d0, d1, ONE (is_update), OLD_VALUE] - /// - Replace leaf node with subtree 16->32: [ONE, ONE, ZERO, ZERO, P_KEY, P_VALUE] - /// - Update of an existing leaf: [ONE, d0, d1, ONE, OLD_VALUE] + /// Advice stack: see comments for specialized handlers below. /// /// Where: - /// - d0 is a boolean flag set to `1` if the depth is `16` or `48`. - /// - d1 is a boolean flag set to `1` if the depth is `16` or `32`. - /// - P_NODE is an internal node located at the tier above the insert tier. + /// - ROOT and NEW_ROOT are the roots of the TSMT before and after the insert respectively. /// - VALUE is the value to be inserted. /// - OLD_VALUE is the value previously associated with the specified KEY. - /// - P_KEY and P_VALUE are the key-value pair for a leaf which is to be replaced by a subtree. - /// - ROOT and NEW_ROOT are the roots of the TSMT prior and post the insert respectively. /// /// # Errors - /// Will return an error if the provided Merkle root doesn't exist on the advice provider. + /// Will return an error if: + /// - The Merkle store does not contain a node with the specified root. + /// - The Merkle store does not contain all nodes needed to validate the path between the root + /// and the relevant TSMT nodes. + /// - The advice map does not contain required data about TSMT leaves to be modified. /// /// # Panics /// Will panic as unimplemented if the target depth is `64`. pub(super) fn push_smtinsert_inputs(&mut self) -> Result<(), ExecutionError> { // get the key and tree root from the stack - let key = [self.stack.get(7), self.stack.get(6), self.stack.get(5), self.stack.get(4)]; - let root = [self.stack.get(11), self.stack.get(10), self.stack.get(9), self.stack.get(8)]; + let key = self.stack.get_word(1); + let root = self.stack.get_word(2); // determine the depth of the first leaf or an empty tree node let index = &key[3]; @@ -397,81 +393,169 @@ where // - if the node is a leaf, this could be either an update (for the same key), or a // complex insert (i.e., the existing leaf needs to be moved to a lower tier). let empty = EmptySubtreeRoots::empty_hashes(64)[depth as usize]; - let (is_update, is_simple_insert) = if node == Word::from(empty) { - // handle simple insert case - if depth == 32 || depth == 48 { + if node == Word::from(empty) { + self.handle_smt_simple_insert(root, depth, index)?; + } else { + // get the key and value stored in the current leaf + let (leaf_key, leaf_value) = self.get_smt_upper_leaf_preimage(node)?; + + // if the key for the value to be inserted is the same as the leaf's key, we are + // dealing with a simple update; otherwise, we are dealing with a complex insert + if leaf_key == key { + self.handle_smt_update(depth, leaf_value)?; + } else { + self.handle_smt_complex_insert(depth, key, leaf_key, leaf_value)?; + } + } + + Ok(()) + } + + // TSMT INSERT HELPER METHODS + // -------------------------------------------------------------------------------------------- + + /// Retrieves a key-value pair for the specified leaf node from the advice map. + /// + /// # Errors + /// Returns an error if the value under the specified node does not exist or does not consist + /// of exactly 8 elements. + fn get_smt_upper_leaf_preimage(&self, node: Word) -> Result<(Word, Word), ExecutionError> { + let node_bytes = RpoDigest::from(node).as_bytes(); + let kv = self + .advice_provider + .get_mapped_values(&node_bytes) + .ok_or(ExecutionError::AdviceMapKeyNotFound(node))?; + + if kv.len() != WORD_SIZE * 2 { + return Err(ExecutionError::AdviceMapValueInvalidLength(node, WORD_SIZE * 2, kv.len())); + } + + let key = [kv[0], kv[1], kv[2], kv[3]]; + let val = [kv[4], kv[5], kv[6], kv[7]]; + Ok((key, val)) + } + + /// Prepares the advice stack for a TSMT update operation. Specifically, the advice stack will + /// be arranged as follows: + /// + /// - [ZERO (padding), d0, d1, ONE (is_update), OLD_VALUE] + /// + /// Where: + /// - d0 is a boolean flag set to `1` if the depth is `16` or `48`. + /// - d1 is a boolean flag set to `1` if the depth is `16` or `32`. + /// - OLD_VALUE is the current value in the leaf to be updated. + fn handle_smt_update(&mut self, depth: u8, old_value: Word) -> Result<(), ExecutionError> { + // put the old value onto the advice stack + self.advice_provider.push_stack(AdviceSource::Word(old_value))?; + + // set is_update flag to ONE + self.advice_provider.push_stack(AdviceSource::Value(ONE))?; + + // set depth flags based on leaf's depth + let (is_16_or_32, is_16_or_48) = get_depth_flags(depth); + self.advice_provider.push_stack(AdviceSource::Value(is_16_or_32))?; + self.advice_provider.push_stack(AdviceSource::Value(is_16_or_48))?; + + // pad the advice stack with an extra value to make it consistent with other cases when + // we expect 4 flag values on the top of the advice stack + self.advice_provider.push_stack(AdviceSource::Value(ZERO))?; + + Ok(()) + } + + /// Prepares the advice stack for a TSMT simple insert operation (i.e., when we are replacing + /// an empty node). Specifically, the advice stack will be arranged as follows: + /// + /// - Simple insert at depth 16: [d0, d1, ONE (is_simple_insert), ZERO (is_update)] + /// - Simple insert at depth 32 or 48: [d0, d1, ONE (is_simple_insert), ZERO (is_update), P_NODE] + /// + /// Where: + /// - d0 is a boolean flag set to `1` if the depth is `16` or `48`. + /// - d1 is a boolean flag set to `1` if the depth is `16` or `32`. + /// - P_NODE is an internal node located at the tier above the insert tier. + fn handle_smt_simple_insert( + &mut self, + root: Word, + depth: u8, + index: Felt, + ) -> Result<(), ExecutionError> { + // put additional data onto the advice stack as needed + match depth { + 16 => (), // nothing to do; all the required data is already in the VM + 32 | 48 => { // for depth 32 and 48, we need to provide the internal node located on the tier // above the insert tier let p_index = Felt::from(index.as_int() >> 16); let p_depth = Felt::from(depth - 16); let p_node = self.advice_provider.get_tree_node(root, &p_depth, &p_index)?; - for &element in p_node.iter().rev() { - self.advice_provider.push_stack(AdviceSource::Value(element))?; - } + self.advice_provider.push_stack(AdviceSource::Word(p_node))?; } + 64 => unimplemented!("insertions at depth 64 are not yet implemented"), + _ => unreachable!("invalid depth {depth}"), + } - // return is_update = ZERO, is_simple_insert = ONE - (ZERO, ONE) - } else { - // if the node is a leaf node, push the elements mapped to this node onto the advice - // stack; the elements should be [KEY, VALUE], with key located at the top of the - // advice stack. - self.advice_provider.push_stack(AdviceSource::Map { - key: node, - include_len: false, - })?; - - // remove the KEY from the advice stack, leaving only the VALUE on the stack - let leaf_key = self.advice_provider.pop_stack_word()?; + // push is_update and is_simple_insert flags onto the advice stack + self.advice_provider.push_stack(AdviceSource::Value(ZERO))?; + self.advice_provider.push_stack(AdviceSource::Value(ONE))?; - // if the key for the value to be inserted is the same as the leaf's key, we are - // dealing with a simple update. otherwise, we are dealing with a complex insert - // (i.e., the leaf needs to be moved to a lower tier). - if leaf_key == key { - // return is_update = ONE, is_simple_insert = ZERO - (ONE, ZERO) - } else { - // TODO: improve code readability as more cases are handled - let common_prefix = get_common_prefix(&key, &leaf_key); - if depth == 16 { - if common_prefix < 32 { - // put the key back onto the advice stack - for &element in leaf_key.iter().rev() { - self.advice_provider.push_stack(AdviceSource::Value(element))?; - } - } else { - todo!("handle moving leaf from depth 16 to 48 or 64") - } - } else if depth == 32 { - todo!("handle moving leaf from depth 32 to 48 or 64") - } else if depth == 48 { - todo!("handle moving leaf from depth 48 to 64") - } else { - todo!("handle inserting key-value pair into existing leaf at depth 64") - } - - // return is_update = ZERO, is_simple_insert = ZERO - (ZERO, ZERO) - } - }; + // set depth flags based on node's depth + let (is_16_or_32, is_16_or_48) = get_depth_flags(depth); + self.advice_provider.push_stack(AdviceSource::Value(is_16_or_32))?; + self.advice_provider.push_stack(AdviceSource::Value(is_16_or_48))?; - // set the flags used to determine which tier the insert is happening at - let is_16_or_32 = if depth == 16 || depth == 32 { ONE } else { ZERO }; - let is_16_or_48 = if depth == 16 || depth == 48 { ONE } else { ZERO }; + Ok(()) + } - self.advice_provider.push_stack(AdviceSource::Value(is_update))?; - if is_update == ONE { - // for update we don't need to specify whether we are dealing with an insert; but we - // insert an extra ONE at the end so that we can read 4 values from the advice stack - // regardless of which branch is taken. - self.advice_provider.push_stack(AdviceSource::Value(is_16_or_32))?; - self.advice_provider.push_stack(AdviceSource::Value(is_16_or_48))?; - self.advice_provider.push_stack(AdviceSource::Value(ZERO))?; - } else { - self.advice_provider.push_stack(AdviceSource::Value(is_simple_insert))?; - self.advice_provider.push_stack(AdviceSource::Value(is_16_or_32))?; - self.advice_provider.push_stack(AdviceSource::Value(is_16_or_48))?; + /// Prepares the advice stack for a TSMT complex insert operation (i.e., when a leaf node needs + /// to be replaced with a subtree of nodes at a lower tier). Specifically, the advice stack + /// will be arranged as follows: + /// + /// - [d0, d1, ZERO (is_simple_insert), ZERO (is_update), E_KEY, E_VALUE] + /// + /// Where: + /// - d0 and d1 are boolean flags a combination of which determines the source and the target + /// tiers as follows: + /// - (0, 0): depth 16 -> 32 + /// - (0, 1): depth 16 -> 48 + /// - (1, 0): depth 32 -> 48 + /// - (1, 1): depth 16, 32, or 48 -> 64 + /// - E_KEY and E_VALUE are the key-value pair for a leaf which is to be replaced by a subtree. + fn handle_smt_complex_insert( + &mut self, + depth: u8, + key: Word, + leaf_key: Word, + leaf_value: Word, + ) -> Result<(), ExecutionError> { + // push the key and value onto the advice stack + self.advice_provider.push_stack(AdviceSource::Word(leaf_value))?; + self.advice_provider.push_stack(AdviceSource::Word(leaf_key))?; + + // push is_update and is_simple_insert flags onto the advice stack + self.advice_provider.push_stack(AdviceSource::Value(ZERO))?; + self.advice_provider.push_stack(AdviceSource::Value(ZERO))?; + + // determine the combination of the source and target tiers for the insert + // and populate the depth flags accordingly + let common_prefix = get_common_prefix(&key, &leaf_key); + let target_depth = SMT_NORMALIZED_DEPTHS[common_prefix as usize + 1]; + match target_depth { + 32 if depth == 16 => { + self.advice_provider.push_stack(AdviceSource::Value(ONE))?; + self.advice_provider.push_stack(AdviceSource::Value(ONE))?; + } + 48 if depth == 16 => { + self.advice_provider.push_stack(AdviceSource::Value(ONE))?; + self.advice_provider.push_stack(AdviceSource::Value(ZERO))?; + } + 48 if depth == 32 => { + self.advice_provider.push_stack(AdviceSource::Value(ZERO))?; + self.advice_provider.push_stack(AdviceSource::Value(ONE))?; + } + 64 => unimplemented!("insertions at depth 64 are not yet implemented"), + _ => unreachable!("invalid source/target tier combination: {depth} -> {target_depth}"), } + Ok(()) } } @@ -490,3 +574,9 @@ fn get_common_prefix(key1: &Word, key2: &Word) -> u8 { let k2 = key2[3].as_int(); (k1 ^ k2).leading_zeros() as u8 } + +fn get_depth_flags(depth: u8) -> (Felt, Felt) { + let is_16_or_32 = if depth == 16 || depth == 32 { ONE } else { ZERO }; + let is_16_or_48 = if depth == 16 || depth == 48 { ONE } else { ZERO }; + (is_16_or_32, is_16_or_48) +} diff --git a/processor/src/errors.rs b/processor/src/errors.rs index e95c0c5649..4c616dbf1d 100644 --- a/processor/src/errors.rs +++ b/processor/src/errors.rs @@ -15,7 +15,8 @@ use std::error::Error; #[derive(Debug)] pub enum ExecutionError { - AdviceKeyNotFound(Word), + AdviceMapKeyNotFound(Word), + AdviceMapValueInvalidLength(Word, usize, usize), AdviceStackReadFailed(u32), CallerNotInSyscall, CodeBlockNotFound(Digest), @@ -47,9 +48,16 @@ impl Display for ExecutionError { use ExecutionError::*; match self { - AdviceKeyNotFound(key) => { + AdviceMapKeyNotFound(key) => { let hex = to_hex(Felt::elements_as_bytes(key))?; - write!(f, "Can't push values onto the advice stack: value for key {hex} not present in the advice map.") + write!(f, "Value for key {hex} not present in the advice map") + } + AdviceMapValueInvalidLength(key, expected, actual) => { + let hex = to_hex(Felt::elements_as_bytes(key))?; + write!( + f, + "Expected value for key {hex} to contain {expected} elements, but was {actual}" + ) } AdviceStackReadFailed(step) => write!(f, "Advice stack read failed at step {step}"), CallerNotInSyscall => { diff --git a/stdlib/asm/collections/smt.masm b/stdlib/asm/collections/smt.masm index 6e910668b0..c90e904a03 100644 --- a/stdlib/asm/collections/smt.masm +++ b/stdlib/asm/collections/smt.masm @@ -56,6 +56,67 @@ proc.get_top_48_bits add end +#! Extracts top 16 and the next 32 bits from the most significant elements of U and V. +#! +#! Also verifies that the top 32 bits of these elements are the same, while the next 16 bits are +#! different. +#! +#! Input: [U, V, ...] +#! Output: [(u3 << 16) >> 32, (v3 << 16) >> 32, v3 >> 48, U, V, ...] +#! +#! Cycles: 30 +proc.get_common_prefix_16 + # split the most significant elements of U and V into 32-bit chunks and make sure the top + # 32 bit chunks are the same (i.e. u3_hi = v3_hi) - (8 cycles) + dup.4 u32split dup.2 u32split dup movup.3 assert_eq + # => [u3_hi, u3_lo, v3_lo, U, V, ...] + + u32unchecked_divmod.65536 mul.65536 + # => [idx_mid, idx_hi, u3_lo, v3_lo, U, V, ...] + + movup.3 u32unchecked_shr.16 + # => [v3_lo_hi, idx_mid, idx_hi, u3_lo, U, V, ...] + + dup dup.2 add + # => [idx_lo_v, v3_lo_hi, idx_mid, idx_hi, u3_lo, U, V, ...] + + movup.4 u32unchecked_shr.16 + # => [u3_lo_hi, idx_lo_v, v3_lo_hi, idx_mid, idx_hi, U, V, ...] + + dup movup.3 neq assert + # => [u3_lo_hi, idx_lo_v, idx_mid, idx_hi, U, V, ...] + + movup.2 add + # => [idx_lo_u, idx_lo_v, idx_hi, U, V, ...] +end + +#! Extracts top 32 bits and the next 16 bits from the most significant elements of U and V. +#! +#! Also verifies that the top 32 bits of these elements are the same, while the next 16 bits are +#! different. +#! +#! Input: [U, V, ...] +#! Output: [(u3 << 32) >> 48, (v3 << 32) >> 48, v3 >> 32, U, V, ...] +#! +#! Cycles: 20 +proc.get_common_prefix_32 + # slit teh most significant elements of U and V into 32-bit chunks (4 cycles) + dup.4 u32split dup.2 u32split + # => [u3_hi, u3_lo, v3_hi, v3_lo, U, V, ...] + + # make sure that the top 32 bit chunks are the same (3 cycles) + dup.2 assert_eq + # => [u3_lo, idx_hi, v3_lo, U, V, ...] + + # drop the least significant 16 bits from the lower 32-bit chunks (8 cycles) + u32unchecked_shr.16 movup.2 u32unchecked_shr.16 swap + # => [idx_lo_u, idx_lo_v, idx_hi, U, V, ...] + + # make sure the lower 16-bit chunks are different (5 cycles) + dup dup.2 neq assert + # => [idx_lo_u, idx_lo_v, idx_hi, U, V, ...] +end + # GET # ================================================================================================= @@ -630,7 +691,7 @@ end #! must be greater or equal to 16, but smaller than 32. #! #! Cycles: 216 -proc.replace_16.3 +proc.replace_32.3 # save k3 into loc[0][0] - (6 cycles) swapw dup loc_store.0 # => [K, V, R, ...] @@ -732,6 +793,157 @@ proc.replace_16.3 # => [0, 0, 0, 0, R_new, ...] end +#! Replaces a leaf node at depth 16 or 32 with a subtree containing two leaf nodes at depth 48 +#! such that one of the leaf nodes commits to a key-value pair equal to the leaf node at the +#! original depth, and the other leaf node comments to the key-value pair being inserted. +#! +#! Input: [E, idx_lo_e, idx_lo_n, idx_hi, d, K_e, K, V, R, ...]; +#! Output:[0, 0, 0, 0, R_new, ...] +#! +#! Where: +#! - R is the initial root of the TSMT, and R_new is the new root of the TSMT. +#! - K, V is the key-value pair for the leaf node to be inserted into the TSMT. +#! - d is the depth of the current leaf node (i.e., depth 16 or 32). +#! - idx_hi is the index of the last common node on the path from R to the leaves at depth 48. +#! - idx_lo_e and idx_lo_n are the indexes of the new leaf nodes in a subtree rooted in the +#! last common node. +#! - E is a root of an empty subtree at depth d. +#! +#! This procedure consists of three high-level steps: +#! - First, insert M = hash([K_e, V_e], domain=48) into an empty subtree at depth 48 - d, where +#! K_e and V_e are the key-value pair for the existing leaf node at depth d. This outputs the +#! new root of the subtree T. +#! - Then, insert N = hash([K, V], domain=48) into a subtree with root T. This outputs the new +#! root of the subtree P_new. +#! - Then, insert P_new into the TSMT with root R at depth d. +#! +#! This procedure succeeds only if: +#! - Node at depth d is a leaf node. +#! +#! The procedure assumes but does not check that: +#! - d is either 16 or 32. +#! - idx_hi is within range valid for depth d. +#! - idx_lo_e and idx_hi_e are different values. +#! - idx_lo_e and idx_hi_e are within range valid for depth 48 - d. +#! +#! Cycles: 195 +proc.replace_48.4 + # save E into loc[3] and drop it from the stack (7 cycles) + loc_storew.3 dropw + # => [idx_lo_e, idx_lo_n, idx_hi, d, K_e, K, V, R, ...]; + + # save [d, idx_hi, idx_lo_n, idx_lo_e] into loc[0] (3 cycles) + loc_storew.0 + # => [idx_lo_e, idx_lo_n, idx_hi, d, K_e, K, V, R, ...]; + + # prepare the stack for computing P = hash([K_e, V_e], domain=d) + + # load V_e from the advice provider and save it into loc[1] (5 cycles) + adv_loadw loc_storew.1 + # => [V_e, K_e, K, V, R, ...]; + + # (6 cycles) + push.0 loc_load.0 push.0.0 + # => [0, 0, d, 0, V_e, K_e, K, V, R, ...]; + + # save K_e into loc[2] - (5 cycles) + swapw.2 loc_storew.2 swapw + # => [V_e, K_e, 0, 0, d, 0, K, V, R, ...]; + + # compute P = hash([K_e, V_e], domain=d) (1 cycle) + hperm + # => [X, P, X, K, V, R, ...]; + + # prepare the stack for computing M = hash([K_e, V_e], domain=48) + + # load K_e and V_e from loc[2] and loc[1] respectively (13 cycles) + loc_loadw.2 push.0.48.0.0 swapw.2 swapw.3 loc_loadw.1 + # => [V_e, K_e, 0, 0, 48, 0, P, K, V, R, ...]; + + # compute M = hash([K_e, V_e], domain=48) (1 cycle) + hperm + # => [X, M, X, P, K, V, R, ...]; + + # load the root of empty subtree at depth d from loc[3] (3 cycles) + loc_loadw.3 + # => [E, M, X, P, K, V, R, ...]; + + # prepare the stack for inserting M into E + + # (5 cycles) + swapw swapw.2 loc_loadw.0 + # => [idx_lo_e, idx_lo_n, idx_hi, d, E, M, P, K, V, R, ...]; + + # (6 cycles) + movdn.3 drop drop neg add.48 + # => [48 - d, idx_lo_e, E, M, P, K, V, R, ...]; + + # insert M into an empty subtree rooted at E; this leaves a root of empty subtree at depth 48 + # on the stack - though, we don't need to verify this (29 cycles) + mtree_set + # => [E48, T, P, K, V, R, ...]; + + # prepare the stack for computing N = hash([K, V], domain=48) + + # (5 cycles) + dropw swapdw + # => [K, V, T, P, R, ...]; + + # (5 cycles) + push.0.48.0.0 swapw.2 + # => [V, K, 0, 0, 48, 0, T, P, R, ...]; + + # compute N = hash([K, V], domain=48) - (1 cycles) + hperm + # => [X, N, X, T, P, R, ...]; + + # prepare the stack for inserting N into T + + # (6 cycles) + dropw swapw.2 swapw + # => [X, T, N, P, R, ...]; + + # (3 cycles) + loc_loadw.0 + # => [idx_lo_e, idx_lo_n, idx_hi, d, T, N, P, R, ...]; + + # (6 cycles) + drop movdn.2 drop neg add.48 + # => [48 - d, idx_lo_n, T, N, P, R, ...]; + + # insert N into a subtree with root T; this leaves a root of an empty subtree at depth 48 + # on the stack - though, we don't need to verify this (29 cycles) + mtree_set + # => [E48, P_new, P, R, ...]; + + # prepare the stack for inserting P_new into R + + # (4 cycles) + swapw.3 swapw swapw.2 swapw.3 + # => [E48, R, P_new, P, ...]; + + # (3 cycles) + loc_loadw.0 + # => [idx_lo_e, idx_lo_n, idx_hi, d, R, P_new, P, ...]; + + # (3 cycles) + drop drop swap + # => [d, idx_hi, R, P_new, P, ...]; + + # insert P_new into the tree rooted at R; this also leaves P_old (the old value of the node) + # on the stack (29 cycles) + mtree_set + # => [P_old, R_new, P, ...]; + + # make sure P and P_old are the same (13 cycles) + swapw swapw.2 assert_eqw + # => [R_new, ...]; + + # put the return value onto the stack and return (4 cycles) + padw + # => [0, 0, 0, 0, R_new, ...] +end + #! Inserts the specified value into a Sparse Merkle Tree with the specified root under the #! specified key. #! @@ -753,8 +965,9 @@ end #! - Depth 32: 181 #! - Depth 48: 181 #! - Replace a leaf with a subtree: -#! - Depth 32: 243 -#! - Depth 48: TODO +#! - Depth 16 -> 32: 243 +#! - Depth 16 -> 48: 263 +#! - Depth 32 -> 48: 253 export.insert # make sure the value is not [ZERO; 4] (17 cycles) repeat.4 @@ -770,7 +983,7 @@ export.insert # => [is_update, f0, f1, f2, V, K, R, ...] # call the inner procedure depending on the type of insert and depth - if.true # --- update leaf ------------------------------------------------- + if.true # --- update leaf --------------------------------------------------------------------- # => [is_16_or_32, is_16_or_48, ZERO, V, K, R, ...] if.true if.true # --- update a leaf node at depth 16 --- @@ -812,7 +1025,7 @@ export.insert end else # => [is_simple_insert, is_16_or_32, is_16_or_48, V, K, R, ...] - if.true # --- inset new leaf ---------------------------------------------- + if.true # --- insert new leaf ------------------------------------------------------------- if.true if.true exec.insert_16 @@ -827,19 +1040,56 @@ export.insert push.0 assert end end - else # --- replace leaf with subtree ---------------------------------- + else # --- replace leaf with subtree ------------------------------------------------------ if.true - if.true - # replace a leaf node at depth 16 with a subtree containing - # two leaf nodes at depth 32. - exec.replace_16 - else - # not implemented - push.0 assert + if.true # --- replace a leaf at depth 16 with two leaves at depth 32 --- + exec.replace_32 + else # --- replace a leaf at depth 16 with two leaves at depth 48 --- + # load K_e from the advice provider (5 cycles) + swapw adv_push.4 + # => [K_e, K, V, R, ...] + + # (30 cycles) + exec.get_common_prefix_16 + # => [idx_lo_e, idx_lo, idx_hi, K_e, K, V, R, ...] + + # (2 cycles) + push.16 movdn.3 + # => [idx_lo_e, idx_lo, idx_hi, 16, K_e, K, V, R, ...] + + # (4 cycles) + push.EMPTY_16_0.EMPTY_16_1.EMPTY_16_2.EMPTY_16_3 + # => [E, idx_lo_e, idx_lo, idx_hi, 16, K_e, K, V, R, ...] + + # (195 cycles) + exec.replace_48 + # => [0, 0, 0, 0, R_new, ...] end else - # not implemented - push.0 assert + if.true # --- replace a leaf at depth 32 with two leaves at depth 48 --- + # load K_e from the advice provider (5 cycles) + swapw adv_push.4 + # => [K_e, K, V, R, ...] + + # (20 cycles) + exec.get_common_prefix_32 + # => [idx_lo_e, idx_lo, idx_hi, K_e, K, V, R, ...] + + # (2 cycles) + push.32 movdn.3 + # => [idx_lo_e, idx_lo, idx_hi, 16, K_e, K, V, R, ...] + + # (4 cycles) + push.EMPTY_32_0.EMPTY_32_1.EMPTY_32_2.EMPTY_32_3 + # => [E, idx_lo_e, idx_lo, idx_hi, 16, K_e, K, V, R, ...] + + # (195 cycles) + exec.replace_48 + # => [0, 0, 0, 0, R_new, ...] + else # --- replace a leaf at depth 16, 32, or 48 with two leaves at depth 64 --- + # depth 64 - currently not implemented + push.0 assert + end end end end diff --git a/stdlib/docs/collections/smt.md b/stdlib/docs/collections/smt.md index 3af197d770..9d31187bc6 100644 --- a/stdlib/docs/collections/smt.md +++ b/stdlib/docs/collections/smt.md @@ -3,4 +3,4 @@ | Procedure | Description | | ----------- | ------------- | | get | Returns the value stored under the specified key in a Sparse Merkle Tree with the specified root.

If the value for a given key has not been set, the returned `V` will consist of all zeroes.

Input: [K, R, ...]

Output: [V, R, ...]

Depth 16: 91 cycles

Depth 32: 87 cycles

Depth 48: 94 cycles

Depth 64: unimplemented | -| insert | Inserts the specified value into a Sparse Merkle Tree with the specified root under the

specified key.

The value previously stored in the SMT under this key is left on the stack together with

the updated tree root.

This assumes that the value is not [ZERO; 4]. If it is, the procedure fails.

Input: [V, K, R, ...];

Output:[V_old, R_new, ...]

Cycles:

- Update existing leaf:

- Depth 16: 129

- Depth 32: 126

- Depth 48: 131

- Insert new leaf:

- Depth 16: 100

- Depth 32: 181

- Depth 48: 181

- Replace a leaf with a subtree:

- Depth 32: 243

- Depth 48: TODO | +| insert | Inserts the specified value into a Sparse Merkle Tree with the specified root under the

specified key.

The value previously stored in the SMT under this key is left on the stack together with

the updated tree root.

This assumes that the value is not [ZERO; 4]. If it is, the procedure fails.

Input: [V, K, R, ...];

Output:[V_old, R_new, ...]

Cycles:

- Update existing leaf:

- Depth 16: 129

- Depth 32: 126

- Depth 48: 131

- Insert new leaf:

- Depth 16: 100

- Depth 32: 181

- Depth 48: 181

- Replace a leaf with a subtree:

- Depth 16 -> 32: 243

- Depth 16 -> 48: 263

- Depth 32 -> 48: 253 | diff --git a/stdlib/tests/collections/smt.rs b/stdlib/tests/collections/smt.rs index 5c5a15fe4e..e47ceb860c 100644 --- a/stdlib/tests/collections/smt.rs +++ b/stdlib/tests/collections/smt.rs @@ -147,6 +147,7 @@ fn tsmt_insert_32() { let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); let val_b = [ONE, ONE, ZERO, ZERO]; + // this tests a complex insertion when a leaf node moves from depth 16 to depth 32 let init_smt = smt.clone(); smt.insert(key_b.into(), val_b); assert_insert(&init_smt, key_b, EMPTY_VALUE, val_b, smt.root().into()); @@ -182,8 +183,10 @@ fn tsmt_insert_48() { let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); let val_b = [ONE, ONE, ZERO, ZERO]; - // TODO: test this insertion once complex inserts are working + // this tests a complex insertion when a leaf moves from depth 16 to depth 48 + let init_smt = smt.clone(); smt.insert(key_b.into(), val_b); + assert_insert(&init_smt, key_b, EMPTY_VALUE, val_b, smt.root().into()); // update a value under key_a let init_smt = smt.clone(); @@ -201,6 +204,33 @@ fn tsmt_insert_48() { assert_insert(&init_smt, key_c, EMPTY_VALUE, val_c, smt.root().into()); } +#[test] +fn tsmt_insert_48_from_32() { + let mut smt = TieredSmt::default(); + + // insert a value under key_a into an empty tree + let raw_a = 0b00000000_00000000_11111111_11111111_11111111_11111111_11111111_11111111_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let val_a = [ONE, ZERO, ZERO, ZERO]; + smt.insert(key_a.into(), val_a); + + // insert a value under key_b which has the same 16-bit prefix as A + let raw_b = 0b00000000_00000000_01111111_11111111_01111111_11111111_11111111_11111111_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let val_b = [ONE, ONE, ZERO, ZERO]; + smt.insert(key_b.into(), val_b); + + // insert a value under key_c which has the same 32-bit prefix as A + let raw_c = 0b00000000_00000000_11111111_11111111_00111111_11111111_11111111_11111111_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + let val_c = [ONE, ONE, ONE, ZERO]; + + // this tests a complex insertion when a leaf moves from depth 32 to depth 48 + let init_smt = smt.clone(); + smt.insert(key_c.into(), val_c); + assert_insert(&init_smt, key_c, EMPTY_VALUE, val_c, smt.root().into()); +} + fn assert_insert( init_smt: &TieredSmt, key: RpoDigest, From c3d7b4b07c2b3c89e1773e8cacfd78225b70c265 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sat, 12 Aug 2023 21:26:49 -0700 Subject: [PATCH 17/18] refactor: make TSMT replace32 work similarly to replace48 --- stdlib/asm/collections/smt.masm | 240 ++++++++++++++++---------------- stdlib/docs/collections/smt.md | 2 +- 2 files changed, 123 insertions(+), 119 deletions(-) diff --git a/stdlib/asm/collections/smt.masm b/stdlib/asm/collections/smt.masm index c90e904a03..5a00de3b6f 100644 --- a/stdlib/asm/collections/smt.masm +++ b/stdlib/asm/collections/smt.masm @@ -56,6 +56,33 @@ proc.get_top_48_bits add end +#! Extracts top 16 and the next 16 bits from the most significant elements of U and V. +#! +#! Also verifies that the top 16 bits of these elements are the same, while the next 16 bits are +#! different. +#! +#! Input: [U, V, ...] +#! Output: [(u3 << 16) >> 48, (v3 << 16) >> 48, v3 >> 48, U, V, ...] +#! +#! Cycles: 20 +proc.extract_index_16_16 + # extract the top 16 and the next 16 bits from the most significant element of V (6 cycles) + dup.4 u32split swap drop u32unchecked_divmod.65536 + # => [v3_hi_lo, v3_hi_hi, U, V, ...] + + # extract the top 16 and the next 16 bits from the most significant element of U (4 cycles) + dup.2 u32split u32unchecked_divmod.65536 + # => [u3_hi_lo, u3_hi_hi, u3_lo, v3_hi_lo, v3_hi_hi, U, V, ...] + + # make sure the lower 16 bits are different (5 cycles) + dup dup.4 neq assert + # => [u3_hi_lo, u3_hi_hi, u3_lo, v3_hi_lo, v3_hi_hi, U, V, ...] + + # make sure the top 16 bits are the same (5 cycles) + movdn.2 dup.4 assert_eq drop + # => [u3_hi_lo, v3_hi_lo, v3_hi_hi, U, V, ...] +end + #! Extracts top 16 and the next 32 bits from the most significant elements of U and V. #! #! Also verifies that the top 32 bits of these elements are the same, while the next 16 bits are @@ -65,7 +92,7 @@ end #! Output: [(u3 << 16) >> 32, (v3 << 16) >> 32, v3 >> 48, U, V, ...] #! #! Cycles: 30 -proc.get_common_prefix_16 +proc.extract_index_16_32 # split the most significant elements of U and V into 32-bit chunks and make sure the top # 32 bit chunks are the same (i.e. u3_hi = v3_hi) - (8 cycles) dup.4 u32split dup.2 u32split dup movup.3 assert_eq @@ -99,7 +126,7 @@ end #! Output: [(u3 << 32) >> 48, (v3 << 32) >> 48, v3 >> 32, U, V, ...] #! #! Cycles: 20 -proc.get_common_prefix_32 +proc.extract_index_32_16 # slit teh most significant elements of U and V into 32-bit chunks (4 cycles) dup.4 u32split dup.2 u32split # => [u3_hi, u3_lo, v3_hi, v3_lo, U, V, ...] @@ -385,7 +412,7 @@ end #! Updates a leaf node at depths 16, 32, or 48. #! -#! Input: [d, idx, V, K, R, ...]; +#! Input: [d, idx, V, K, R, ...] #! Output: [V_old, R_new, ...] #! #! Where: @@ -447,8 +474,8 @@ end #! Inserts a new leaf node at depth 16. #! -#! Input: [V, K, R, ...]; -#! Output:[0, 0, 0, 0, R_new, ...] +#! Input: [V, K, R, ...] +#! Output: [0, 0, 0, 0, R_new, ...] #! #! Where: #! - R is the initial root of the TSMT, and R_new is the new root of the TSMT. @@ -493,8 +520,8 @@ end #! Inserts a new leaf node at depth 32. #! -#! Input: [V, K, R, ...]; -#! Output:[0, 0, 0, 0, R_new, ...] +#! Input: [V, K, R, ...] +#! Output: [0, 0, 0, 0, R_new, ...] #! #! Where: #! - R is the initial root of the TSMT, and R_new is the new root of the TSMT. @@ -588,8 +615,8 @@ end #! Inserts a new leaf node at depth 48. #! -#! Input: [V, K, R, ...]; -#! Output:[0, 0, 0, 0, R_new, ...] +#! Input: [V, K, R, ...] +#! Output: [0, 0, 0, 0, R_new, ...] #! #! This procedure is nearly identical to the insert_32 procedure above, adjusted for the use of #! constants and idx_hi/idx_lo computation. It may be possible to combine the two at the expense @@ -670,8 +697,8 @@ end #! one of the leaf nodes commits to a key-value pair equal to the leaf node at depth 16, and the #! other leaf node comments to the key-value pair being inserted. #! -#! Input: [V, K, R, ...]; -#! Output:[0, 0, 0, 0, R_new, ...] +#! Input: [idx_lo_e, idx_lo_n, idx_hi, K_e, K, V, R, ...] +#! Output: [0, 0, 0, 0, R_new, ...] #! #! Where: #! - R is the initial root of the TSMT, and R_new is the new root of the TSMT. @@ -690,101 +717,71 @@ end #! - The key in this node has a common prefix with the key to be inserted. This common prefix #! must be greater or equal to 16, but smaller than 32. #! -#! Cycles: 216 +#! Cycles: 188 proc.replace_32.3 - # save k3 into loc[0][0] - (6 cycles) - swapw dup loc_store.0 - # => [K, V, R, ...] - - # compute N = hash([K, V], domain=32) - (6 cycles) - push.0.32.0.0 swapw.2 hperm - # => [X, N, X, R, ...] + # save [idx_hi, idx_lo_n, idx_lo_e, 16] into loc[0] - (4 cycles) + push.16 loc_storew.0 + # => [16, idx_lo_e, idx_lo_n, idx_hi, K_e, K, V, R, ...] - # load the key associated with the existing leaf P from the advice provider and save it in - # loc[1] - (5 cycles) - adv_loadw loc_storew.1 - # => [K_e, N, X, R, ...] + # load V_e from the advice provider (1 cycle) + adv_loadw + # => [V_e, K_e, K, V, R, ...] - # load the value associated with the existing leaf P from the advice provider and save it in - # loc[2] - (10 cycles) - push.0.16.0.0 swapw.2 swapw.3 adv_loadw loc_storew.2 - # => [V_e, K_e, 0, 0, 16, 0, N, R, ...] + # save K_e and V_e into loc[1] and loc[2] respectively (13 cycles) + push.0.16.0.0 swapw.2 loc_storew.1 swapw loc_storew.2 - # compute P = hash([K_e, V_e], domain=16); we will use this later to prove correct execution - # of mtree_set instruction (1 cycle) + # compute P = hash([K_e, V_e], domain=16) - (1 cycles) hperm - # => [X, P, X, N, R, ...] - - # load K_e from loc[1] - (9 cycles) - push.0.32.0.0 swapw loc_loadw.1 - # => [K_e, 0, 0, 32, 0, P, X, N, R, ...] - - # extract from the most significant element of K_e (i.e., ke_3) two most significant 16-bit - # limbs: idx_hi_eidx_lo_e - (6 cycles) - dup exec.get_top_32_bits u32unchecked_divmod.65536 - # => [idx_lo_e, idx_hi_e, K_e, 0, 0, 32, 0, P, X, N, R, ...] - - # load k3 from loc[0][0] and also extract the two most significant 16-bit limbs from it - # (8 cycles) - loc_load.0 exec.get_top_32_bits u32unchecked_divmod.65536 - # => [idx_lo, idx_hi, idx_lo_e, idx_hi_e, K_e, 0, 0, 32, 0, P, X, N, R, ...] + # => [X, P, X, K, V, R, ...] - # make sure the top 16 bits of both keys are the same (4 cycles) - movup.3 dup.2 assert_eq - # => [idx_lo, idx_hi, idx_lo_e, K_e, 0, 0, 32, 0, P, X, N, R, ...] + # prepare the stack for computing M = hash([K_e, V_e], domain=32) - (13 cycles) + loc_loadw.1 push.0.32.0.0 swapw.2 swapw.3 loc_loadw.2 + # => [V_e, K_e, 0, 0, 32, 0, P, K, V, R, ...] - # make sure that the next 16 bits of the keys are not the same; this proves that the keys - # have the same 16-bit prefix, but not the same 32-bit prefix (6 cycles) - movup.2 dup dup.2 neq assert - # => [idx_lo_e, idx_lo, idx_hi, K_e, 0, 0, 32, 0, P, X, N, R, ...] - - # save [idx_hi, idx_lo, idx_lo_e, 0] into loc[0] - (4 cycles) - push.0 loc_storew.0 - # => [0, idx_lo_e, idx_lo, idx_hi, K_e, 0, 0, 32, 0, P, X, N, R, ...] - - # load the value V_e from loc[2] and compute M = hash([K_e, K_e], domain=32) - (4 cycles) - loc_loadw.2 hperm - # => [X, M, X, P, X, N, R, ...] - - # load the indexes from loc[0] and drop all but idx_lo_e from the stack (7 cycles) - loc_loadw.0 drop movdn.2 drop drop - # => [idx_lo_e, M, X, P, X, N, R, ...] + # compute M = hash([K_e, V_e], domain=32) - 1 cycle + hperm + # => [X, M, X, P, K, V, R, ...] # push the root of an empty subtree at depth 16 onto the stack (4 cycles) push.EMPTY_16_0.EMPTY_16_1.EMPTY_16_2.EMPTY_16_3 - # => [E16, idx_lo_e, M, X, P, X, N, R, ...] + # => [E, X, M, X, P, K, V, R, ...] + + # prepare the stack for inserting M into E (8 cycles) + swapw loc_loadw.0 movup.2 drop movup.2 drop + # => [16, idx_lo_e, E, M, X, P, K, V, R, ...] - # insert node M into the empty subtree at depth 16; this leaves the new root of the - # subtree T together with the root of an empty subtree at depth 32 - (31 cycles) - movup.4 push.16 mtree_set - # => [E32, T, X, P, X, N, R, ...] + # insert M into an empty subtree rooted at E; this leaves a root of empty subtree at depth 32 + # on the stack - though, we don't need to verify this (29 cycles) + mtree_set + # => [E32, T, X, P, K, V, R, ...] - # drop the E32 root as we don't need it, and arrange the stack for inserting the next - # leaf (12 cycles) - dropw swapw dropw swapw swapw.3 swapw.2 - # => [X, N, T, P, R, ...] + # prepare the stack for computing N = hash([K, V], domain=32) - (15 cycles) + dropw swapw dropw swapdw push.0.32.0.0 swapw.2 + # => [V, K, 0, 0, 32, 0, T, P, R, ...] - # load the indexes from loc[0] and drop all but idx_lo from the stack (7 cycles) - loc_loadw.0 drop drop swap drop - # => [idx_lo, N, T, P, R, ...] + # compute N = hash([K, V], domain=32) - 1 cycle + hperm + # => [X, N, X, T, P, R, ...] - # insert node N into the subtree with root T at depth 16; this leaves the new root of the - # subtree P_new on the stack together with the root of an empty subtree at depth 32 - (30 cycles) - push.16 mtree_set + # prepare the stack for inserting N into T (13 cycles) + dropw swapw.2 swapw loc_loadw.0 swap drop movup.2 drop + # => [16, idx_lo_n, T, N, P, R, ...] + + # insert M into an empty subtree rooted at T; this leaves a root of empty subtree at depth 32 + # on the stack - though, we don't need to verify this (29 cycles) + mtree_set # => [E32, P_new, P, R, ...] - # prepare the stack for an mtree_set operation against R; we drop the E32 value as we don't - # need it; the index idx_hi is loaded from memory (10 cycles) - dropw swapw swapw.2 loc_load.0 push.16 + # prepare the stack for inserting P_new into R (10 cycles) + swapw.3 swapw swapw.2 swapw.3 loc_loadw.0 movdn.2 drop drop # => [16, idx_hi, R, P_new, P, ...] - # insert node P_new into the TSMT at depth 16; this puts the new value of TSMT root onto the - # stack together with the old value of the node at depth 16 - (29 cycles) + # insert P_new into the tree rooted at R; this also leaves P_old (the old value of the node) + # on the stack (29 cycles) mtree_set # => [P_old, R_new, P, ...] - # make sure P (which we computed as hash([K_e, V_e], domain=16)) and P_old are the same - # (13 cycles) + # make sure P and P_old are the same (13 cycles) swapw swapw.2 assert_eqw # => [R_new, ...] @@ -797,8 +794,8 @@ end #! such that one of the leaf nodes commits to a key-value pair equal to the leaf node at the #! original depth, and the other leaf node comments to the key-value pair being inserted. #! -#! Input: [E, idx_lo_e, idx_lo_n, idx_hi, d, K_e, K, V, R, ...]; -#! Output:[0, 0, 0, 0, R_new, ...] +#! Input: [E, idx_lo_e, idx_lo_n, idx_hi, d, K_e, K, V, R, ...] +#! Output: [0, 0, 0, 0, R_new, ...] #! #! Where: #! - R is the initial root of the TSMT, and R_new is the new root of the TSMT. @@ -830,114 +827,114 @@ end proc.replace_48.4 # save E into loc[3] and drop it from the stack (7 cycles) loc_storew.3 dropw - # => [idx_lo_e, idx_lo_n, idx_hi, d, K_e, K, V, R, ...]; + # => [idx_lo_e, idx_lo_n, idx_hi, d, K_e, K, V, R, ...] # save [d, idx_hi, idx_lo_n, idx_lo_e] into loc[0] (3 cycles) loc_storew.0 - # => [idx_lo_e, idx_lo_n, idx_hi, d, K_e, K, V, R, ...]; + # => [idx_lo_e, idx_lo_n, idx_hi, d, K_e, K, V, R, ...] # prepare the stack for computing P = hash([K_e, V_e], domain=d) # load V_e from the advice provider and save it into loc[1] (5 cycles) adv_loadw loc_storew.1 - # => [V_e, K_e, K, V, R, ...]; + # => [V_e, K_e, K, V, R, ...] # (6 cycles) push.0 loc_load.0 push.0.0 - # => [0, 0, d, 0, V_e, K_e, K, V, R, ...]; + # => [0, 0, d, 0, V_e, K_e, K, V, R, ...] # save K_e into loc[2] - (5 cycles) swapw.2 loc_storew.2 swapw - # => [V_e, K_e, 0, 0, d, 0, K, V, R, ...]; + # => [V_e, K_e, 0, 0, d, 0, K, V, R, ...] # compute P = hash([K_e, V_e], domain=d) (1 cycle) hperm - # => [X, P, X, K, V, R, ...]; + # => [X, P, X, K, V, R, ...] # prepare the stack for computing M = hash([K_e, V_e], domain=48) # load K_e and V_e from loc[2] and loc[1] respectively (13 cycles) loc_loadw.2 push.0.48.0.0 swapw.2 swapw.3 loc_loadw.1 - # => [V_e, K_e, 0, 0, 48, 0, P, K, V, R, ...]; + # => [V_e, K_e, 0, 0, 48, 0, P, K, V, R, ...] # compute M = hash([K_e, V_e], domain=48) (1 cycle) hperm - # => [X, M, X, P, K, V, R, ...]; + # => [X, M, X, P, K, V, R, ...] # load the root of empty subtree at depth d from loc[3] (3 cycles) loc_loadw.3 - # => [E, M, X, P, K, V, R, ...]; + # => [E, M, X, P, K, V, R, ...] # prepare the stack for inserting M into E # (5 cycles) swapw swapw.2 loc_loadw.0 - # => [idx_lo_e, idx_lo_n, idx_hi, d, E, M, P, K, V, R, ...]; + # => [idx_lo_e, idx_lo_n, idx_hi, d, E, M, P, K, V, R, ...] # (6 cycles) movdn.3 drop drop neg add.48 - # => [48 - d, idx_lo_e, E, M, P, K, V, R, ...]; + # => [48 - d, idx_lo_e, E, M, P, K, V, R, ...] # insert M into an empty subtree rooted at E; this leaves a root of empty subtree at depth 48 # on the stack - though, we don't need to verify this (29 cycles) mtree_set - # => [E48, T, P, K, V, R, ...]; + # => [E48, T, P, K, V, R, ...] # prepare the stack for computing N = hash([K, V], domain=48) # (5 cycles) dropw swapdw - # => [K, V, T, P, R, ...]; + # => [K, V, T, P, R, ...] # (5 cycles) push.0.48.0.0 swapw.2 - # => [V, K, 0, 0, 48, 0, T, P, R, ...]; + # => [V, K, 0, 0, 48, 0, T, P, R, ...] # compute N = hash([K, V], domain=48) - (1 cycles) hperm - # => [X, N, X, T, P, R, ...]; + # => [X, N, X, T, P, R, ...] # prepare the stack for inserting N into T # (6 cycles) dropw swapw.2 swapw - # => [X, T, N, P, R, ...]; + # => [X, T, N, P, R, ...] # (3 cycles) loc_loadw.0 - # => [idx_lo_e, idx_lo_n, idx_hi, d, T, N, P, R, ...]; + # => [idx_lo_e, idx_lo_n, idx_hi, d, T, N, P, R, ...] # (6 cycles) drop movdn.2 drop neg add.48 - # => [48 - d, idx_lo_n, T, N, P, R, ...]; + # => [48 - d, idx_lo_n, T, N, P, R, ...] # insert N into a subtree with root T; this leaves a root of an empty subtree at depth 48 # on the stack - though, we don't need to verify this (29 cycles) mtree_set - # => [E48, P_new, P, R, ...]; + # => [E48, P_new, P, R, ...] # prepare the stack for inserting P_new into R # (4 cycles) swapw.3 swapw swapw.2 swapw.3 - # => [E48, R, P_new, P, ...]; + # => [E48, R, P_new, P, ...] # (3 cycles) loc_loadw.0 - # => [idx_lo_e, idx_lo_n, idx_hi, d, R, P_new, P, ...]; + # => [idx_lo_e, idx_lo_n, idx_hi, d, R, P_new, P, ...] # (3 cycles) drop drop swap - # => [d, idx_hi, R, P_new, P, ...]; + # => [d, idx_hi, R, P_new, P, ...] # insert P_new into the tree rooted at R; this also leaves P_old (the old value of the node) # on the stack (29 cycles) mtree_set - # => [P_old, R_new, P, ...]; + # => [P_old, R_new, P, ...] # make sure P and P_old are the same (13 cycles) swapw swapw.2 assert_eqw - # => [R_new, ...]; + # => [R_new, ...] # put the return value onto the stack and return (4 cycles) padw @@ -952,8 +949,8 @@ end #! #! This assumes that the value is not [ZERO; 4]. If it is, the procedure fails. #! -#! Input: [V, K, R, ...]; -#! Output:[V_old, R_new, ...] +#! Input: [V, K, R, ...] +#! Output: [V_old, R_new, ...] #! #! Cycles: #! - Update existing leaf: @@ -965,7 +962,7 @@ end #! - Depth 32: 181 #! - Depth 48: 181 #! - Replace a leaf with a subtree: -#! - Depth 16 -> 32: 243 +#! - Depth 16 -> 32: 242 #! - Depth 16 -> 48: 263 #! - Depth 32 -> 48: 253 export.insert @@ -1043,6 +1040,15 @@ export.insert else # --- replace leaf with subtree ------------------------------------------------------ if.true if.true # --- replace a leaf at depth 16 with two leaves at depth 32 --- + # load K_e from the advice provider (5 cycles) + swapw adv_push.4 + # => [K_e, K, V, R, ...] + + # (20 cycles) + exec.extract_index_16_16 + # => [idx_lo_e, idx_lo, idx_hi, K_e, K, V, R, ...] + + # (188 cycles) exec.replace_32 else # --- replace a leaf at depth 16 with two leaves at depth 48 --- # load K_e from the advice provider (5 cycles) @@ -1050,7 +1056,7 @@ export.insert # => [K_e, K, V, R, ...] # (30 cycles) - exec.get_common_prefix_16 + exec.extract_index_16_32 # => [idx_lo_e, idx_lo, idx_hi, K_e, K, V, R, ...] # (2 cycles) @@ -1063,7 +1069,6 @@ export.insert # (195 cycles) exec.replace_48 - # => [0, 0, 0, 0, R_new, ...] end else if.true # --- replace a leaf at depth 32 with two leaves at depth 48 --- @@ -1072,7 +1077,7 @@ export.insert # => [K_e, K, V, R, ...] # (20 cycles) - exec.get_common_prefix_32 + exec.extract_index_32_16 # => [idx_lo_e, idx_lo, idx_hi, K_e, K, V, R, ...] # (2 cycles) @@ -1085,7 +1090,6 @@ export.insert # (195 cycles) exec.replace_48 - # => [0, 0, 0, 0, R_new, ...] else # --- replace a leaf at depth 16, 32, or 48 with two leaves at depth 64 --- # depth 64 - currently not implemented push.0 assert diff --git a/stdlib/docs/collections/smt.md b/stdlib/docs/collections/smt.md index 9d31187bc6..2ff3771647 100644 --- a/stdlib/docs/collections/smt.md +++ b/stdlib/docs/collections/smt.md @@ -3,4 +3,4 @@ | Procedure | Description | | ----------- | ------------- | | get | Returns the value stored under the specified key in a Sparse Merkle Tree with the specified root.

If the value for a given key has not been set, the returned `V` will consist of all zeroes.

Input: [K, R, ...]

Output: [V, R, ...]

Depth 16: 91 cycles

Depth 32: 87 cycles

Depth 48: 94 cycles

Depth 64: unimplemented | -| insert | Inserts the specified value into a Sparse Merkle Tree with the specified root under the

specified key.

The value previously stored in the SMT under this key is left on the stack together with

the updated tree root.

This assumes that the value is not [ZERO; 4]. If it is, the procedure fails.

Input: [V, K, R, ...];

Output:[V_old, R_new, ...]

Cycles:

- Update existing leaf:

- Depth 16: 129

- Depth 32: 126

- Depth 48: 131

- Insert new leaf:

- Depth 16: 100

- Depth 32: 181

- Depth 48: 181

- Replace a leaf with a subtree:

- Depth 16 -> 32: 243

- Depth 16 -> 48: 263

- Depth 32 -> 48: 253 | +| insert | Inserts the specified value into a Sparse Merkle Tree with the specified root under the

specified key.

The value previously stored in the SMT under this key is left on the stack together with

the updated tree root.

This assumes that the value is not [ZERO; 4]. If it is, the procedure fails.

Input: [V, K, R, ...]

Output: [V_old, R_new, ...]

Cycles:

- Update existing leaf:

- Depth 16: 129

- Depth 32: 126

- Depth 48: 131

- Insert new leaf:

- Depth 16: 100

- Depth 32: 181

- Depth 48: 181

- Replace a leaf with a subtree:

- Depth 16 -> 32: 242

- Depth 16 -> 48: 263

- Depth 32 -> 48: 253 | From 7d707625f40f4422d05a1b491010fc5ce4c28d4e Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Mon, 14 Aug 2023 18:59:38 -0700 Subject: [PATCH 18/18] fix: implement TSMT leaf insertion into advice map --- CHANGELOG.md | 1 + assembly/src/ast/nodes/advice.rs | 6 + assembly/src/ast/parsers/adv_ops.rs | 4 + core/src/operations/decorators/advice.rs | 16 ++ docs/src/user_docs/assembly/io_operations.md | 16 +- processor/src/advice/mod.rs | 1 + processor/src/advice/source.rs | 4 +- processor/src/decorators/adv_map_injectors.rs | 50 +++- processor/src/decorators/mod.rs | 1 + stdlib/asm/collections/smt.masm | 36 ++- stdlib/tests/collections/smt.rs | 220 +++++++++++------- 11 files changed, 251 insertions(+), 104 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7777072cde..3467a431ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Added support for nested modules (#992). - Added support for the arithmetic expressions in constant values (#1026). - Added support for module aliases (#1037). +- Added `adv.insert_hperm` decorator (#1042). #### VM Internals - Simplified range checker and removed 1 main and 1 auxiliary trace column (#949). diff --git a/assembly/src/ast/nodes/advice.rs b/assembly/src/ast/nodes/advice.rs index 00d054da9e..9ea83d4aad 100644 --- a/assembly/src/ast/nodes/advice.rs +++ b/assembly/src/ast/nodes/advice.rs @@ -27,6 +27,7 @@ pub enum AdviceInjectorNode { InsertMem, InsertHdword, InsertHdwordImm { domain: u8 }, + InsertHperm, } impl From<&AdviceInjectorNode> for AdviceInjector { @@ -59,6 +60,7 @@ impl From<&AdviceInjectorNode> for AdviceInjector { InsertHdwordImm { domain } => Self::HdwordToMap { domain: Felt::from(*domain), }, + InsertHperm => Self::HpermToMap, } } } @@ -79,6 +81,7 @@ impl fmt::Display for AdviceInjectorNode { InsertMem => write!(f, "insert_mem"), InsertHdword => write!(f, "insert_hdword"), InsertHdwordImm { domain } => write!(f, "insert_hdword.{domain}"), + InsertHperm => writeln!(f, "insert_hperm"), } } } @@ -98,6 +101,7 @@ const PUSH_MTNODE: u8 = 8; const INSERT_MEM: u8 = 9; const INSERT_HDWORD: u8 = 10; const INSERT_HDWORD_IMM: u8 = 11; +const INSERT_HPERM: u8 = 12; impl Serializable for AdviceInjectorNode { fn write_into(&self, target: &mut W) { @@ -124,6 +128,7 @@ impl Serializable for AdviceInjectorNode { target.write_u8(INSERT_HDWORD_IMM); target.write_u8(*domain); } + InsertHperm => target.write_u8(INSERT_HPERM), } } } @@ -158,6 +163,7 @@ impl Deserializable for AdviceInjectorNode { let domain = source.read_u8()?; Ok(AdviceInjectorNode::InsertHdwordImm { domain }) } + INSERT_HPERM => Ok(AdviceInjectorNode::InsertHperm), val => Err(DeserializationError::InvalidValue(val.to_string())), } } diff --git a/assembly/src/ast/parsers/adv_ops.rs b/assembly/src/ast/parsers/adv_ops.rs index 1393796a05..c4688a9ac6 100644 --- a/assembly/src/ast/parsers/adv_ops.rs +++ b/assembly/src/ast/parsers/adv_ops.rs @@ -81,6 +81,10 @@ pub fn parse_adv_inject(op: &Token) -> Result { } _ => return Err(ParsingError::extra_param(op)), }, + "insert_hperm" => match op.num_parts() { + 2 => AdvInject(InsertHperm), + _ => return Err(ParsingError::extra_param(op)), + }, _ => return Err(ParsingError::invalid_op(op)), }; diff --git a/core/src/operations/decorators/advice.rs b/core/src/operations/decorators/advice.rs index b6c7229d84..854a4c7469 100644 --- a/core/src/operations/decorators/advice.rs +++ b/core/src/operations/decorators/advice.rs @@ -214,6 +214,21 @@ pub enum AdviceInjector { /// Where KEY is computed as hash(A || B, domain), where domain is provided via the immediate /// value. HdwordToMap { domain: Felt }, + + /// Reads three words from the operand stack and inserts the top two words into the advice map + /// under the key defined by applying an RPO permutation to all three words. + /// + /// Inputs: + /// Operand stack: [B, A, C, ...] + /// Advice map: {...} + /// + /// Outputs: + /// Operand stack: [B, A, C, ...] + /// Advice map: {KEY: [a0, a1, a2, a3, b0, b1, b2, b3]} + /// + /// Where KEY is computed by extracting the digest elements from hperm([C, A, B]). For example, + /// if C is [0, d, 0, 0], KEY will be set as hash(A || B, d). + HpermToMap, } impl fmt::Display for AdviceInjector { @@ -238,6 +253,7 @@ impl fmt::Display for AdviceInjector { Self::SmtInsert => write!(f, "smt_insert"), Self::MemToMap => write!(f, "mem_to_map"), Self::HdwordToMap { domain } => write!(f, "hdword_to_map.{domain}"), + Self::HpermToMap => write!(f, "hperm_to_map"), } } } diff --git a/docs/src/user_docs/assembly/io_operations.md b/docs/src/user_docs/assembly/io_operations.md index a754fd79f0..053c34a509 100644 --- a/docs/src/user_docs/assembly/io_operations.md +++ b/docs/src/user_docs/assembly/io_operations.md @@ -47,14 +47,16 @@ Advice injectors fall into two categories: (1) injectors which push new data ont | Instruction | Stack_input | Stack_output | Notes | | -------------------------------------------- | -------------------------- | -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| adv.push_mapval
adv.push_mapval.*s* | [K, ... ] | [K, ... ] | Pushes a list of field elements onto the advice stack. The list is looked up in the advice map using word $K$ as the key. If offset $s$ is provided, the key is taken starting from item $s$ on the stack. | +| adv.push_mapval
adv.push_mapval.*s* | [K, ... ] | [K, ... ] | Pushes a list of field elements onto the advice stack. The list is looked up in the advice map using word $K$ as the key. If offset $s$ is provided, the key is taken starting from item $s$ on the stack. | | adv.push_mapvaln
adv.push_mapvaln.*s* | [K, ... ] | [K, ... ] | Pushes a list of field elements together with the number of elements onto the advice stack. The list is looked up in the advice map using word $K$ as the key. If offset $s$ is provided, the key is taken starting from item $s$ on the stack. | -| adv.push_mtnode | [d, i, R, ... ] | [d, i, R, ... ] | Pushes a node of a Merkle tree with root $R$ at depth $d$ and index $i$ from Merkle store onto the advice stack. | -| adv.push_u64div | [b1, b0, a1, a0, ...] | [b1, b0, a1, a0, ...] | Pushes the result of `u64` division $a / b$ onto the advice stack. Both $a$ and $b$ are represented using 32-bit limbs. The result consists of both the quotient and the remainder. | -| adv.push_ext2intt | [osize, isize, iptr, ... ] | [osize, isize, iptr, ... ] | Given evaluations of a polynomial over some specified domain, interpolates the evaluations into a polynomial in coefficient form and pushes the result into the advice stack. | -| adv.smt_get | [K, R, ... ] | [K, R, ... ] | Pushes values onto the advice stack which are required for successful retrieval of a value under the key $K$ from a Sparse Merkle Tree with root $R$. | -| adv.insert_mem | [K, a, b, ... ] | [K, a, b, ... ] | Reads words $data \leftarrow mem[a] .. mem[b]$ from memory, and save the data into $advice\_map[K] \leftarrow data$. | -| adv.insert_hdword
adv.insert_hdword.*d* | [B, A, ... ] | [B, A, ... ] | Reads top two words from the stack, computes a key as $K \leftarrow hash(A || b, d)$, and saves the data into $advice\_map[K] \leftarrow [A, B]$. $d$ is an optional domain value which can be between $0$ and $255$, default value $0$. | +| adv.push_mtnode | [d, i, R, ... ] | [d, i, R, ... ] | Pushes a node of a Merkle tree with root $R$ at depth $d$ and index $i$ from Merkle store onto the advice stack. | +| adv.push_u64div | [b1, b0, a1, a0, ...] | [b1, b0, a1, a0, ...] | Pushes the result of `u64` division $a / b$ onto the advice stack. Both $a$ and $b$ are represented using 32-bit limbs. The result consists of both the quotient and the remainder. | +| adv.push_ext2intt | [osize, isize, iptr, ... ] | [osize, isize, iptr, ... ] | Given evaluations of a polynomial over some specified domain, interpolates the evaluations into a polynomial in coefficient form and pushes the result into the advice stack. | +| adv.smt_get | [K, R, ... ] | [K, R, ... ] | Pushes values onto the advice stack which are required for successful retrieval of a value under the key $K$ from a Sparse Merkle Tree with root $R$. | +| adv.smt_insert | [V, K, R, ...] | [V, K, R, ...] | Pushes values onto the advice stack which are required for successful insertion of a key-value pair $(K, V)$ into a Sparse Merkle Tree with root $R$. | +| adv.insert_mem | [K, a, b, ... ] | [K, a, b, ... ] | Reads words $data \leftarrow mem[a] .. mem[b]$ from memory, and save the data into $advice\_map[K] \leftarrow data$. | +| adv.insert_hdword
adv.insert_hdword.*d* | [B, A, ... ] | [B, A, ... ] | Reads top two words from the stack, computes a key as $K \leftarrow hash(A || b, d)$, and saves the data into $advice\_map[K] \leftarrow [A, B]$. $d$ is an optional domain value which can be between $0$ and $255$, default value $0$. | +| adv.insert_hperm | [B, A, C, ...] | [B, A, C, ...] | Reads top three words from the stack, computes a key as $K \leftarrow permute(C, A, B).digest$, and saves data into $advice\_mpa[K] \leftarrow [A, B]$. | ### Random access memory diff --git a/processor/src/advice/mod.rs b/processor/src/advice/mod.rs index a141d6019d..8382a2075b 100644 --- a/processor/src/advice/mod.rs +++ b/processor/src/advice/mod.rs @@ -103,6 +103,7 @@ pub trait AdviceProvider { // ADVICE MAP // -------------------------------------------------------------------------------------------- + /// Returns a reference to the value(s) associated with the specified key in the advice map. fn get_mapped_values(&self, key: &[u8; 32]) -> Option<&[Felt]>; diff --git a/processor/src/advice/source.rs b/processor/src/advice/source.rs index d87f7ed436..e52df4509b 100644 --- a/processor/src/advice/source.rs +++ b/processor/src/advice/source.rs @@ -9,13 +9,13 @@ pub enum AdviceSource { /// Puts a single value onto the advice stack. Value(Felt), - /// Puts a word (4 elements) ont the the stack. + /// Puts a word (4 elements) onto the stack. Word(Word), /// Fetches a list of elements under the specified key from the advice map and pushes them onto /// the advice stack. /// - /// If `include_len` is set to true, this also pushes the number of elements ont the advice + /// If `include_len` is set to true, this also pushes the number of elements onto the advice /// stack. /// /// Note: this operation doesn't consume the map element so it can be called multiple times diff --git a/processor/src/decorators/adv_map_injectors.rs b/processor/src/decorators/adv_map_injectors.rs index 34eb44a643..aa65bbadbc 100644 --- a/processor/src/decorators/adv_map_injectors.rs +++ b/processor/src/decorators/adv_map_injectors.rs @@ -1,5 +1,9 @@ use super::{AdviceProvider, ExecutionError, Process}; -use vm_core::{crypto::hash::Rpo256, utils::collections::Vec, Felt, StarkField, WORD_SIZE, ZERO}; +use vm_core::{ + crypto::hash::{Rpo256, RpoDigest}, + utils::collections::Vec, + Felt, StarkField, WORD_SIZE, ZERO, +}; // ADVICE INJECTORS // ================================================================================================ @@ -72,6 +76,50 @@ where self.advice_provider.insert_into_map(key.into(), values) } + /// Reads three words from the operand stack and inserts the top two words into the advice map + /// under the key defined by applying an RPO permutation to all three words. + /// + /// Inputs: + /// Operand stack: [B, A, C, ...] + /// Advice map: {...} + /// + /// Outputs: + /// Operand stack: [B, A, C, ...] + /// Advice map: {KEY: [a0, a1, a2, a3, b0, b1, b2, b3]} + /// + /// Where KEY is computed by extracting the digest elements from hperm([C, A, B]). For example, + /// if C is [0, d, 0, 0], KEY will be set as hash(A || B, d). + pub(super) fn insert_hperm_into_adv_map(&mut self) -> Result<(), ExecutionError> { + // read the state from the stack + let mut state = [ + self.stack.get(11), + self.stack.get(10), + self.stack.get(9), + self.stack.get(8), + self.stack.get(7), + self.stack.get(6), + self.stack.get(5), + self.stack.get(4), + self.stack.get(3), + self.stack.get(2), + self.stack.get(1), + self.stack.get(0), + ]; + + // get the values to be inserted into the advice map from the state + let values = state[Rpo256::RATE_RANGE].to_vec(); + + // apply the permutation to the state and extract the key from it + Rpo256::apply_permutation(&mut state); + let key = RpoDigest::new( + state[Rpo256::DIGEST_RANGE] + .try_into() + .expect("failed to extract digest from state"), + ); + + self.advice_provider.insert_into_map(key.into(), values) + } + // HELPER METHODS // -------------------------------------------------------------------------------------------- diff --git a/processor/src/decorators/mod.rs b/processor/src/decorators/mod.rs index 811f53b712..a4e601e438 100644 --- a/processor/src/decorators/mod.rs +++ b/processor/src/decorators/mod.rs @@ -48,6 +48,7 @@ where AdviceInjector::SmtInsert => self.push_smtinsert_inputs(), AdviceInjector::MemToMap => self.insert_mem_values_into_adv_map(), AdviceInjector::HdwordToMap { domain } => self.insert_hdword_into_adv_map(*domain), + AdviceInjector::HpermToMap => self.insert_hperm_into_adv_map(), } } diff --git a/stdlib/asm/collections/smt.masm b/stdlib/asm/collections/smt.masm index 5a00de3b6f..a831a09044 100644 --- a/stdlib/asm/collections/smt.masm +++ b/stdlib/asm/collections/smt.masm @@ -127,7 +127,7 @@ end #! #! Cycles: 20 proc.extract_index_32_16 - # slit teh most significant elements of U and V into 32-bit chunks (4 cycles) + # split the most significant elements of U and V into 32-bit chunks (4 cycles) dup.4 u32split dup.2 u32split # => [u3_hi, u3_lo, v3_hi, v3_lo, U, V, ...] @@ -435,6 +435,9 @@ proc.update_16_32_48.2 movdn.3 movup.2 drop push.0 swapw.2 loc_storew.1 swapw # => [V, K, 0, 0, d, 0, R, ...] + # insert N |-> [K, V] into the advice map (0 cycles) + adv.insert_hperm + # compute the hash of the node N = hash([K, V], domain=d) - (1 cycle) hperm # => [X, N, X, R, ...] @@ -490,10 +493,13 @@ proc.insert_16 swapw dup exec.get_top_16_bits # => [idx, K, V, R, ...] - # prepare the stack for computing leaf node value (6 cycles) + # prepare the stack for computing N = hash([K, V], domain=16) (6 cycles) movdn.8 push.0.16.0.0 swapw.2 # => [V, K, 0, 0, 16, 0, idx, R, ...] + # insert N |-> [K, V] into the advice map (0 cycles) + adv.insert_hperm + # compute leaf node value as N = hash([K, V], domain=16) (10 cycles) hperm dropw swapw dropw # => [N, idx, R, ...] @@ -554,6 +560,9 @@ proc.insert_32.2 push.0.32.0.0 swapw.2 # => [V, K, 0, 0, 32, 0, P, R, ...] + # insert N |-> [K, V] into the advice map (0 cycles) + adv.insert_hperm + # compute N = hash([K, V], domain=32) (1 cycle) hperm # => [X, N, X, P, R, ...] @@ -634,6 +643,9 @@ proc.insert_48.2 push.0.48.0.0 swapw.2 # => [V, K, 0, 0, 48, 0, P, R, ...] + # insert N |-> [K, V] into the advice map (0 cycles) + adv.insert_hperm + # compute N = hash([K, V], domain=48) (1 cycle) hperm # => [X, N, X, P, R, ...] @@ -695,7 +707,7 @@ end #! Replaces a leaf node at depth 16 with a subtree containing two leaf nodes at depth 32 such that #! one of the leaf nodes commits to a key-value pair equal to the leaf node at depth 16, and the -#! other leaf node comments to the key-value pair being inserted. +#! other leaf node commits to the key-value pair being inserted. #! #! Input: [idx_lo_e, idx_lo_n, idx_hi, K_e, K, V, R, ...] #! Output: [0, 0, 0, 0, R_new, ...] @@ -738,6 +750,9 @@ proc.replace_32.3 loc_loadw.1 push.0.32.0.0 swapw.2 swapw.3 loc_loadw.2 # => [V_e, K_e, 0, 0, 32, 0, P, K, V, R, ...] + # insert M |-> [K_e, V_e] into the advice map (0 cycles) + adv.insert_hperm + # compute M = hash([K_e, V_e], domain=32) - 1 cycle hperm # => [X, M, X, P, K, V, R, ...] @@ -759,6 +774,9 @@ proc.replace_32.3 dropw swapw dropw swapdw push.0.32.0.0 swapw.2 # => [V, K, 0, 0, 32, 0, T, P, R, ...] + # insert N |-> [K, V] into the advice map (0 cycles) + adv.insert_hperm + # compute N = hash([K, V], domain=32) - 1 cycle hperm # => [X, N, X, T, P, R, ...] @@ -767,7 +785,7 @@ proc.replace_32.3 dropw swapw.2 swapw loc_loadw.0 swap drop movup.2 drop # => [16, idx_lo_n, T, N, P, R, ...] - # insert M into an empty subtree rooted at T; this leaves a root of empty subtree at depth 32 + # insert N into an empty subtree rooted at T; this leaves a root of empty subtree at depth 32 # on the stack - though, we don't need to verify this (29 cycles) mtree_set # => [E32, P_new, P, R, ...] @@ -792,7 +810,7 @@ end #! Replaces a leaf node at depth 16 or 32 with a subtree containing two leaf nodes at depth 48 #! such that one of the leaf nodes commits to a key-value pair equal to the leaf node at the -#! original depth, and the other leaf node comments to the key-value pair being inserted. +#! original depth, and the other leaf node commits to the key-value pair being inserted. #! #! Input: [E, idx_lo_e, idx_lo_n, idx_hi, d, K_e, K, V, R, ...] #! Output: [0, 0, 0, 0, R_new, ...] @@ -808,7 +826,7 @@ end #! #! This procedure consists of three high-level steps: #! - First, insert M = hash([K_e, V_e], domain=48) into an empty subtree at depth 48 - d, where -#! K_e and V_e are the key-value pair for the existing leaf node at depth d. This outputs the +#! K_e and V_e is the key-value pair for the existing leaf node at depth d. This outputs the #! new root of the subtree T. #! - Then, insert N = hash([K, V], domain=48) into a subtree with root T. This outputs the new #! root of the subtree P_new. @@ -857,6 +875,9 @@ proc.replace_48.4 loc_loadw.2 push.0.48.0.0 swapw.2 swapw.3 loc_loadw.1 # => [V_e, K_e, 0, 0, 48, 0, P, K, V, R, ...] + # insert M |-> [K_e, V_e] into the advice map (0 cycles) + adv.insert_hperm + # compute M = hash([K_e, V_e], domain=48) (1 cycle) hperm # => [X, M, X, P, K, V, R, ...] @@ -890,6 +911,9 @@ proc.replace_48.4 push.0.48.0.0 swapw.2 # => [V, K, 0, 0, 48, 0, T, P, R, ...] + # insert N |-> [K, V] into the advice map (0 cycles) + adv.insert_hperm + # compute N = hash([K, V], domain=48) - (1 cycles) hperm # => [X, N, X, T, P, R, ...] diff --git a/stdlib/tests/collections/smt.rs b/stdlib/tests/collections/smt.rs index e47ceb860c..d5b10993a6 100644 --- a/stdlib/tests/collections/smt.rs +++ b/stdlib/tests/collections/smt.rs @@ -1,9 +1,11 @@ use crate::build_test; use test_utils::{ - crypto::{MerkleStore, RpoDigest, TieredSmt}, - Felt, StarkField, Word, ONE, ZERO, + crypto::{MerkleStore, Rpo256, RpoDigest, TieredSmt}, + stack_to_ints, stack_top_to_ints, Felt, StarkField, Word, ONE, ZERO, }; +type AdvMapEntry = ([u8; 32], Vec); + // CONSTANTS // ================================================================================================ @@ -21,22 +23,22 @@ fn tsmt_get_16() { let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); // make sure we get an empty value for this key - assert_smt_get_opens_correctly(&smt, key_a, EMPTY_VALUE); + assert_get(&smt, key_a, EMPTY_VALUE); // insert a value under this key and make sure we get it back when queried let val_a = [ONE, ONE, ONE, ONE]; smt.insert(key_a, val_a); - assert_smt_get_opens_correctly(&smt, key_a, val_a); + assert_get(&smt, key_a, val_a); // make sure that another key still returns empty value let raw_b = 0b_01111101_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); - assert_smt_get_opens_correctly(&smt, key_b, EMPTY_VALUE); + assert_get(&smt, key_b, EMPTY_VALUE); // make sure that another key with the same 16-bit prefix returns an empty value let raw_c = 0b_01010101_01101100_11111111_11111111_10010110_10010011_11100000_00000000_u64; let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); - assert_smt_get_opens_correctly(&smt, key_c, EMPTY_VALUE); + assert_get(&smt, key_c, EMPTY_VALUE); } #[test] @@ -55,23 +57,23 @@ fn tsmt_get_32() { smt.insert(key_b, val_b); // make sure the values for these keys are retrieved correctly - assert_smt_get_opens_correctly(&smt, key_a, val_a); - assert_smt_get_opens_correctly(&smt, key_b, val_b); + assert_get(&smt, key_a, val_a); + assert_get(&smt, key_b, val_b); // make sure another key with the same 16-bit prefix returns an empty value let raw_c = 0b_01010101_01010101_11100111_11111111_10010110_10010011_11100000_00000000_u64; let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); - assert_smt_get_opens_correctly(&smt, key_c, EMPTY_VALUE); + assert_get(&smt, key_c, EMPTY_VALUE); // make sure keys with the same 32-bit prefixes return empty value let raw_d = 0b_01010101_01010101_00011111_11111111_11111110_10010011_11100000_00000000_u64; let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_d)]); - assert_smt_get_opens_correctly(&smt, key_d, EMPTY_VALUE); + assert_get(&smt, key_d, EMPTY_VALUE); // make sure keys with the same 32-bit prefixes return empty value let raw_e = 0b_01010101_01010101_11100000_11111111_10011111_10010011_11100000_00000000_u64; let key_e = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_e)]); - assert_smt_get_opens_correctly(&smt, key_e, EMPTY_VALUE); + assert_get(&smt, key_e, EMPTY_VALUE); } #[test] @@ -90,26 +92,77 @@ fn tsmt_get_48() { smt.insert(key_b, val_b); // make sure the values for these keys are retrieved correctly - assert_smt_get_opens_correctly(&smt, key_a, val_a); - assert_smt_get_opens_correctly(&smt, key_b, val_b); + assert_get(&smt, key_a, val_a); + assert_get(&smt, key_b, val_b); // make sure another key with the same 32-bit prefix returns an empty value let raw_c = 0b_01010101_01010101_00011111_11111111_00000000_10010011_11100000_00000000_u64; let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); - assert_smt_get_opens_correctly(&smt, key_c, EMPTY_VALUE); + assert_get(&smt, key_c, EMPTY_VALUE); // make sure keys with the same 48-bit prefixes return empty value let raw_d = 0b_01010101_01010101_00011111_11111111_10010110_10010011_00000111_00000000_u64; let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_d)]); - assert_smt_get_opens_correctly(&smt, key_d, EMPTY_VALUE); + assert_get(&smt, key_d, EMPTY_VALUE); // make sure keys with the same 48-bit prefixes return empty value let raw_e = 0b_01010101_01010101_00011111_11111111_11111111_10010011_000001011_00000000_u64; let key_e = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_e)]); - assert_smt_get_opens_correctly(&smt, key_e, EMPTY_VALUE); + assert_get(&smt, key_e, EMPTY_VALUE); +} + +/// Asserts key/value opens to root for the provided Tiered Sparse Merkle tree. +fn assert_get(smt: &TieredSmt, key: RpoDigest, value: Word) { + let root = smt.root(); + let source = r#" + use.std::collections::smt + + begin + exec.smt::get + end + "#; + let initial_stack = [ + root[0].as_int(), + root[1].as_int(), + root[2].as_int(), + root[3].as_int(), + key[0].as_int(), + key[1].as_int(), + key[2].as_int(), + key[3].as_int(), + ]; + let expected_output = [ + value[3].as_int(), + value[2].as_int(), + value[1].as_int(), + value[0].as_int(), + root[3].as_int(), + root[2].as_int(), + root[1].as_int(), + root[0].as_int(), + ]; + + let (store, advice_map) = build_advice_inputs(smt); + let advice_stack = []; + build_test!(source, &initial_stack, &advice_stack, store, advice_map.into_iter()) + .expect_stack(&expected_output); +} + +fn build_advice_inputs(smt: &TieredSmt) -> (MerkleStore, Vec<([u8; 32], Vec)>) { + let store = MerkleStore::from(smt); + let advice_map = smt + .upper_leaves() + .map(|(node, key, value)| { + let mut elements = key.as_elements().to_vec(); + elements.extend(&value); + (node.as_bytes(), elements) + }) + .collect::>(); + + (store, advice_map) } -// INSERTS +// INSERTION TESTS // ================================================================================================ #[test] @@ -121,15 +174,17 @@ fn tsmt_insert_16() { let val_a1 = [ONE, ZERO, ZERO, ZERO]; let val_a2 = [ONE, ONE, ZERO, ZERO]; - // insert a value under key_a into an empty tree + // insert a value under key_a into an empty tree; this inserts one entry into the advice map let init_smt = smt.clone(); smt.insert(key_a.into(), val_a1); - assert_insert(&init_smt, key_a, EMPTY_VALUE, val_a1, smt.root().into()); + let new_map_entries = [build_node_entry(key_a, val_a1, 16)]; + assert_insert(&init_smt, key_a, EMPTY_VALUE, val_a1, smt.root().into(), &new_map_entries); - // update a value under key_a + // update a value under key_a; this inserts one entry into the advice map let init_smt = smt.clone(); smt.insert(key_a.into(), val_a2); - assert_insert(&init_smt, key_a, val_a1, val_a2, smt.root().into()); + let new_map_entries = [build_node_entry(key_a, val_a2, 16)]; + assert_insert(&init_smt, key_a, val_a1, val_a2, smt.root().into(), &new_map_entries); } #[test] @@ -147,25 +202,30 @@ fn tsmt_insert_32() { let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); let val_b = [ONE, ONE, ZERO, ZERO]; - // this tests a complex insertion when a leaf node moves from depth 16 to depth 32 + // this tests a complex insertion when a leaf node moves from depth 16 to depth 32; this + // moves the original node to depth 32, and thus two new entries are added to the advice map let init_smt = smt.clone(); smt.insert(key_b.into(), val_b); - assert_insert(&init_smt, key_b, EMPTY_VALUE, val_b, smt.root().into()); + let new_map_entries = [build_node_entry(key_a, val_a, 32), build_node_entry(key_b, val_b, 32)]; + assert_insert(&init_smt, key_b, EMPTY_VALUE, val_b, smt.root().into(), &new_map_entries); - // update a value under key_a + // update a value under key_a; this adds one new entry to the advice map let init_smt = smt.clone(); let val_a2 = [ONE, ZERO, ZERO, ONE]; smt.insert(key_a.into(), val_a2); - assert_insert(&init_smt, key_a, val_a, val_a2, smt.root().into()); + let new_map_entries = [build_node_entry(key_a, val_a2, 32)]; + assert_insert(&init_smt, key_a, val_a, val_a2, smt.root().into(), &new_map_entries); - // insert a value under key_c which has the same 16-bit prefix as A and B + // insert a value under key_c which has the same 16-bit prefix as A and B; this inserts a new + // node at depth 32, and thus adds one entry to the advice map let raw_c = 0b00000000_00000000_00111111_11111111_11111111_11111111_11111111_11111111_u64; let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); let val_c = [ONE, ONE, ONE, ZERO]; let init_smt = smt.clone(); smt.insert(key_c.into(), val_c); - assert_insert(&init_smt, key_c, EMPTY_VALUE, val_c, smt.root().into()); + let new_map_entries = [build_node_entry(key_c, val_c, 32)]; + assert_insert(&init_smt, key_c, EMPTY_VALUE, val_c, smt.root().into(), &new_map_entries); } #[test] @@ -183,25 +243,30 @@ fn tsmt_insert_48() { let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); let val_b = [ONE, ONE, ZERO, ZERO]; - // this tests a complex insertion when a leaf moves from depth 16 to depth 48 + // this tests a complex insertion when a leaf moves from depth 16 to depth 48; this moves + // node at depth 16 to depth 48 and inserts a new node at depth 48 let init_smt = smt.clone(); smt.insert(key_b.into(), val_b); - assert_insert(&init_smt, key_b, EMPTY_VALUE, val_b, smt.root().into()); + let new_map_entries = [build_node_entry(key_a, val_a, 48), build_node_entry(key_b, val_b, 48)]; + assert_insert(&init_smt, key_b, EMPTY_VALUE, val_b, smt.root().into(), &new_map_entries); - // update a value under key_a + // update a value under key_a; this inserts one entry into the advice map let init_smt = smt.clone(); let val_a2 = [ONE, ZERO, ZERO, ONE]; smt.insert(key_a.into(), val_a2); - assert_insert(&init_smt, key_a, val_a, val_a2, smt.root().into()); + let new_map_entries = [build_node_entry(key_a, val_a2, 48)]; + assert_insert(&init_smt, key_a, val_a, val_a2, smt.root().into(), &new_map_entries); - // insert a value under key_c which has the same 32-bit prefix as A and B + // insert a value under key_c which has the same 32-bit prefix as A and B; this inserts + // one entry into the advice map let raw_c = 0b00000000_00000000_11111111_11111111_00111111_11111111_11111111_11111111_u64; let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); let val_c = [ONE, ONE, ONE, ZERO]; let init_smt = smt.clone(); smt.insert(key_c.into(), val_c); - assert_insert(&init_smt, key_c, EMPTY_VALUE, val_c, smt.root().into()); + let new_map_entries = [build_node_entry(key_c, val_c, 48)]; + assert_insert(&init_smt, key_c, EMPTY_VALUE, val_c, smt.root().into(), &new_map_entries); } #[test] @@ -225,10 +290,12 @@ fn tsmt_insert_48_from_32() { let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); let val_c = [ONE, ONE, ONE, ZERO]; - // this tests a complex insertion when a leaf moves from depth 32 to depth 48 + // this tests a complex insertion when a leaf moves from depth 32 to depth 48; two new + // entries are added to the advice map let init_smt = smt.clone(); smt.insert(key_c.into(), val_c); - assert_insert(&init_smt, key_c, EMPTY_VALUE, val_c, smt.root().into()); + let new_map_entries = [build_node_entry(key_a, val_a, 48), build_node_entry(key_c, val_c, 48)]; + assert_insert(&init_smt, key_c, EMPTY_VALUE, val_c, smt.root().into(), &new_map_entries); } fn assert_insert( @@ -237,6 +304,7 @@ fn assert_insert( old_value: Word, new_value: Word, new_root: RpoDigest, + new_map_entries: &[AdvMapEntry], ) { let old_root = init_smt.root(); let source = r#" @@ -260,7 +328,7 @@ fn assert_insert( new_value[2].as_int(), new_value[3].as_int(), ]; - let expected_output = [ + let expected_output = stack_top_to_ints(&[ old_value[3].as_int(), old_value[2].as_int(), old_value[1].as_int(), @@ -269,61 +337,37 @@ fn assert_insert( new_root[2].as_int(), new_root[1].as_int(), new_root[0].as_int(), - ]; + ]); let (store, adv_map) = build_advice_inputs(init_smt); - build_test!(source, &initial_stack, &[], store, adv_map).expect_stack(&expected_output); + let process = build_test!(source, &initial_stack, &[], store, adv_map.clone()) + .execute_process() + .unwrap(); + + // check the returned values + let stack = stack_to_ints(&process.stack.trace_state()); + assert_eq!(stack, expected_output); + + // remove the initial key-value pairs from the advice map + let mut new_adv_map = process.advice_provider.map().clone(); + for (key, value) in adv_map.iter() { + let init_value = new_adv_map.remove(key).unwrap(); + assert_eq!(value, &init_value); + } + + // make sure the remaining values in the advice map are the same as expected new entries + assert_eq!(new_adv_map.len(), new_map_entries.len()); + for (key, val) in new_map_entries { + let old_val = new_adv_map.get(key).unwrap(); + assert_eq!(old_val, val); + } } -// TEST HELPERS +// HELPER FUNCTIONS // ================================================================================================ -/// Asserts key/value opens to root for the provided Tiered Sparse Merkle tree. -fn assert_smt_get_opens_correctly(smt: &TieredSmt, key: RpoDigest, value: Word) { - let root = smt.root(); - let source = r#" - use.std::collections::smt - - begin - exec.smt::get - end - "#; - let initial_stack = [ - root[0].as_int(), - root[1].as_int(), - root[2].as_int(), - root[3].as_int(), - key[0].as_int(), - key[1].as_int(), - key[2].as_int(), - key[3].as_int(), - ]; - let expected_output = [ - value[3].as_int(), - value[2].as_int(), - value[1].as_int(), - value[0].as_int(), - root[3].as_int(), - root[2].as_int(), - root[1].as_int(), - root[0].as_int(), - ]; - - let (store, advice_map) = build_advice_inputs(smt); - let advice_stack = []; - build_test!(source, &initial_stack, &advice_stack, store, advice_map.into_iter()) - .expect_stack(&expected_output); -} - -fn build_advice_inputs(smt: &TieredSmt) -> (MerkleStore, Vec<([u8; 32], Vec)>) { - let store = MerkleStore::from(smt); - let advice_map = smt - .upper_leaves() - .map(|(node, key, value)| { - let mut elements = key.as_elements().to_vec(); - elements.extend(&value); - (node.as_bytes(), elements) - }) - .collect::>(); - - (store, advice_map) +fn build_node_entry(key: RpoDigest, value: Word, depth: u8) -> AdvMapEntry { + let digest = Rpo256::merge_in_domain(&[key.into(), value.into()], depth.into()); + let mut elements = key.to_vec(); + elements.extend_from_slice(&value); + (digest.into(), elements) }