From fd973e9ee97444f8589efdee3698e8ba4065ad7f Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Tue, 11 Feb 2025 15:55:54 +0800 Subject: [PATCH] Bus multi interaction (arbitrary number of columns) (#2457) CI fails from my attempt to break this up to several PR, I think it's probably because materializing the helper column isn't compatible with our permutation/lookup -> bus witgen infrastructure. #2458, which is #2457 plus not materializing the helper column in some cases, is passing all CI. Is it possible to merge this with CI failures which will be immediately fixed by #2458? Ready for a final review. Broken up from https://github.com/powdr-labs/powdr/pull/2449 so that this works end to end for arbitrary number of columns (both even and odd). Instead of using and Option for helper columns, which currently bugs out, this version uses an array and forces the case of odd number of bus interactions to materialize a helper_last = multiplicity_last / payloads_last. This has the advantage of passing tests end to end, so we can debug in another PR. One curious case is that CI test for block_to_block_with_bus_composite() failed at the prover stage on the helper column witness constraint before if I use the bus multi version for it. Maybe some updates on witness generation for helper is needed (not on the bus accumulator but on the conversion from non-native permutation/lookup to bus)? @georgwiese --- ast/src/analyzed/display.rs | 9 +- ast/src/analyzed/mod.rs | 11 +- executor/src/witgen/bus_accumulator/mod.rs | 46 ++- .../src/witgen/data_structures/identity.rs | 14 +- pil-analyzer/src/condenser.rs | 28 ++ std/prelude.asm | 7 +- std/protocols/bus.asm | 374 ++++++++++++------ test_data/asm/static_bus_multi.asm | 42 +- 8 files changed, 396 insertions(+), 135 deletions(-) diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index e0f2a3967d..065ad95e7e 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -431,7 +431,7 @@ impl Display for PhantomBusInteractionIdentity { fn fmt(&self, f: &mut Formatter<'_>) -> Result { write!( f, - "Constr::PhantomBusInteraction({}, {}, [{}], {}, [{}], [{}]);", + "Constr::PhantomBusInteraction({}, {}, [{}], {}, [{}], [{}], {});", self.multiplicity, self.bus_id, self.payload.0.iter().map(ToString::to_string).format(", "), @@ -445,6 +445,13 @@ impl Display for PhantomBusInteractionIdentity { .iter() .map(ToString::to_string) .format(", "), + match &self.helper_columns { + Some(helper_columns) => format!( + "Option::Some([{}])", + helper_columns.iter().map(ToString::to_string).format(", ") + ), + None => "Option::None".to_string(), + }, ) } } diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 558f22f666..cc3b1bd9a6 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -1037,10 +1037,15 @@ pub struct PhantomBusInteractionIdentity { pub payload: ExpressionList, pub latch: AlgebraicExpression, pub folded_expressions: ExpressionList, - // Note that in PIL, this is a list of expressions, but we'd - // always expect direct column references, so this is unpacked - // when converting from PIL to this struct. + // Note that in PIL, `accumulator_columns` and + // `helper_columns` are lists of expressions, but we'd + // always expect direct column references, because + // they always materialize as witness columns, + // so they are unpacked when converting from PIL + // to this struct, whereas `folded_expressions` + // can be linear and thus optimized away by pilopt. pub accumulator_columns: Vec, + pub helper_columns: Option>, } impl Children> for PhantomBusInteractionIdentity { diff --git a/executor/src/witgen/bus_accumulator/mod.rs b/executor/src/witgen/bus_accumulator/mod.rs index d08157d706..3d1c5181fa 100644 --- a/executor/src/witgen/bus_accumulator/mod.rs +++ b/executor/src/witgen/bus_accumulator/mod.rs @@ -21,6 +21,8 @@ mod extension_field; mod fp2; mod fp4; +pub type InteractionColumns = (Vec>, Vec>, Vec>); + /// Generates the second-stage columns for the bus accumulator. pub fn generate_bus_accumulator_columns<'a, T>( pil: &'a Analyzed, @@ -127,8 +129,9 @@ impl<'a, T: FieldElement, Ext: ExtensionField + Sync> BusAccumulatorGenerator .bus_interactions .par_iter() .flat_map(|bus_interaction| { - let (folded, acc) = self.interaction_columns(bus_interaction); + let (folded, helper, acc) = self.interaction_columns(bus_interaction); collect_folded_columns(bus_interaction, folded) + .chain(collect_helper_columns(bus_interaction, helper)) .chain(collect_acc_columns(bus_interaction, acc)) .collect::>() }) @@ -179,14 +182,20 @@ impl<'a, T: FieldElement, Ext: ExtensionField + Sync> BusAccumulatorGenerator result } + /// Given a bus interaction and existing witness values, + /// calculates and returns a triple tuple of: + /// - the folded columns (one per bus interaction) + /// - one helper column per pair of bus interactions + /// - the accumulator column (shared by all interactions) fn interaction_columns( &self, bus_interaction: &PhantomBusInteractionIdentity, - ) -> (Vec>, Vec>) { + ) -> InteractionColumns { let intermediate_definitions = self.pil.intermediate_definitions(); let size = self.values.height(); let mut folded_list = vec![Ext::zero(); size]; + let mut helper_list = vec![Ext::zero(); size]; let mut acc_list = vec![Ext::zero(); size]; for i in 0..size { @@ -206,28 +215,40 @@ impl<'a, T: FieldElement, Ext: ExtensionField + Sync> BusAccumulatorGenerator .collect::>(); let folded = self.beta - self.fingerprint(&tuple); + let to_add = folded.inverse() * multiplicity; + + let helper = match bus_interaction.helper_columns { + Some(_) => to_add, + None => Ext::zero(), + }; + let new_acc = match multiplicity.is_zero() { true => current_acc, - false => current_acc + folded.inverse() * multiplicity, + false => current_acc + to_add, }; folded_list[i] = folded; + helper_list[i] = helper; acc_list[i] = new_acc; } // Transpose from row-major to column-major & flatten. let mut folded = vec![Vec::with_capacity(size); Ext::size()]; + let mut helper = vec![Vec::with_capacity(size); Ext::size()]; let mut acc = vec![Vec::with_capacity(size); Ext::size()]; for row_index in 0..size { for (col_index, x) in folded_list[row_index].to_vec().into_iter().enumerate() { folded[col_index].push(x); } + for (col_index, x) in helper_list[row_index].to_vec().into_iter().enumerate() { + helper[col_index].push(x); + } for (col_index, x) in acc_list[row_index].to_vec().into_iter().enumerate() { acc[col_index].push(x); } } - (folded, acc) + (folded, helper, acc) } /// Fingerprints a tuples of field elements, using the pre-computed powers of alpha. @@ -279,3 +300,20 @@ fn collect_acc_columns( .zip_eq(acc) .map(|(column_reference, column)| (column_reference.poly_id, column)) } + +fn collect_helper_columns( + bus_interaction: &PhantomBusInteractionIdentity, + helper: Vec>, +) -> impl Iterator)> { + match &bus_interaction.helper_columns { + Some(helper_columns) => { + let pairs: Vec<_> = helper_columns + .iter() + .zip_eq(helper) + .map(|(column_reference, column)| (column_reference.poly_id, column)) + .collect(); + pairs.into_iter() + } + None => Vec::new().into_iter(), + } +} diff --git a/executor/src/witgen/data_structures/identity.rs b/executor/src/witgen/data_structures/identity.rs index db1142c1de..9197ee3619 100644 --- a/executor/src/witgen/data_structures/identity.rs +++ b/executor/src/witgen/data_structures/identity.rs @@ -356,7 +356,7 @@ mod test { r" namespace main(4); col fixed right_latch = [0, 1]*; - col witness right_selector, left_latch, a, b, multiplicities, folded, acc; + col witness right_selector, left_latch, a, b, multiplicities, folded, acc, helper; {constraint} // Selectors should be binary @@ -420,9 +420,9 @@ namespace main(4); // std::protocols::lookup_via_bus::lookup_send and // std::protocols::lookup_via_bus::lookup_receive. let (send, receive) = get_generated_bus_interaction_pair( - // The folded expressions and accumulator is ignored in both the bus send and receive, so we just use the same. - r"Constr::PhantomBusInteraction(main::left_latch, 42, [main::a], main::left_latch, [main::folded], [main::acc]); - Constr::PhantomBusInteraction(-main::multiplicities, 42, [main::b], main::right_latch, [main::folded], [main::acc]);", + // The folded expressions, accumulator, and helper columns are ignored in both the bus send and receive, so we just use the same. + r"Constr::PhantomBusInteraction(main::left_latch, 42, [main::a], main::left_latch, [main::folded], [main::acc], Option::None); + Constr::PhantomBusInteraction(-main::multiplicities, 42, [main::b], main::right_latch, [main::folded], [main::acc], Option::None);", ); assert_eq!( send.selected_payload.to_string(), @@ -478,9 +478,9 @@ namespace main(4); // std::protocols::permutation_via_bus::permutation_send and // std::protocols::permutation_via_bus::permutation_receive. let (send, receive) = get_generated_bus_interaction_pair( - // The folded expressions and accumulator is ignored in both the bus send and receive, so we just use the same. - r"Constr::PhantomBusInteraction(main::left_latch, 42, [main::a], main::left_latch, [main::folded], [main::acc]); - Constr::PhantomBusInteraction(-(main::right_latch * main::right_selector), 42, [main::b], main::right_latch * main::right_selector, [main::folded], [main::acc]);", + // The folded expressions, accumulator, and helper columns are ignored in both the bus send and receive, so we just use the same. + r"Constr::PhantomBusInteraction(main::left_latch, 42, [main::a], main::left_latch, [main::folded], [main::acc], Option::None); + Constr::PhantomBusInteraction(-(main::right_latch * main::right_selector), 42, [main::b], main::right_latch * main::right_selector, [main::folded], [main::acc], Option::None);", ); assert_eq!( send.selected_payload.to_string(), diff --git a/pil-analyzer/src/condenser.rs b/pil-analyzer/src/condenser.rs index 50c8718504..2fd0b1ded6 100644 --- a/pil-analyzer/src/condenser.rs +++ b/pil-analyzer/src/condenser.rs @@ -819,6 +819,34 @@ fn to_constraint( .collect(), _ => panic!("Expected array, got {:?}", fields[5]), }, + helper_columns: match fields[6].as_ref() { + Value::Enum(enum_value) => { + assert_eq!(enum_value.enum_decl.name, "std::prelude::Option"); + match enum_value.variant { + "None" => None, + "Some" => { + let fields = enum_value.data.as_ref().unwrap(); + assert_eq!(fields.len(), 1); + match fields[0].as_ref() { + Value::Array(fields) => fields + .iter() + .map(|f| match to_expr(f) { + AlgebraicExpression::Reference(reference) => { + assert!(!reference.next); + reference + } + _ => panic!("Expected reference, got {f:?}"), + }) + .collect::>() + .into(), + _ => panic!("Expected array, got {:?}", fields[0]), + } + } + _ => panic!("Expected Some or None, got {0}", enum_value.variant), + } + } + _ => panic!("Expected Enum, got {:?}", fields[6]), + }, } .into(), _ => panic!("Expected constraint but got {constraint}"), diff --git a/std/prelude.asm b/std/prelude.asm index e4e395ee2d..4b9ee07b2f 100644 --- a/std/prelude.asm +++ b/std/prelude.asm @@ -59,7 +59,12 @@ enum Constr { /// Note that this could refer to witness columns, intermediate columns, or /// in-lined expressions. /// - The list of accumulator columns. - PhantomBusInteraction(expr, expr, expr[], expr, expr[], expr[]) + /// - The list of helper columns that are intermediate values + /// (but materialized witnesses) to help calculate + /// the accumulator columns, so that constraints are always bounded to + /// degree 3. Each set of helper columns is always shared by two bus + /// interactions. + PhantomBusInteraction(expr, expr, expr[], expr, expr[], expr[], Option) } /// This is the result of the "$" operator. It can be used as the left and diff --git a/std/protocols/bus.asm b/std/protocols/bus.asm index bd5b015e6a..567942a953 100644 --- a/std/protocols/bus.asm +++ b/std/protocols/bus.asm @@ -19,12 +19,12 @@ use std::field::known_field; use std::field::KnownField; use std::check::panic; -// Helper function. -// Materialized as a witness column for two reasons: -// - It makes sure the constraint degree is independent of the input payload. -// - We can access folded', even if the payload contains next references. -// Note that if all expressions are degree-1 and there is no next reference, -// this is wasteful, but we can't check that here. +/// Helper function. +/// Materialized as a witness column for two reasons: +/// - It makes sure the constraint degree is independent of the input payload. +/// - We can access folded', even if the payload contains next references. +/// Note that if all expressions are degree-1 and there is no next reference, +/// this is wasteful, but we can't check that here. let materialize_folded: -> bool = || match known_field() { Option::Some(KnownField::Goldilocks) => true, Option::Some(KnownField::BabyBear) => true, @@ -39,17 +39,37 @@ let materialize_folded: -> bool = || match known_field() { _ => panic("Unexpected field!") }; +/// Helper function. +/// Implemented as: folded = (beta - fingerprint(id, payload...)); +let create_folded: expr, expr[], Ext, Ext -> Ext = constr |id, payload, alpha, beta| + if materialize_folded() { + let folded = from_array( + array::new(required_extension_size(), + |i| std::prover::new_witness_col_at_stage("folded", 1)) + ); + constrain_eq_ext(folded, sub_ext(beta, fingerprint_with_id_inter(id, payload, alpha))); + folded + } else { + sub_ext(beta, fingerprint_with_id_inter(id, payload, alpha)) + }; + /// Sends the payload (id, payload...) to the bus by adding /// `multiplicity / (beta - fingerprint(id, payload...))` to `acc` /// It is the callers responsibility to properly constrain the multiplicity (e.g. constrain /// it to be boolean) if needed. /// -/// # Arguments: +/// # Arguments are plural for multiple bus interactions. +/// For each bus interaction: /// - id: Interaction Id /// - payload: An array of expressions to be sent to the bus /// - multiplicity: The multiplicity which shows how many times a column will be sent /// - latch: a binary expression which indicates where the multiplicity can be non-zero. -let bus_interaction: expr, expr[], expr, expr -> () = constr |id, payload, multiplicity, latch| { +let bus_multi_interaction: expr[], expr[][], expr[], expr[] -> () = constr |ids, payloads, multiplicities, latches| { + // Check length of inputs + let input_len: int = array::len(ids); + assert(input_len == array::len(payloads), || "inputs ids and payloads have unequal lengths"); + assert(input_len == array::len(multiplicities), || "inputs ids and multiplicities have unequal lengths"); + assert(input_len == array::len(latches), || "inputs ids and latches have unequal lengths"); let extension_field_size = required_extension_size(); @@ -58,22 +78,12 @@ let bus_interaction: expr, expr[], expr, expr -> () = constr |id, payload, multi // Beta is used to update the accumulator. let beta = from_array(array::new(extension_field_size, |i| challenge(0, i + 1 + extension_field_size))); - // Implemented as: folded = (beta - fingerprint(id, payload...)); - let folded = if materialize_folded() { - let folded = from_array( - array::new(extension_field_size, - |i| std::prover::new_witness_col_at_stage("folded", 1)) - ); - constrain_eq_ext(folded, sub_ext(beta, fingerprint_with_id_inter(id, payload, alpha))); - folded - } else { - sub_ext(beta, fingerprint_with_id_inter(id, payload, alpha)) - }; - - let folded_next = next_ext(folded); + // Create folded columns. + let folded_arr = array::new(input_len, |i| create_folded(ids[i], payloads[i], alpha, beta)); // Ext[] + let folded_next_arr = array::map(folded_arr, |folded| next_ext(folded)); // Ext[] - let m_ext = from_base(multiplicity); - let m_ext_next = next_ext(m_ext); + let m_ext_arr = array::map(multiplicities, |multiplicity| from_base(multiplicity)); // Ext[] + let m_ext_next_arr = array::map(m_ext_arr, |m_ext| next_ext(m_ext)); // Ext[] let acc = array::new(extension_field_size, |i| std::prover::new_witness_col_at_stage("acc", 1)); let acc_ext = from_array(acc); @@ -82,25 +92,224 @@ let bus_interaction: expr, expr[], expr, expr -> () = constr |id, payload, multi let is_first: col = std::well_known::is_first; let is_first_next = from_base(is_first'); + // Create helper columns to bound degree to 3 for arbitrary number of bus interactions. + // Each helper processes two bus interactions: + // helper_i = multiplicity_{2*i} / folded_{2*i} + multiplicity_{2*i+1} / folded_{2*i+1} + // Or equivalently when expanded: + // folded_{2*i} * folded_{2*i+1}' * helper_i - folded_{2*i+1} * multiplicity_{2*i} - folded_{2*i} * multiplicity_{2*i+1} = 0 + let helper_arr: expr[][] = array::new( + input_len / 2, + |helper| + array::new( + extension_field_size, + |column| std::prover::new_witness_col_at_stage("helper", 1) + ) + ); + let helper_ext_arr = array::map( // Ext[] (somehow type annotating this will cause a symbol not found error in analyzer) + helper_arr, + |helper| from_array(helper) + ); + let helper_ext_next_arr = array::map( + helper_ext_arr, + |helper_ext| next_ext(helper_ext) + ); + // The expression to constrain. + let helper_expr_arr = array::new( // Ext[] + input_len / 2, + |i| sub_ext( + sub_ext( + mul_ext( + mul_ext(folded_arr[2 * i], folded_arr[2 * i + 1]), + helper_ext_arr[i] + ), + mul_ext(folded_arr[2 * i + 1], m_ext_arr[2 * i]) + ), + mul_ext(folded_arr[2 * i], m_ext_arr[2 * i + 1]) + ) + ); + // Return a flattened array of constraints. (Must use `array::fold` or the compiler won't allow nested Constr[][].) + array::fold(helper_expr_arr, [], |init, helper_expr| constrain_eq_ext(helper_expr, from_base(0))); + // Update rule: - // acc' = acc * (1 - is_first') + multiplicity' / folded' - // or equivalently: - // folded' * (acc' - acc * (1 - is_first')) - multiplicity' = 0 - let update_expr = sub_ext( - mul_ext(folded_next, sub_ext(next_acc, mul_ext(acc_ext, sub_ext(from_base(1), is_first_next)))), m_ext_next + // acc' = acc * (1 - is_first') + helper_0' + helper_1' + ... + // Add up all helper columns. + // Or equivalently: + // acc * (1 - is_first') + helper_0' + helper_1' + ... - acc' = 0 + let update_expr = + sub_ext( + add_ext( + mul_ext( + acc_ext, + sub_ext(from_base(1), is_first_next) + ), + // Sum of all helper columns. + array::fold(helper_ext_next_arr, from_base(0), |sum, helper_ext_next| add_ext(sum, helper_ext_next)) + ), + next_acc + ); + + // In cases where there are odd number of bus interactions, the last bus interaction doesn't need helper column. + // Instead, we have `update_expr` + multiplicity_last' / folded_last' = 0 + // Or equivalently: + // `update_expr` * folded_last' + multiplicity_last' = 0 + let update_expr_final = if input_len % 2 == 1 { + // Odd number of bus interactions + add_ext( + mul_ext( + update_expr, + folded_next_arr[input_len - 1] + ), + m_ext_next_arr[input_len - 1] + ) + } else { + // Even number of bus interactions + update_expr + }; + + // Constrain the accumulator update identity + constrain_eq_ext(update_expr_final, from_base(0)); + + // Add array of phantom bus interactions + array::new( + input_len, + |i| if input_len % 2 == 1 && i == input_len - 1 { + Constr::PhantomBusInteraction( + multiplicities[i], + ids[i], + payloads[i], + latches[i], + unpack_ext_array(folded_arr[i]), + acc, + Option::None + ) + } else { + Constr::PhantomBusInteraction( + multiplicities[i], + ids[i], + payloads[i], + latches[i], + unpack_ext_array(folded_arr[i]), + acc, + Option::Some(helper_arr[i / 2]) + ) + } ); +}; + +/// Compute acc' = acc * (1 - is_first') + multiplicity' / fingerprint(id, payload...), +/// using extension field arithmetic. +/// This is intended to be used as a hint in the extension field case; for the base case +/// automatic witgen is smart enough to figure out the value of the accumulator. +let compute_next_z: expr, expr, expr[], expr, Ext, Ext, Ext -> fe[] = query |is_first, id, payload, multiplicity, acc, alpha, beta| { + + let m_next = eval(multiplicity'); + let m_ext_next = from_base(m_next); + + let is_first_next = eval(is_first'); + let current_acc = if is_first_next == 1 {from_base(0)} else {eval_ext(acc)}; - constrain_eq_ext(update_expr, from_base(0)); + // acc' = current_acc + multiplicity' / folded' + let res = if m_next == 0 { + current_acc + } else { + // Implemented as: folded = (beta - fingerprint(id, payload...)); + // `multiplicity / (beta - fingerprint(id, payload...))` to `acc` + let folded_next = sub_ext(eval_ext(beta), fingerprint_with_id(eval(id'), array::eval(array::next(payload)), alpha)); + add_ext( + current_acc, + mul_ext(m_ext_next, inv_ext(folded_next)) + ) + }; - // Add phantom bus interaction - Constr::PhantomBusInteraction(multiplicity, id, payload, latch, unpack_ext_array(folded), acc); + unpack_ext_array(res) +}; + +/// Helper function. +/// Transpose user interface friendly bus send input format `(expr, expr[], expr)[]` +/// to constraint building friendly bus send input format `expr[], expr[][], expr[]`, i.e. id, payload, multiplicity. +/// This is because Rust-style tuple indexing, e.g. tuple.0, isn't supported yet. +let transpose_bus_send_inputs: (expr, expr[], expr)[] -> (expr[], expr[][], expr[]) = |bus_inputs| { + let ids: expr[] = array::map(bus_inputs, + |bus_input| { + let (id, _, _) = bus_input; + id + } + ); + let payloads: expr[][] = array::map(bus_inputs, + |bus_input| { + let (_, payload, _) = bus_input; + payload + } + ); + let multiplicities: expr[] = array::map(bus_inputs, + |bus_input| { + let (_, _, multiplicity) = bus_input; + multiplicity + } + ); + (ids, payloads, multiplicities) +}; + +/// Convenience function for batching multiple bus sends. +/// Transposes user inputs and then calls the key logic for batch building bus interactions. +let bus_multi_send: (expr, expr[], expr)[] -> () = constr |bus_inputs| { + let (ids, payloads, multiplicities) = transpose_bus_send_inputs(bus_inputs); + // For bus sends, the multiplicity always equals the latch + bus_multi_interaction(ids, payloads, multiplicities, multiplicities); +}; + +/// Helper function. +/// Transpose user interface friendly bus send input format `(expr, expr[], expr, expr)[]` +/// to constraint building friendly bus send input format `expr[], expr[][], expr[], expr[]`, i.e. id, payload, multiplicity, latch. +/// This is because Rust-style tuple indexing, e.g. tuple.0, isn't supported yet. +let transpose_bus_receive_inputs: (expr, expr[], expr, expr)[] -> (expr[], expr[][], expr[], expr[]) = |bus_inputs| { + let ids: expr[] = array::map(bus_inputs, + |bus_input| { + let (id, _, _, _) = bus_input; + id + } + ); + let payloads: expr[][] = array::map(bus_inputs, + |bus_input| { + let (_, payload, _, _) = bus_input; + payload + } + ); + let negated_multiplicities: expr[] = array::map(bus_inputs, + |bus_input| { + let (_, _, multiplicity, _) = bus_input; + -multiplicity + } + ); + let latches: expr[] = array::map(bus_inputs, + |bus_input| { + let (_, _, _, latch) = bus_input; + latch + } + ); + (ids, payloads, negated_multiplicities, latches) +}; + +/// Convenience function for batching multiple bus receives. +/// Transposes user inputs and then calls the key logic for batch building bus interactions. +/// In practice, can also batch bus send and bus receive, but requires knowledge of this function and careful configuration of input parameters. +/// E.g. sending negative multiplicity and multiplicity for "multiplicity" and "latch" parameters for bus sends. +let bus_multi_receive: (expr, expr[], expr, expr)[] -> () = constr |bus_inputs| { + let (ids, payloads, negated_multiplicities, latches) = transpose_bus_receive_inputs(bus_inputs); + bus_multi_interaction(ids, payloads, negated_multiplicities, latches); }; -/// Multi version of `bus_interaction`. -/// Batches two bus interactions. -/// Requires a prove system constraint degree bound of 4 or more (so won't work with our setup of Plonky3). -/// In practice, saves `acc`, `is_first`, `alpha`, and `beta` columns as well as rotated columns thereof. -let bus_multi_interaction: expr, expr[], expr, expr, expr, expr[], expr, expr -> () = constr |id_0, payload_0, multiplicity_0, latch_0, id_1, payload_1, multiplicity_1, latch_1| { +/// Sends the payload (id, payload...) to the bus by adding +/// `multiplicity / (beta - fingerprint(id, payload...))` to `acc` +/// It is the callers responsibility to properly constrain the multiplicity (e.g. constrain +/// it to be boolean) if needed. +/// +/// # Arguments: +/// - id: Interaction Id +/// - payload: An array of expressions to be sent to the bus +/// - multiplicity: The multiplicity which shows how many times a column will be sent +/// - latch: a binary expression which indicates where the multiplicity can be non-zero. +let bus_interaction: expr, expr[], expr, expr -> () = constr |id, payload, multiplicity, latch| { let extension_field_size = required_extension_size(); @@ -110,35 +319,21 @@ let bus_multi_interaction: expr, expr[], expr, expr, expr, expr[], expr, expr -> let beta = from_array(array::new(extension_field_size, |i| challenge(0, i + 1 + extension_field_size))); // Implemented as: folded = (beta - fingerprint(id, payload...)); - let folded_0 = if materialize_folded() { - let folded_0 = from_array( - array::new(extension_field_size, - |i| std::prover::new_witness_col_at_stage("folded_0", 1)) - ); - constrain_eq_ext(folded_0, sub_ext(beta, fingerprint_with_id_inter(id_0, payload_0, alpha))); - folded_0 - } else { - sub_ext(beta, fingerprint_with_id_inter(id_0, payload_0, alpha)) - }; - let folded_1 = if materialize_folded() { - let folded_1 = from_array( + let folded = if materialize_folded() { + let folded = from_array( array::new(extension_field_size, - |i| std::prover::new_witness_col_at_stage("folded_1", 1)) + |i| std::prover::new_witness_col_at_stage("folded", 1)) ); - constrain_eq_ext(folded_1, sub_ext(beta, fingerprint_with_id_inter(id_1, payload_1, alpha))); - folded_1 + constrain_eq_ext(folded, sub_ext(beta, fingerprint_with_id_inter(id, payload, alpha))); + folded } else { - sub_ext(beta, fingerprint_with_id_inter(id_1, payload_1, alpha)) + sub_ext(beta, fingerprint_with_id_inter(id, payload, alpha)) }; - let folded_next_0 = next_ext(folded_0); - let folded_next_1 = next_ext(folded_1); - - let m_ext_0 = from_base(multiplicity_0); - let m_ext_1 = from_base(multiplicity_1); + let folded_next = next_ext(folded); - let m_ext_next_0 = next_ext(m_ext_0); - let m_ext_next_1 = next_ext(m_ext_1); + let m_ext = from_base(multiplicity); + let m_ext_next = next_ext(m_ext); let acc = array::new(extension_field_size, |i| std::prover::new_witness_col_at_stage("acc", 1)); let acc_ext = from_array(acc); @@ -148,75 +343,26 @@ let bus_multi_interaction: expr, expr[], expr, expr, expr, expr[], expr, expr -> let is_first_next = from_base(is_first'); // Update rule: - // acc' = acc * (1 - is_first') + multiplicity_0' / folded_0' + multiplicity_1' / folded_1' + // acc' = acc * (1 - is_first') + multiplicity' / folded' // or equivalently: - // folded_0' * folded_1' * (acc' - acc * (1 - is_first')) - multiplicity_0' * folded_1' - multiplicity_1' * folded_0' = 0 + // folded' * (acc' - acc * (1 - is_first')) - multiplicity' = 0 let update_expr = sub_ext( - sub_ext( - mul_ext( - mul_ext(folded_next_0, folded_next_1), - sub_ext(next_acc, mul_ext(acc_ext, sub_ext(from_base(1), is_first_next))) - ), - mul_ext(m_ext_next_0, folded_next_1) - ), - mul_ext(m_ext_next_1, folded_next_0) + mul_ext(folded_next, sub_ext(next_acc, mul_ext(acc_ext, sub_ext(from_base(1), is_first_next)))), m_ext_next ); constrain_eq_ext(update_expr, from_base(0)); // Add phantom bus interaction - Constr::PhantomBusInteraction(multiplicity_0, id_0, payload_0, latch_0, unpack_ext_array(folded_0), acc); - Constr::PhantomBusInteraction(multiplicity_1, id_1, payload_1, latch_1, unpack_ext_array(folded_1), acc); -}; - -/// Compute acc' = acc * (1 - is_first') + multiplicity' / fingerprint(id, payload...), -/// using extension field arithmetic. -/// This is intended to be used as a hint in the extension field case; for the base case -/// automatic witgen is smart enough to figure out the value of the accumulator. -let compute_next_z: expr, expr, expr[], expr, Ext, Ext, Ext -> fe[] = query |is_first, id, payload, multiplicity, acc, alpha, beta| { - - let m_next = eval(multiplicity'); - let m_ext_next = from_base(m_next); - - let is_first_next = eval(is_first'); - let current_acc = if is_first_next == 1 {from_base(0)} else {eval_ext(acc)}; - - // acc' = current_acc + multiplicity' / folded' - let res = if m_next == 0 { - current_acc - } else { - // Implemented as: folded = (beta - fingerprint(id, payload...)); - // `multiplicity / (beta - fingerprint(id, payload...))` to `acc` - let folded_next = sub_ext(eval_ext(beta), fingerprint_with_id(eval(id'), array::eval(array::next(payload)), alpha)); - add_ext( - current_acc, - mul_ext(m_ext_next, inv_ext(folded_next)) - ) - }; - - unpack_ext_array(res) + Constr::PhantomBusInteraction(multiplicity, id, payload, latch, unpack_ext_array(folded), acc, Option::None); }; -/// Convenience function for bus interaction to send columns +/// Convenience function for single bus interaction to send columns let bus_send: expr, expr[], expr -> () = constr |id, payload, multiplicity| { // For bus sends, the multiplicity always equals the latch bus_interaction(id, payload, multiplicity, multiplicity); }; -/// Convenience function for bus interaction to receive columns +/// Convenience function for single bus interaction to receive columns let bus_receive: expr, expr[], expr, expr -> () = constr |id, payload, multiplicity, latch| { bus_interaction(id, payload, -multiplicity, latch); }; - -/// Convenience function for batching two bus sends. -let bus_multi_send: expr, expr[], expr, expr, expr[], expr -> () = constr |id_0, payload_0, multiplicity_0, id_1, payload_1, multiplicity_1| { - // For bus sends, the multiplicity always equals the latch - bus_multi_interaction(id_0, payload_0, multiplicity_0, multiplicity_0, id_1, payload_1, multiplicity_1, multiplicity_1); -}; - -/// Convenience function for batching two bus receives. -/// In practice, can also batch one bus send and one bus receive, but requires knowledge of this function and careful configuration of input parameters. -/// E.g. sending negative multiplicity and multiplicity for "multiplicity" and "latch" parameters for bus sends. -let bus_multi_receive: expr, expr[], expr, expr, expr, expr[], expr, expr -> () = constr |id_0, payload_0, multiplicity_0, latch_0, id_1, payload_1, multiplicity_1, latch_1| { - bus_multi_interaction(id_0, payload_0, -multiplicity_0, latch_0, id_1, payload_1, -multiplicity_1, latch_1); -}; diff --git a/test_data/asm/static_bus_multi.asm b/test_data/asm/static_bus_multi.asm index c08b3eca72..aa07c1c86e 100644 --- a/test_data/asm/static_bus_multi.asm +++ b/test_data/asm/static_bus_multi.asm @@ -3,6 +3,9 @@ use std::protocols::bus::bus_multi_send; let ADD_BUS_ID = 123; let MUL_BUS_ID = 456; +let SUB_BUS_ID = 789; +let DOUBLE_BUS_ID = 234; +let TRIPLE_BUS_ID = 321; machine Main with degree: 8, @@ -23,14 +26,38 @@ machine Main with std::utils::force_bool(mul_sel); mul_c = mul_a * mul_b; + // Sub block machine + col witness sub_a, sub_b, sub_c, sub_sel; + std::utils::force_bool(sub_sel); + sub_c = sub_a - sub_b; + + // Double block machine + col witness double_a, double_b, double_c, double_sel; + std::utils::force_bool(double_sel); + double_c = 2 * double_a + 2 * double_b; + + // Triple block machine + col witness triple_a, triple_b, triple_c, triple_sel; + std::utils::force_bool(triple_sel); + triple_c = 3 * triple_a + 3 * triple_b; + // Multi bus receive bus_multi_receive( - ADD_BUS_ID, [add_a, add_b, add_c], add_sel, add_sel, - MUL_BUS_ID, [mul_a, mul_b, mul_c], mul_sel, mul_sel + [ + (ADD_BUS_ID, [add_a, add_b, add_c], add_sel, add_sel), + (MUL_BUS_ID, [mul_a, mul_b, mul_c], mul_sel, mul_sel), + (SUB_BUS_ID, [sub_a, sub_b, sub_c], sub_sel, sub_sel), + (DOUBLE_BUS_ID, [double_a, double_b, double_c], double_sel, double_sel), + (TRIPLE_BUS_ID, [triple_a, triple_b, triple_c], triple_sel, triple_sel) + ] ); // Main machine - col fixed is_mul = [0, 1]*; + col fixed is_add = [1, 0, 0, 0, 0]*; + col fixed is_mul = [0, 1, 0, 0, 0]*; + col fixed is_sub = [0, 0, 1, 0, 0]*; + col fixed is_double = [0, 0, 0, 1, 0]*; + col fixed is_triple = [0, 0, 0, 0, 1]*; col fixed x(i) {i * 42}; col fixed y(i) {i + 12345}; col witness z; @@ -39,7 +66,12 @@ machine Main with // a bus send for each receiver, even though at most one send will be // active in each row. bus_multi_send( - MUL_BUS_ID, [x, y, z], is_mul, - ADD_BUS_ID, [x, y, z], 1 - is_mul + [ + (MUL_BUS_ID, [x, y, z], is_mul), + (ADD_BUS_ID, [x, y, z], is_add), + (DOUBLE_BUS_ID, [x, y, z], is_double), + (SUB_BUS_ID, [x, y, z], is_sub), + (TRIPLE_BUS_ID, [x, y, z], is_triple) + ] ); }