From c100ca1079c0b41a3b000ecfded6aabb21d04ade Mon Sep 17 00:00:00 2001 From: Christian Lewe <ch_lewe@agdsn.me> Date: Thu, 29 Aug 2024 11:28:58 +0200 Subject: [PATCH] refactor: Type-check Simplicity values Values have two bit encodings: 1) Compact: witness encoding, IMR computation 2) Padded: values on the Bit Machine The code so far used the compact encoding exclusively, which is incorrect for sum values on the Bit Machine. As a fix, this commit introduces the padded encoding. However, a value can only be encoded with padding if its type is known. To this end, this commit refactors the Value struct so it knows its Simplicity type. Adding types to all values instead of "upgrading" a typeless value to a typed value has several benefits: 1) Bit words, products and options get their type for free when they are constructed. The caller doesn't have to supply additional type info. This covers almost all values that we are using in the code. 2) Values cannot be decoded without type info. Type checking happens implicitly during decoding. The decoded value gets its type for free. 3) Values are used as word constants and as witness data. In both cases, we want to check if the supplied value is of the correct type. There is (almost?) no use case for untyped values. The new API exposes two Boolean iterators (iter_compact / iter_padded) instead of do_each_bit. The iterators are more flexible and can be used in conjunction with BitCollector. Most changes in this commit happen inside value.rs. The rest is renamings of Arc<Value> to Value, etc. --- jets-bench/benches/elements/main.rs | 17 +- jets-bench/src/data_structures.rs | 63 +++--- jets-bench/src/input.rs | 38 ++-- src/analysis.rs | 2 +- src/bit_encoding/bititer.rs | 24 +- src/bit_encoding/decode.rs | 10 +- src/bit_encoding/encode.rs | 27 +-- src/bit_encoding/mod.rs | 2 +- src/bit_machine/mod.rs | 18 +- src/human_encoding/mod.rs | 6 +- src/human_encoding/named_node.rs | 8 +- src/human_encoding/parse/mod.rs | 4 +- src/human_encoding/serialize.rs | 15 +- src/merkle/cmr.rs | 20 +- src/merkle/mod.rs | 3 +- src/node/construct.rs | 4 +- src/node/convert.rs | 14 +- src/node/inner.rs | 6 +- src/node/mod.rs | 37 ++-- src/node/redeem.rs | 21 +- src/node/witness.rs | 35 ++- src/policy/satisfy.rs | 32 ++- src/policy/serialize.rs | 2 +- src/types/arrow.rs | 6 +- src/value.rs | 331 ++++++++++++++++------------ 25 files changed, 396 insertions(+), 349 deletions(-) diff --git a/jets-bench/benches/elements/main.rs b/jets-bench/benches/elements/main.rs index c3ec188e..ac0456f0 100644 --- a/jets-bench/benches/elements/main.rs +++ b/jets-bench/benches/elements/main.rs @@ -7,6 +7,7 @@ use simplicity::elements; use simplicity::jet::elements::ElementsEnv; use simplicity::jet::{Elements, Jet}; use simplicity::types; +use simplicity::types::Final; use simplicity::Value; use simplicity_bench::input::{ self, EqProduct, GenericProduct, InputSample, PrefixBit, Sha256Ctx, UniformBits, @@ -751,36 +752,36 @@ fn bench(c: &mut Criterion) { } // Input to outpoint hash jet - fn outpoint_hash() -> Arc<Value> { + fn outpoint_hash() -> Value { let ctx8 = SimplicityCtx8::with_len(511).value(); let genesis_pegin = genesis_pegin(); let outpoint = elements::OutPoint::sample().value(); Value::product(ctx8, Value::product(genesis_pegin, outpoint)) } - fn asset_amount_hash() -> Arc<Value> { + fn asset_amount_hash() -> Value { let ctx8 = SimplicityCtx8::with_len(511).value(); let asset = confidential::Asset::sample().value(); let amount = confidential::Value::sample().value(); Value::product(ctx8, Value::product(asset, amount)) } - fn nonce_hash() -> Arc<Value> { + fn nonce_hash() -> Value { let ctx8 = SimplicityCtx8::with_len(511).value(); let nonce = confidential::Nonce::sample().value(); Value::product(ctx8, nonce) } - fn annex_hash() -> Arc<Value> { + fn annex_hash() -> Value { let ctx8 = SimplicityCtx8::with_len(511).value(); let annex = if rand::random() { - Value::right(Value::u256(rand::random::<[u8; 32]>())) + Value::some(Value::u256(rand::random::<[u8; 32]>())) } else { - Value::left(Value::unit()) + Value::none(Final::u256()) }; Value::product(ctx8, annex) } - let arr: [(Elements, Arc<dyn Fn() -> Arc<Value>>); 4] = [ + let arr: [(Elements, Arc<dyn Fn() -> Value>); 4] = [ (Elements::OutpointHash, Arc::new(&outpoint_hash)), (Elements::AssetAmountHash, Arc::new(&asset_amount_hash)), (Elements::NonceHash, Arc::new(nonce_hash)), @@ -814,7 +815,7 @@ fn bench(c: &mut Criterion) { } // Operations that use tx input or output index. - fn index_value(bound: u32) -> Arc<Value> { + fn index_value(bound: u32) -> Value { let v = rand::random::<u32>() % bound; Value::u32(v) } diff --git a/jets-bench/src/data_structures.rs b/jets-bench/src/data_structures.rs index 996c08f9..edafeba8 100644 --- a/jets-bench/src/data_structures.rs +++ b/jets-bench/src/data_structures.rs @@ -8,10 +8,9 @@ use simplicity::{ bitcoin, elements, hashes::Hash, hex::FromHex, - types::{self, Type}, + types::Final, BitIter, Error, Value, }; -use std::sync::Arc; /// Engine to compute SHA256 hash function. /// We can't use hashes::sha256::HashEngine because it does not accept @@ -55,21 +54,19 @@ impl SimplicityCtx8 { /// # Panics: /// /// Panics if the length of the slice is >= 2^(n + 1) bytes -pub fn var_len_buf_from_slice(v: &[u8], mut n: usize) -> Result<Arc<Value>, Error> { +pub fn var_len_buf_from_slice(v: &[u8], mut n: usize) -> Result<Value, Error> { // Simplicity consensus rule for n < 16 while reading buffers. assert!(n < 16); assert!(v.len() < (1 << (n + 1))); let mut iter = BitIter::new(v.iter().copied()); - let ctx = types::Context::new(); - let types = Type::powers_of_two(&ctx, n); // size n + 1 let mut res = None; while n > 0 { + let ty = Final::two_two_n(n); let v = if v.len() >= (1 << (n + 1)) { - let ty = &types[n]; - let val = iter.read_value(&ty.final_data().unwrap())?; - Value::right(val) + let val = iter.read_value(&ty)?; + Value::some(val) } else { - Value::left(Value::unit()) + Value::none(ty) }; res = match res { Some(prod) => Some(Value::product(prod, v)), @@ -155,11 +152,11 @@ pub struct SimplicityPoint(pub bitcoin::secp256k1::PublicKey); /// Trait defining how to encode a data structure into a Simplicity value /// This is then used to write these vales into the bit machine. pub trait SimplicityEncode { - fn value(&self) -> Arc<Value>; + fn value(&self) -> Value; } impl SimplicityEncode for SimplicityCtx8 { - fn value(&self) -> Arc<Value> { + fn value(&self) -> Value { let buf_len = self.length % 512; let buf = var_len_buf_from_slice(&self.buffer[..buf_len], 8).unwrap(); let len = Value::u64(self.length as u64); @@ -174,7 +171,7 @@ impl SimplicityEncode for SimplicityCtx8 { } impl SimplicityEncode for elements::OutPoint { - fn value(&self) -> Arc<Value> { + fn value(&self) -> Value { let txid = Value::u256(self.txid.to_byte_array()); let vout = Value::u32(self.vout); Value::product(txid, vout) @@ -182,10 +179,11 @@ impl SimplicityEncode for elements::OutPoint { } impl SimplicityEncode for elements::confidential::Asset { - fn value(&self) -> Arc<Value> { + fn value(&self) -> Value { match self { elements::confidential::Asset::Explicit(a) => { - Value::right(Value::u256(a.into_inner().to_byte_array())) + let left = Final::product(Final::u1(), Final::u256()); + Value::right(left, Value::u256(a.into_inner().to_byte_array())) } elements::confidential::Asset::Confidential(gen) => { let ser = gen.serialize(); @@ -193,7 +191,7 @@ impl SimplicityEncode for elements::confidential::Asset { let x_bytes = (&ser[1..33]).try_into().unwrap(); let x_pt = Value::u256(x_bytes); let y_pt = Value::u1(odd_gen as u8); - Value::left(Value::product(y_pt, x_pt)) + Value::left(Value::product(y_pt, x_pt), Final::u256()) } elements::confidential::Asset::Null => panic!("Tried to encode Null asset"), } @@ -201,15 +199,18 @@ impl SimplicityEncode for elements::confidential::Asset { } impl SimplicityEncode for elements::confidential::Value { - fn value(&self) -> Arc<Value> { + fn value(&self) -> Value { match self { - elements::confidential::Value::Explicit(v) => Value::right(Value::u64(*v)), + elements::confidential::Value::Explicit(v) => { + let left = Final::product(Final::u1(), Final::u256()); + Value::right(left, Value::u64(*v)) + }, elements::confidential::Value::Confidential(v) => { let ser = v.serialize(); let x_bytes = (&ser[1..33]).try_into().unwrap(); let x_pt = Value::u256(x_bytes); let y_pt = Value::u1((ser[0] & 1 == 1) as u8); - Value::left(Value::product(y_pt, x_pt)) + Value::left(Value::product(y_pt, x_pt), Final::u64()) } elements::confidential::Value::Null => panic!("Tried to encode Null value"), } @@ -217,31 +218,33 @@ impl SimplicityEncode for elements::confidential::Value { } impl SimplicityEncode for elements::confidential::Nonce { - fn value(&self) -> Arc<Value> { + fn value(&self) -> Value { + let ty_l = Final::product(Final::u1(), Final::u256()); + let ty_r = Final::u256(); match self { elements::confidential::Nonce::Explicit(n) => { - Value::right(Value::right(Value::u256(*n))) + Value::some(Value::right(ty_l, Value::u256(*n))) } elements::confidential::Nonce::Confidential(n) => { let ser = n.serialize(); let x_bytes = (&ser[1..33]).try_into().unwrap(); let x_pt = Value::u256(x_bytes); let y_pt = Value::u1((ser[0] & 1 == 1) as u8); - Value::right(Value::left(Value::product(y_pt, x_pt))) + Value::some(Value::left(Value::product(y_pt, x_pt), ty_r)) } - elements::confidential::Nonce::Null => Value::left(Value::unit()), + elements::confidential::Nonce::Null => Value::none(Final::sum(ty_l, ty_r)), } } } impl SimplicityEncode for SimplicityFe { - fn value(&self) -> Arc<Value> { + fn value(&self) -> Value { Value::u256(*self.as_inner()) } } impl SimplicityEncode for SimplicityGe { - fn value(&self) -> Arc<Value> { + fn value(&self) -> Value { let ser = match &self { SimplicityGe::ValidPoint(p) => p.serialize_uncompressed(), SimplicityGe::InvalidPoint(x, y) => { @@ -261,7 +264,7 @@ impl SimplicityEncode for SimplicityGe { } impl SimplicityEncode for SimplicityGej { - fn value(&self) -> Arc<Value> { + fn value(&self) -> Value { let ge = self.ge.value(); let z = self.z.value(); Value::product(ge, z) @@ -269,13 +272,13 @@ impl SimplicityEncode for SimplicityGej { } impl SimplicityEncode for SimplicityScalar { - fn value(&self) -> Arc<Value> { + fn value(&self) -> Value { Value::u256(self.0) } } impl SimplicityEncode for SimplicityPoint { - fn value(&self) -> Arc<Value> { + fn value(&self) -> Value { let ser = self.0.serialize(); // compressed let x_bytes = (&ser[1..33]).try_into().unwrap(); let y_pt = Value::u1((ser[0] & 1 == 1) as u8); @@ -407,11 +410,11 @@ impl BenchSample for SimplicityPoint { } // Sample genesis pegin with 50% probability -pub fn genesis_pegin() -> Arc<Value> { +pub fn genesis_pegin() -> Value { if rand::random() { - Value::left(Value::unit()) + Value::none(Final::two_two_n(8)) } else { let genesis_hash = rand::random::<[u8; 32]>(); - Value::right(Value::u256(genesis_hash)) + Value::some(Value::u256(genesis_hash)) } } diff --git a/jets-bench/src/input.rs b/jets-bench/src/input.rs index 2a04a325..b5bd2fb6 100644 --- a/jets-bench/src/input.rs +++ b/jets-bench/src/input.rs @@ -10,11 +10,11 @@ use simplicity::jet::Elements; use simplicity::types::{self, CompleteBound}; use simplicity::Value; -pub fn random_value(ty: &types::Final, rng: &mut ThreadRng) -> Arc<Value> { +pub fn random_value(ty: &types::Final, rng: &mut ThreadRng) -> Value { enum StackItem<'a> { Type(&'a types::Final), - LeftSum, - RightSum, + LeftSum(Arc<types::Final>), + RightSum(Arc<types::Final>), Product, } @@ -27,10 +27,10 @@ pub fn random_value(ty: &types::Final, rng: &mut ThreadRng) -> Arc<Value> { CompleteBound::Unit => value_stack.push(Value::unit()), CompleteBound::Sum(left, right) => { if rng.gen() { - call_stack.push(StackItem::LeftSum); + call_stack.push(StackItem::LeftSum(Arc::clone(right))); call_stack.push(StackItem::Type(left)); } else { - call_stack.push(StackItem::RightSum); + call_stack.push(StackItem::RightSum(Arc::clone(left))); call_stack.push(StackItem::Type(right)); } } @@ -40,13 +40,13 @@ pub fn random_value(ty: &types::Final, rng: &mut ThreadRng) -> Arc<Value> { call_stack.push(StackItem::Type(left)); } }, - StackItem::LeftSum => { + StackItem::LeftSum(right) => { let left = value_stack.pop().unwrap(); - value_stack.push(Value::left(left)); + value_stack.push(Value::left(left, right)); } - StackItem::RightSum => { + StackItem::RightSum(left) => { let right = value_stack.pop().unwrap(); - value_stack.push(Value::right(right)); + value_stack.push(Value::right(left, right)); } StackItem::Product => { let right = value_stack.pop().unwrap(); @@ -411,10 +411,10 @@ pub enum InputSampling { Random, /// A given, fixed bit string (whose length is multiple of 8) /// Worst-case inputs - Fixed(Arc<Value>), + Fixed(Value), /// Custom sampling method, read first src type bits from input /// Useful for cases where we want to sample inputs according to some distributions - Custom(Arc<dyn Fn() -> Arc<Value>>), + Custom(Arc<dyn Fn() -> Value>), } impl InputSampling { @@ -424,19 +424,25 @@ impl InputSampling { src_ty: &types::Final, rng: &mut ThreadRng, ) { - let write_bit = |bit: bool| unsafe { c_writeBit(src_frame, bit) }; + let mut write_bit = |bit: bool| unsafe { c_writeBit(src_frame, bit) }; match self { InputSampling::Random => { let value = random_value(src_ty, rng); - value.do_each_bit(write_bit); + for bit in value.iter_padded() { + write_bit(bit); + } } - InputSampling::Fixed(v) => { - v.do_each_bit(write_bit); + InputSampling::Fixed(value) => { + for bit in value.iter_padded() { + write_bit(bit); + } } InputSampling::Custom(gen_bytes) => { let value = gen_bytes(); - value.do_each_bit(write_bit); + for bit in value.iter_padded() { + write_bit(bit); + } } } } diff --git a/src/analysis.rs b/src/analysis.rs index 25b9ba6d..10b4ebb0 100644 --- a/src/analysis.rs +++ b/src/analysis.rs @@ -317,7 +317,7 @@ impl NodeBounds { NodeBounds { extra_cells: 0, extra_frames: 0, - cost: Cost::OVERHEAD + Cost::of_type(value.len()), + cost: Cost::OVERHEAD + Cost::of_type(value.padded_len()), } } diff --git a/src/bit_encoding/bititer.rs b/src/bit_encoding/bititer.rs index f408fad2..c7e8d811 100644 --- a/src/bit_encoding/bititer.rs +++ b/src/bit_encoding/bititer.rs @@ -9,6 +9,7 @@ //! `Iterator<Item=bool>`. //! +use crate::types::Final; use crate::{decode, types}; use crate::{Cmr, FailEntropy, Value}; use std::sync::Arc; @@ -221,11 +222,11 @@ impl<I: Iterator<Item = u8>> BitIter<I> { } /// Decode a value from bits, based on the given type. - pub fn read_value(&mut self, ty: &types::Final) -> Result<Arc<Value>, EarlyEndOfStreamError> { + pub fn read_value(&mut self, ty: &Final) -> Result<Value, EarlyEndOfStreamError> { enum State<'a> { - ProcessType(&'a types::Final), - DoSumL, - DoSumR, + ProcessType(&'a Final), + DoSumL(Arc<Final>), + DoSumR(Arc<Final>), DoProduct, } @@ -237,10 +238,10 @@ impl<I: Iterator<Item = u8>> BitIter<I> { types::CompleteBound::Unit => result_stack.push(Value::unit()), types::CompleteBound::Sum(ref l, ref r) => { if self.read_bit()? { - stack.push(State::DoSumR); + stack.push(State::DoSumR(Arc::clone(l))); stack.push(State::ProcessType(r)); } else { - stack.push(State::DoSumL); + stack.push(State::DoSumL(Arc::clone(r))); stack.push(State::ProcessType(l)); } } @@ -250,13 +251,13 @@ impl<I: Iterator<Item = u8>> BitIter<I> { stack.push(State::ProcessType(l)); } }, - State::DoSumL => { + State::DoSumL(r) => { let val = result_stack.pop().unwrap(); - result_stack.push(Value::left(val)); + result_stack.push(Value::left(val, r)); } - State::DoSumR => { + State::DoSumR(l) => { let val = result_stack.pop().unwrap(); - result_stack.push(Value::right(val)); + result_stack.push(Value::right(l, val)); } State::DoProduct => { let val_r = result_stack.pop().unwrap(); @@ -265,7 +266,7 @@ impl<I: Iterator<Item = u8>> BitIter<I> { } } } - assert_eq!(result_stack.len(), 1); + debug_assert_eq!(result_stack.len(), 1); Ok(result_stack.pop().unwrap()) } @@ -305,7 +306,6 @@ impl<I: Iterator<Item = u8>> BitIter<I> { } } -#[allow(dead_code)] /// Functionality for Boolean iterators to collect their bits or bytes. pub trait BitCollector: Sized { /// Collect the bits of the iterator into a byte vector and a bit length. diff --git a/src/bit_encoding/decode.rs b/src/bit_encoding/decode.rs index d6042daf..e0fdfa1b 100644 --- a/src/bit_encoding/decode.rs +++ b/src/bit_encoding/decode.rs @@ -129,7 +129,7 @@ enum DecodeNode<J: Jet> { Fail(FailEntropy), Hidden(Cmr), Jet(J), - Word(Arc<Value>), + Word(Value), } impl<'d, J: Jet> DagLike for (usize, &'d [DecodeNode<J>]) { @@ -243,7 +243,9 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>( Hidden(cmr) } DecodeNode::Jet(j) => Node(ArcNode::jet(&inference_context, j)), - DecodeNode::Word(ref w) => Node(ArcNode::const_word(&inference_context, Arc::clone(w))), + DecodeNode::Word(ref w) => { + Node(ArcNode::const_word(&inference_context, w.shallow_clone())) + } }; converted.push(new); } @@ -326,9 +328,9 @@ fn decode_node<I: Iterator<Item = u8>, J: Jet>( pub fn decode_power_of_2<I: Iterator<Item = bool>>( iter: &mut I, exp: usize, -) -> Result<Arc<Value>, Error> { +) -> Result<Value, Error> { struct StackElem { - value: Arc<Value>, + value: Value, width: usize, } diff --git a/src/bit_encoding/encode.rs b/src/bit_encoding/encode.rs index daeec0b8..9469fdbc 100644 --- a/src/bit_encoding/encode.rs +++ b/src/bit_encoding/encode.rs @@ -259,8 +259,8 @@ fn encode_node<W: io::Write, N: node::Marker>( node::Inner::Word(val) => { w.write_bit(true)?; // jet or word w.write_bit(false)?; // word - assert_eq!(val.len().count_ones(), 1); - let depth = val.len().trailing_zeros(); + assert_eq!(val.compact_len().count_ones(), 1); + let depth = val.compact_len().trailing_zeros(); encode_natural(1 + depth as usize, w)?; encode_value(val, w)?; } @@ -289,20 +289,15 @@ where pub fn encode_value<W: io::Write>(value: &Value, w: &mut BitWriter<W>) -> io::Result<usize> { let n_start = w.n_total_written(); - match value { - Value::Unit => {} - Value::Left(left) => { - w.write_bit(false)?; - encode_value(left, w)?; - } - Value::Right(right) => { - w.write_bit(true)?; - encode_value(right, w)?; - } - Value::Product(left, right) => { - encode_value(left, w)?; - encode_value(right, w)?; - } + if let Some(left) = value.as_left() { + w.write_bit(false)?; + encode_value(left, w)?; + } else if let Some(right) = value.as_right() { + w.write_bit(true)?; + encode_value(right, w)?; + } else if let Some((left, right)) = value.as_product() { + encode_value(left, w)?; + encode_value(right, w)?; } Ok(w.n_total_written() - n_start) diff --git a/src/bit_encoding/mod.rs b/src/bit_encoding/mod.rs index 9b3eef1a..6463933d 100644 --- a/src/bit_encoding/mod.rs +++ b/src/bit_encoding/mod.rs @@ -13,5 +13,5 @@ mod bitwriter; pub mod decode; pub mod encode; -pub use bititer::{u2, BitIter, CloseError, EarlyEndOfStreamError}; +pub use bititer::{u2, BitCollector, BitIter, CloseError, EarlyEndOfStreamError}; pub use bitwriter::{write_to_vec, BitWriter}; diff --git a/src/bit_machine/mod.rs b/src/bit_machine/mod.rs index a3d81598..8b56873c 100644 --- a/src/bit_machine/mod.rs +++ b/src/bit_machine/mod.rs @@ -13,7 +13,6 @@ use std::fmt; use std::sync::Arc; use crate::analysis; -use crate::dag::{DagLike, NoSharing}; use crate::jet::{Jet, JetFailed}; use crate::node::{self, RedeemNode}; use crate::types::Final; @@ -53,7 +52,7 @@ impl BitMachine { pub fn test_exec<J: Jet>( program: Arc<crate::node::ConstructNode<J>>, env: &J::Environment, - ) -> Result<Arc<Value>, ExecutionError> { + ) -> Result<Value, ExecutionError> { use crate::node::SimpleFinalizer; let prog = program @@ -172,13 +171,8 @@ impl BitMachine { /// Write a value to the current write frame fn write_value(&mut self, val: &Value) { - for val in val.pre_order_iter::<NoSharing>() { - match val { - Value::Unit => {} - Value::Left(..) => self.write_bit(false), - Value::Right(..) => self.write_bit(true), - Value::Product(..) => {} - } + for bit in val.iter_padded() { + self.write_bit(bit); } } @@ -203,7 +197,7 @@ impl BitMachine { } // Unit value doesn't need extra frame if !input.is_empty() { - self.new_frame(input.len()); + self.new_frame(input.padded_len()); self.write_value(input); self.move_frame(); } @@ -217,7 +211,7 @@ impl BitMachine { &mut self, program: &RedeemNode<J>, env: &J::Environment, - ) -> Result<Arc<Value>, ExecutionError> { + ) -> Result<Value, ExecutionError> { enum CallStack<'a, J: Jet> { Goto(&'a RedeemNode<J>), MoveFrame, @@ -534,7 +528,7 @@ mod tests { cmr_str: &str, amr_str: &str, imr_str: &str, - ) -> Result<Arc<Value>, ExecutionError> { + ) -> Result<Value, ExecutionError> { let prog_hex = prog_bytes.as_hex(); let prog = BitIter::from(prog_bytes); diff --git a/src/human_encoding/mod.rs b/src/human_encoding/mod.rs index c1bfe31b..38140082 100644 --- a/src/human_encoding/mod.rs +++ b/src/human_encoding/mod.rs @@ -213,7 +213,7 @@ impl<J: Jet> Forest<J> { /// Succeeds if the forest contains a "main" root and returns `None` otherwise. pub fn to_witness_node( &self, - witness: &HashMap<Arc<str>, Arc<Value>>, + witness: &HashMap<Arc<str>, Value>, ) -> Option<Arc<WitnessNode<J>>> { let main = self.roots.get("main")?; Some(main.to_witness_node(witness, self.roots())) @@ -230,7 +230,7 @@ mod tests { fn assert_finalize_ok<J: Jet>( s: &str, - witness: &HashMap<Arc<str>, Arc<Value>>, + witness: &HashMap<Arc<str>, Value>, env: &J::Environment, ) { let program = Forest::<J>::parse(s) @@ -245,7 +245,7 @@ mod tests { fn assert_finalize_err<J: Jet>( s: &str, - witness: &HashMap<Arc<str>, Arc<Value>>, + witness: &HashMap<Arc<str>, Value>, env: &J::Environment, err_msg: &'static str, ) { diff --git a/src/human_encoding/named_node.rs b/src/human_encoding/named_node.rs index 8aeccec2..ccb6ac2d 100644 --- a/src/human_encoding/named_node.rs +++ b/src/human_encoding/named_node.rs @@ -110,11 +110,11 @@ impl<J: Jet> NamedCommitNode<J> { pub fn to_witness_node( &self, - witness: &HashMap<Arc<str>, Arc<Value>>, + witness: &HashMap<Arc<str>, Value>, disconnect: &HashMap<Arc<str>, Arc<NamedCommitNode<J>>>, ) -> Arc<WitnessNode<J>> { struct Populator<'a, J: Jet> { - witness_map: &'a HashMap<Arc<str>, Arc<Value>>, + witness_map: &'a HashMap<Arc<str>, Value>, disconnect_map: &'a HashMap<Arc<str>, Arc<NamedCommitNode<J>>>, inference_context: types::Context, phantom: PhantomData<J>, @@ -127,7 +127,7 @@ impl<J: Jet> NamedCommitNode<J> { &mut self, data: &PostOrderIterItem<&Node<Named<Commit<J>>>>, _: &NoWitness, - ) -> Result<Option<Arc<Value>>, Self::Error> { + ) -> Result<Option<Value>, Self::Error> { let name = &data.node.cached_data().name; // We keep the witness nodes without data unpopulated. // Some nodes are pruned later so they don't need to be populated. @@ -175,7 +175,7 @@ impl<J: Jet> NamedCommitNode<J> { &Arc<Node<Witness<J>>>, J, &Option<Arc<WitnessNode<J>>>, - &Option<Arc<Value>>, + &Option<Value>, >, ) -> Result<WitnessData<J>, Self::Error> { let inner = inner diff --git a/src/human_encoding/parse/mod.rs b/src/human_encoding/parse/mod.rs index cef245d0..4ff88097 100644 --- a/src/human_encoding/parse/mod.rs +++ b/src/human_encoding/parse/mod.rs @@ -580,7 +580,7 @@ mod tests { fn assert_cmr_witness<J: Jet>( s: &str, cmr: &str, - witness: &HashMap<Arc<str>, Arc<Value>>, + witness: &HashMap<Arc<str>, Value>, env: &J::Environment, ) { match parse::<J>(s) { @@ -619,7 +619,7 @@ mod tests { } } - fn assert_const<J: Jet>(s: &str, value: Arc<Value>) { + fn assert_const<J: Jet>(s: &str, value: Value) { match parse::<J>(s) { Ok(forest) => { assert_eq!(forest.len(), 1); diff --git a/src/human_encoding/serialize.rs b/src/human_encoding/serialize.rs index 653208e2..adba2441 100644 --- a/src/human_encoding/serialize.rs +++ b/src/human_encoding/serialize.rs @@ -2,12 +2,10 @@ //! Serialization +use crate::bit_encoding::BitCollector; use hex::DisplayHex; use std::fmt; -use crate::dag::{DagLike, NoSharing}; -use crate::Value; - pub struct DisplayWord<'a>(pub &'a crate::Value); impl<'a> fmt::Display for DisplayWord<'a> { @@ -15,15 +13,14 @@ impl<'a> fmt::Display for DisplayWord<'a> { // The default value serialization shows the whole structure of // the value; but for words, the structure is always fixed by the // length, so it is fine to just serialize the bits. - if let Ok(hex) = self.0.try_to_bytes() { + if let Ok(hex) = self.0.iter_compact().try_collect_bytes() { write!(f, "0x{}", hex.as_hex())?; } else { f.write_str("0b")?; - for comb in self.0.pre_order_iter::<NoSharing>() { - match comb { - Value::Left(..) => f.write_str("0")?, - Value::Right(..) => f.write_str("1")?, - _ => {} + for bit in self.0.iter_compact() { + match bit { + false => f.write_str("0")?, + true => f.write_str("1")?, } } } diff --git a/src/merkle/cmr.rs b/src/merkle/cmr.rs index 7e088f5a..6e437c07 100644 --- a/src/merkle/cmr.rs +++ b/src/merkle/cmr.rs @@ -1,7 +1,5 @@ // SPDX-License-Identifier: CC0-1.0 -use std::sync::Arc; - use crate::impl_midstate_wrapper; use crate::jet::Jet; use crate::node::{ @@ -103,13 +101,12 @@ impl Cmr { /// This is equal to the IMR of the equivalent scribe, converted to a CMR in /// the usual way for jets. pub fn const_word(v: &Value) -> Self { - assert_eq!(v.len().count_ones(), 1); - let w = 1 + v.len().trailing_zeros() as usize; + assert_eq!(v.compact_len().count_ones(), 1); + let w = 1 + v.compact_len().trailing_zeros() as usize; let mut cmr_stack = Vec::with_capacity(33); // 1. Compute the CMR for the `scribe` corresponding to this word jet - let mut bit_idx = 0; - v.do_each_bit(|bit| { + for (bit_idx, bit) in v.iter_compact().enumerate() { cmr_stack.push(Cmr::BITS[usize::from(bit)]); let mut j = bit_idx; while j & 1 == 1 { @@ -118,9 +115,7 @@ impl Cmr { cmr_stack.push(Cmr::PAIR_IV.update(left_cmr, right_cmr)); j >>= 1; } - - bit_idx += 1; - }); + } assert_eq!(cmr_stack.len(), 1); let imr_iv = Self::CONST_WORD_IV; @@ -128,7 +123,8 @@ impl Cmr { // 2. Add TMRs to get the pass-two IMR let imr_pass2 = imr_pass1.update(Tmr::unit().into(), Tmr::POWERS_OF_TWO[w - 1].into()); // 3. Convert to a jet CMR - Cmr(bip340_iv(b"Simplicity-Draft\x1fJet")).update_with_weight(v.len() as u64, imr_pass2) + Cmr(bip340_iv(b"Simplicity-Draft\x1fJet")) + .update_with_weight(v.compact_len() as u64, imr_pass2) } #[rustfmt::skip] @@ -349,7 +345,7 @@ impl CoreConstructible for ConstructibleCmr { } } - fn const_word(inference_context: &types::Context, word: Arc<Value>) -> Self { + fn const_word(inference_context: &types::Context, word: Value) -> Self { ConstructibleCmr { cmr: Cmr::const_word(&word), inference_context: inference_context.shallow_clone(), @@ -422,7 +418,7 @@ mod tests { #[test] fn fixed_const_word_cmr() { // Checked against C implementation - let bit0 = Value::left(Value::unit()); + let bit0 = Value::u1(0); #[rustfmt::skip] assert_eq!( Cmr::const_word(&bit0), diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 85492ae3..58a6f746 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -10,6 +10,7 @@ pub mod cmr; pub mod imr; pub mod tmr; +use crate::bit_encoding::BitCollector; use crate::Value; use hashes::{sha256, Hash, HashEngine}; use std::fmt; @@ -43,7 +44,7 @@ impl AsRef<[u8]> for FailEntropy { /// Helper function to compute the "compact value", i.e. the sha256 hash /// of the bits of a given value, which is used in some IMRs and AMRs. fn compact_value(value: &Value) -> [u8; 32] { - let (mut bytes, bit_length) = value.to_bytes_len(); + let (mut bytes, bit_length) = value.iter_compact().collect_bits(); // TODO: Automate hashing once `hashes` supports bit-wise hashing // 1.1 Append single '1' bit diff --git a/src/node/construct.rs b/src/node/construct.rs index e2fd6db6..9bd2e4e7 100644 --- a/src/node/construct.rs +++ b/src/node/construct.rs @@ -256,7 +256,7 @@ impl<J> CoreConstructible for ConstructData<J> { } } - fn const_word(inference_context: &types::Context, word: Arc<Value>) -> Self { + fn const_word(inference_context: &types::Context, word: Value) -> Self { ConstructData { arrow: Arrow::const_word(inference_context, word), phantom: PhantomData, @@ -390,7 +390,7 @@ mod tests { assert_eq!( unit.cmr(), - Arc::<ConstructNode<Core>>::scribe(&ctx, &Value::Unit).cmr() + Arc::<ConstructNode<Core>>::scribe(&ctx, &Value::unit()).cmr() ); assert_eq!( bit0.cmr(), diff --git a/src/node/convert.rs b/src/node/convert.rs index 19c63b23..ff795d64 100644 --- a/src/node/convert.rs +++ b/src/node/convert.rs @@ -154,26 +154,24 @@ pub trait Converter<N: Marker, M: Marker> { /// If it encounters a disconnect node, it simply returns an error. // FIXME we should do type checking, but this would require a method to check // type compatibility between a Value and a type::Final. -pub struct SimpleFinalizer<W: Iterator<Item = Arc<Value>>> { +pub struct SimpleFinalizer<W: Iterator<Item = Value>> { iter: W, } -impl<W: Iterator<Item = Arc<Value>>> SimpleFinalizer<W> { +impl<W: Iterator<Item = Value>> SimpleFinalizer<W> { pub fn new(iter: W) -> Self { SimpleFinalizer { iter } } } -impl<W: Iterator<Item = Arc<Value>>, J: Jet> Converter<Commit<J>, Redeem<J>> - for SimpleFinalizer<W> -{ +impl<W: Iterator<Item = Value>, J: Jet> Converter<Commit<J>, Redeem<J>> for SimpleFinalizer<W> { type Error = crate::Error; fn convert_witness( &mut self, _: &PostOrderIterItem<&CommitNode<J>>, _: &NoWitness, - ) -> Result<Arc<Value>, Self::Error> { + ) -> Result<Value, Self::Error> { self.iter.next().ok_or(crate::Error::NoMoreWitnesses) } @@ -189,12 +187,12 @@ impl<W: Iterator<Item = Arc<Value>>, J: Jet> Converter<Commit<J>, Redeem<J>> fn convert_data( &mut self, data: &PostOrderIterItem<&CommitNode<J>>, - inner: Inner<&Arc<RedeemNode<J>>, J, &Arc<RedeemNode<J>>, &Arc<Value>>, + inner: Inner<&Arc<RedeemNode<J>>, J, &Arc<RedeemNode<J>>, &Value>, ) -> Result<Arc<RedeemData<J>>, Self::Error> { let converted_data = inner .map(|node| node.cached_data()) .map_disconnect(|node| node.cached_data()) - .map_witness(Arc::clone); + .map_witness(Value::shallow_clone); Ok(Arc::new(RedeemData::new( data.node.arrow().shallow_clone(), converted_data, diff --git a/src/node/inner.rs b/src/node/inner.rs index 1e52c660..1da282da 100644 --- a/src/node/inner.rs +++ b/src/node/inner.rs @@ -44,7 +44,7 @@ pub enum Inner<C, J, X, W> { /// Application jet Jet(J), /// Constant word - Word(Arc<Value>), + Word(Value), } impl<C, J: Clone, X, W> Inner<C, J, X, W> { @@ -144,7 +144,7 @@ impl<C, J: Clone, X, W> Inner<C, J, X, W> { Inner::Witness(w) => Inner::Witness(w), Inner::Fail(entropy) => Inner::Fail(*entropy), Inner::Jet(j) => Inner::Jet(j.clone()), - Inner::Word(w) => Inner::Word(Arc::clone(w)), + Inner::Word(w) => Inner::Word(w.shallow_clone()), } } @@ -171,7 +171,7 @@ impl<C, J: Clone, X, W> Inner<C, J, X, W> { Inner::Witness(w) => Inner::Witness(w), Inner::Fail(entropy) => Inner::Fail(entropy), Inner::Jet(j) => Inner::Jet(j), - Inner::Word(ref w) => Inner::Word(Arc::clone(w)), + Inner::Word(ref w) => Inner::Word(w.shallow_clone()), } } diff --git a/src/node/mod.rs b/src/node/mod.rs index 056347e6..84a8643d 100644 --- a/src/node/mod.rs +++ b/src/node/mod.rs @@ -148,7 +148,7 @@ pub trait Constructible<J, X, W>: Inner::Pair(left, right) => Self::pair(left, right), Inner::Disconnect(left, right) => Self::disconnect(left, right), Inner::Fail(entropy) => Ok(Self::fail(inference_context, entropy)), - Inner::Word(ref w) => Ok(Self::const_word(inference_context, Arc::clone(w))), + Inner::Word(ref w) => Ok(Self::const_word(inference_context, w.shallow_clone())), Inner::Jet(j) => Ok(Self::jet(inference_context, j)), Inner::Witness(w) => Ok(Self::witness(inference_context, w)), } @@ -177,7 +177,7 @@ pub trait CoreConstructible: Sized { fn assertr(left: Cmr, right: &Self) -> Result<Self, types::Error>; fn pair(left: &Self, right: &Self) -> Result<Self, types::Error>; fn fail(inference_context: &types::Context, entropy: FailEntropy) -> Self; - fn const_word(inference_context: &types::Context, word: Arc<Value>) -> Self; + fn const_word(inference_context: &types::Context, word: Value) -> Self; /// Accessor for the type inference context used to create the object. fn inference_context(&self) -> &types::Context; @@ -188,23 +188,18 @@ pub trait CoreConstructible: Sized { fn scribe(inference_context: &types::Context, value: &Value) -> Self { let mut stack = vec![]; for data in value.post_order_iter::<NoSharing>() { - match data.node { - Value::Unit => stack.push(Self::unit(inference_context)), - Value::Left(..) => { - let child = stack.pop().unwrap(); - stack.push(Self::injl(&child)); - } - Value::Right(..) => { - let child = stack.pop().unwrap(); - stack.push(Self::injr(&child)); - } - Value::Product(..) => { - let right = stack.pop().unwrap(); - let left = stack.pop().unwrap(); - stack.push( - Self::pair(&left, &right).expect("source of scribe has no constraints"), - ); - } + if data.node.is_unit() { + stack.push(Self::unit(inference_context)); + } else if data.node.as_left().is_some() { + let child = stack.pop().unwrap(); + stack.push(Self::injl(&child)); + } else if data.node.as_right().is_some() { + let child = stack.pop().unwrap(); + stack.push(Self::injr(&child)); + } else if data.node.as_product().is_some() { + let right = stack.pop().unwrap(); + let left = stack.pop().unwrap(); + stack.push(Self::pair(&left, &right).expect("source of scribe has no constraints")); } } assert_eq!(stack.len(), 1); @@ -479,10 +474,10 @@ where }) } - fn const_word(inference_context: &types::Context, value: Arc<Value>) -> Self { + fn const_word(inference_context: &types::Context, value: Value) -> Self { Arc::new(Node { cmr: Cmr::const_word(&value), - data: N::CachedData::const_word(inference_context, Arc::clone(&value)), + data: N::CachedData::const_word(inference_context, value.shallow_clone()), inner: Inner::Word(value), }) } diff --git a/src/node/redeem.rs b/src/node/redeem.rs index ce63bf42..30880419 100644 --- a/src/node/redeem.rs +++ b/src/node/redeem.rs @@ -27,7 +27,7 @@ pub struct Redeem<J> { impl<J: Jet> Marker for Redeem<J> { type CachedData = Arc<RedeemData<J>>; - type Witness = Arc<Value>; + type Witness = Value; type Disconnect = Arc<RedeemNode<J>>; type SharingId = Imr; type Jet = J; @@ -66,7 +66,7 @@ impl<J> std::hash::Hash for RedeemData<J> { } impl<J: Jet> RedeemData<J> { - pub fn new(arrow: FinalArrow, inner: Inner<&Arc<Self>, J, &Arc<Self>, Arc<Value>>) -> Self { + pub fn new(arrow: FinalArrow, inner: Inner<&Arc<Self>, J, &Arc<Self>, Value>) -> Self { let (amr, first_pass_imr, bounds) = match inner { Inner::Iden => ( Amr::iden(&arrow), @@ -190,7 +190,7 @@ impl<J: Jet> RedeemNode<J> { fn convert_witness( &mut self, _: &PostOrderIterItem<&RedeemNode<J>>, - _: &Arc<Value>, + _: &Value, ) -> Result<NoWitness, Self::Error> { Ok(NoWitness) } @@ -234,8 +234,8 @@ impl<J: Jet> RedeemNode<J> { fn convert_witness( &mut self, _: &PostOrderIterItem<&Node<Redeem<J>>>, - witness: &Arc<Value>, - ) -> Result<Option<Arc<Value>>, Self::Error> { + witness: &Value, + ) -> Result<Option<Value>, Self::Error> { Ok(Some(witness.clone())) } @@ -255,7 +255,7 @@ impl<J: Jet> RedeemNode<J> { &Arc<Node<Witness<J>>>, J, &Option<Arc<WitnessNode<J>>>, - &Option<Arc<Value>>, + &Option<Value>, >, ) -> Result<WitnessData<J>, Self::Error> { let inner = inner @@ -296,7 +296,7 @@ impl<J: Jet> RedeemNode<J> { &mut self, data: &PostOrderIterItem<&ConstructNode<J>>, _: &NoWitness, - ) -> Result<Arc<Value>, Self::Error> { + ) -> Result<Value, Self::Error> { let arrow = data.node.data.arrow(); let target_ty = arrow.target.finalize()?; self.bits.read_value(&target_ty).map_err(Error::from) @@ -318,13 +318,13 @@ impl<J: Jet> RedeemNode<J> { fn convert_data( &mut self, data: &PostOrderIterItem<&ConstructNode<J>>, - inner: Inner<&Arc<RedeemNode<J>>, J, &Arc<RedeemNode<J>>, &Arc<Value>>, + inner: Inner<&Arc<RedeemNode<J>>, J, &Arc<RedeemNode<J>>, &Value>, ) -> Result<Arc<RedeemData<J>>, Self::Error> { let arrow = data.node.data.arrow().finalize()?; let converted_data = inner .map(|node| node.cached_data()) .map_disconnect(|node| node.cached_data()) - .map_witness(Arc::clone); + .map_witness(Value::shallow_clone); Ok(Arc::new(RedeemData::new(arrow, converted_data))) } } @@ -379,8 +379,7 @@ impl<J: Jet> RedeemNode<J> { let sharing_iter = self.post_order_iter::<MaxSharing<Redeem<J>>>(); let program_bits = encode::encode_program(self, prog)?; prog.flush_all()?; - let witness_bits = - encode::encode_witness(sharing_iter.into_witnesses().map(Arc::as_ref), witness)?; + let witness_bits = encode::encode_witness(sharing_iter.into_witnesses(), witness)?; witness.flush_all()?; Ok(program_bits + witness_bits) } diff --git a/src/node/witness.rs b/src/node/witness.rs index b9748c58..fc9395fd 100644 --- a/src/node/witness.rs +++ b/src/node/witness.rs @@ -31,7 +31,7 @@ pub struct Witness<J> { impl<J: Jet> Marker for Witness<J> { type CachedData = WitnessData<J>; - type Witness = Option<Arc<Value>>; + type Witness = Option<Value>; type Disconnect = Option<Arc<WitnessNode<J>>>; type SharingId = WitnessId; type Jet = J; @@ -59,7 +59,7 @@ impl<J: Jet> WitnessNode<J> { .as_ref() .map(Arc::clone) .map_disconnect(Option::<Arc<_>>::clone) - .map_witness(Option::<Arc<Value>>::clone), + .map_witness(Option::<Value>::clone), }) } @@ -84,9 +84,9 @@ impl<J: Jet> WitnessNode<J> { fn convert_witness( &mut self, _: &PostOrderIterItem<&WitnessNode<J>>, - wit: &Option<Arc<Value>>, - ) -> Result<Option<Arc<Value>>, Self::Error> { - Ok(Option::<Arc<Value>>::clone(wit)) + wit: &Option<Value>, + ) -> Result<Option<Value>, Self::Error> { + Ok(Option::<Value>::clone(wit)) } fn prune_case( @@ -123,16 +123,11 @@ impl<J: Jet> WitnessNode<J> { fn convert_data( &mut self, data: &PostOrderIterItem<&WitnessNode<J>>, - inner: Inner< - &Arc<WitnessNode<J>>, - J, - &Option<Arc<WitnessNode<J>>>, - &Option<Arc<Value>>, - >, + inner: Inner<&Arc<WitnessNode<J>>, J, &Option<Arc<WitnessNode<J>>>, &Option<Value>>, ) -> Result<WitnessData<J>, Self::Error> { let converted_inner = inner .map(|node| node.cached_data()) - .map_witness(Option::<Arc<Value>>::clone); + .map_witness(Option::<Value>::clone); // This next line does the actual retyping. let mut retyped = WitnessData::from_inner(&self.inference_context, converted_inner)?; @@ -164,10 +159,10 @@ impl<J: Jet> WitnessNode<J> { fn convert_witness( &mut self, _: &PostOrderIterItem<&WitnessNode<J>>, - wit: &Option<Arc<Value>>, - ) -> Result<Arc<Value>, Self::Error> { + wit: &Option<Value>, + ) -> Result<Value, Self::Error> { if let Some(ref wit) = wit { - Ok(Arc::clone(wit)) + Ok(wit.shallow_clone()) } else { Err(Error::IncompleteFinalization) } @@ -189,12 +184,12 @@ impl<J: Jet> WitnessNode<J> { fn convert_data( &mut self, data: &PostOrderIterItem<&WitnessNode<J>>, - inner: Inner<&Arc<RedeemNode<J>>, J, &Arc<RedeemNode<J>>, &Arc<Value>>, + inner: Inner<&Arc<RedeemNode<J>>, J, &Arc<RedeemNode<J>>, &Value>, ) -> Result<Arc<RedeemData<J>>, Self::Error> { let converted_data = inner .map(|node| node.cached_data()) .map_disconnect(|node| node.cached_data()) - .map_witness(Arc::clone); + .map_witness(Value::shallow_clone); Ok(Arc::new(RedeemData::new( data.node.arrow().finalize()?, converted_data, @@ -338,7 +333,7 @@ impl<J> CoreConstructible for WitnessData<J> { } } - fn const_word(inference_context: &types::Context, word: Arc<Value>) -> Self { + fn const_word(inference_context: &types::Context, word: Value) -> Self { WitnessData { arrow: Arrow::const_word(inference_context, word), must_prune: false, @@ -362,8 +357,8 @@ impl<J: Jet> DisconnectConstructible<Option<Arc<WitnessNode<J>>>> for WitnessDat } } -impl<J> WitnessConstructible<Option<Arc<Value>>> for WitnessData<J> { - fn witness(inference_context: &types::Context, witness: Option<Arc<Value>>) -> Self { +impl<J> WitnessConstructible<Option<Value>> for WitnessData<J> { + fn witness(inference_context: &types::Context, witness: Option<Value>) -> Self { WitnessData { arrow: Arrow::witness(inference_context, NoWitness), must_prune: witness.is_none(), diff --git a/src/policy/satisfy.rs b/src/policy/satisfy.rs index ddb85fa8..53c4aed6 100644 --- a/src/policy/satisfy.rs +++ b/src/policy/satisfy.rs @@ -193,7 +193,7 @@ impl<Pk: ToXOnlyPubkey> Policy<Pk> { // this manually. threshold_failed = true; } - witness_bits[idx] = Some(Arc::clone(&b1)); + witness_bits[idx] = Some(b1.shallow_clone()); } for &(idx, _) in &sorted_costs[k..] { nodes[idx] = nodes[idx].pruned(); @@ -230,6 +230,7 @@ impl<Pk: ToXOnlyPubkey> Policy<Pk> { #[cfg(test)] mod tests { use super::*; + use crate::bit_encoding::BitCollector; use crate::dag::{DagLike, NoSharing}; use crate::jet::elements::ElementsEnv; use crate::node::{CoreConstructible, JetConstructible, SimpleFinalizer, WitnessConstructible}; @@ -329,7 +330,7 @@ mod tests { assert!(mac.exec(&program, env).is_err()); } - fn to_witness(program: &RedeemNode<Elements>) -> Vec<&Arc<Value>> { + fn to_witness(program: &RedeemNode<Elements>) -> Vec<&Value> { program .post_order_iter::<NoSharing>() .into_witnesses() @@ -379,7 +380,10 @@ mod tests { let sighash = env.c_tx_env().sighash_all(); let message = secp256k1_zkp::Message::from(sighash); - let signature_bytes = witness[0].try_to_bytes().expect("to bytes"); + let signature_bytes = witness[0] + .iter_padded() + .try_collect_bytes() + .expect("to bytes"); let signature = secp256k1_zkp::schnorr::Signature::from_slice(&signature_bytes).expect("to signature"); assert!(signature.verify(&message, xonly).is_ok()); @@ -399,7 +403,10 @@ mod tests { let witness = to_witness(&program); assert_eq!(1, witness.len()); - let witness_bytes = witness[0].try_to_bytes().expect("to bytes"); + let witness_bytes = witness[0] + .iter_padded() + .try_collect_bytes() + .expect("to bytes"); let witness_preimage = Preimage32::try_from(witness_bytes.as_slice()).expect("to array"); let preimage = *satisfier.preimages.get(&image).unwrap(); assert_eq!(preimage, witness_preimage); @@ -475,7 +482,10 @@ mod tests { assert_eq!(2, witness.len()); for i in 0..2 { - let preimage_bytes = witness[i].try_to_bytes().expect("to bytes"); + let preimage_bytes = witness[i] + .iter_padded() + .try_collect_bytes() + .expect("to bytes"); let witness_preimage = Preimage32::try_from(preimage_bytes.as_slice()).expect("to array"); assert_eq!(preimages[i], &witness_preimage); @@ -516,7 +526,10 @@ mod tests { assert_eq!(2, witness.len()); assert_eq!(Value::u1(bit as u8), *witness[0]); - let preimage_bytes = witness[1].try_to_bytes().expect("to bytes"); + let preimage_bytes = witness[1] + .iter_padded() + .try_collect_bytes() + .expect("to bytes"); let witness_preimage = Preimage32::try_from(preimage_bytes.as_slice()).expect("to array"); assert_eq!(preimages[bit as usize], &witness_preimage); @@ -580,7 +593,10 @@ mod tests { assert_eq!(*witness[witidx], Value::u1(bit.into())); witidx += 1; if bit { - let preimage_bytes = witness[witidx].try_to_bytes().expect("to bytes"); + let preimage_bytes = witness[witidx] + .iter_padded() + .try_collect_bytes() + .expect("to bytes"); let witness_preimage = Preimage32::try_from(preimage_bytes.as_slice()).expect("to array"); assert_eq!(preimages[bit_n], &witness_preimage); @@ -633,7 +649,7 @@ mod tests { let env = ElementsEnv::dummy(); let mut satisfier = get_satisfier(&env); - let mut assert_branch = |witness0: Arc<Value>, witness1: Arc<Value>| { + let mut assert_branch = |witness0: Value, witness1: Value| { let asm_program = serialize::verify_bexp( &Arc::<WitnessNode<Elements>>::pair( &Arc::<WitnessNode<Elements>>::witness(&ctx, Some(witness0.clone())), diff --git a/src/policy/serialize.rs b/src/policy/serialize.rs index ee726455..b324c772 100644 --- a/src/policy/serialize.rs +++ b/src/policy/serialize.rs @@ -281,7 +281,7 @@ mod tests { fn execute_successful( commit: &CommitNode<Elements>, - witness: Vec<Arc<Value>>, + witness: Vec<Value>, env: &ElementsEnv<Arc<elements::Transaction>>, ) -> bool { let finalized = commit diff --git a/src/types/arrow.rs b/src/types/arrow.rs index 8fbee5f6..c9faca16 100644 --- a/src/types/arrow.rs +++ b/src/types/arrow.rs @@ -310,11 +310,11 @@ impl CoreConstructible for Arrow { } } - fn const_word(inference_context: &Context, word: Arc<Value>) -> Self { - let len = word.len(); + fn const_word(inference_context: &Context, word: Value) -> Self { + let len = word.compact_len(); assert!(len > 0, "Words must not be the empty bitstring"); assert!(len.is_power_of_two()); - let depth = word.len().trailing_zeros(); + let depth = len.trailing_zeros(); Arrow { source: Type::unit(inference_context), target: Type::two_two_n(inference_context, depth as usize), diff --git a/src/value.rs b/src/value.rs index b2a128b9..dc17b034 100644 --- a/src/value.rs +++ b/src/value.rs @@ -14,8 +14,15 @@ use std::hash::Hash; use std::sync::Arc; /// A Simplicity value. +#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct Value { + inner: ValueInner, + ty: Arc<Final>, +} + +/// The inner structure of a Simplicity value. #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Value { +enum ValueInner { /// The unit value. /// /// The unit value is the only value of the unit type `1`. @@ -47,145 +54,190 @@ impl<'a> DagLike for &'a Value { } fn as_dag_node(&self) -> Dag<Self> { - match self { - Value::Unit => Dag::Nullary, - Value::Left(child) | Value::Right(child) => Dag::Unary(child), - Value::Product(left, right) => Dag::Binary(left, right), + match &self.inner { + ValueInner::Unit => Dag::Nullary, + ValueInner::Left(child) | ValueInner::Right(child) => Dag::Unary(child), + ValueInner::Product(left, right) => Dag::Binary(left, right), } } } impl Value { + /// Make a cheap copy of the value. + pub fn shallow_clone(&self) -> Self { + Self { + inner: self.inner.clone(), + ty: Arc::clone(&self.ty), + } + } + + /// Access the type of the value. + pub fn ty(&self) -> &Final { + &self.ty + } + /// Create the unit value. - pub fn unit() -> Arc<Self> { - Arc::new(Self::Unit) + pub fn unit() -> Self { + Self { + inner: ValueInner::Unit, + ty: Final::unit(), + } } /// Create a left value that wraps the given `inner` value. - pub fn left(inner: Arc<Self>) -> Arc<Self> { - Arc::new(Value::Left(inner)) + pub fn left(inner: Self, right: Arc<Final>) -> Self { + Self { + ty: Final::sum(Arc::clone(&inner.ty), right), + inner: ValueInner::Left(Arc::new(inner)), + } } /// Create a right value that wraps the given `inner` value. - pub fn right(inner: Arc<Self>) -> Arc<Self> { - Arc::new(Value::Right(inner)) + pub fn right(left: Arc<Final>, inner: Self) -> Self { + Self { + ty: Final::sum(left, Arc::clone(&inner.ty)), + inner: ValueInner::Right(Arc::new(inner)), + } } /// Create a product value that wraps the given `left` and `right` values. - pub fn product(left: Arc<Self>, right: Arc<Self>) -> Arc<Self> { - Arc::new(Value::Product(left, right)) + pub fn product(left: Self, right: Self) -> Self { + Self { + ty: Final::product(Arc::clone(&left.ty), Arc::clone(&right.ty)), + inner: ValueInner::Product(Arc::new(left), Arc::new(right)), + } } - /// The length, in bits, of the value when encoded in the Bit Machine - pub fn len(&self) -> usize { - self.pre_order_iter::<NoSharing>() - .filter(|inner| matches!(inner, Value::Left(_) | Value::Right(_))) - .count() + /// Create a none value. + pub fn none(right: Arc<Final>) -> Self { + Self { + ty: Final::sum(Final::unit(), right), + inner: ValueInner::Left(Arc::new(Value::unit())), + } + } + + /// Create a some value. + pub fn some(inner: Self) -> Self { + Self { + ty: Final::sum(Final::unit(), Arc::clone(&inner.ty)), + inner: ValueInner::Right(Arc::new(inner)), + } + } + + /// Return the bit length of the value in compact encoding. + pub fn compact_len(&self) -> usize { + self.iter_compact().count() + } + + /// Return the bit length of the value in padded encoding. + pub fn padded_len(&self) -> usize { + self.iter_padded().count() } /// Check if the value is a nested product of units. /// In this case, the value contains no information. pub fn is_empty(&self) -> bool { - self.len() == 0 + self.pre_order_iter::<NoSharing>() + .all(|value| matches!(&value.inner, ValueInner::Unit | ValueInner::Product(..))) } /// Check if the value is a unit. pub fn is_unit(&self) -> bool { - matches!(self, Value::Unit) + matches!(&self.inner, ValueInner::Unit) } /// Access the inner value of a left sum value. pub fn as_left(&self) -> Option<&Self> { - match self { - Value::Left(inner) => Some(inner.as_ref()), + match &self.inner { + ValueInner::Left(inner) => Some(inner.as_ref()), _ => None, } } /// Access the inner value of a right sum value. pub fn as_right(&self) -> Option<&Self> { - match self { - Value::Right(inner) => Some(inner.as_ref()), + match &self.inner { + ValueInner::Right(inner) => Some(inner.as_ref()), _ => None, } } /// Access the inner values of a product value. pub fn as_product(&self) -> Option<(&Self, &Self)> { - match self { - Value::Product(left, right) => Some((left.as_ref(), right.as_ref())), + match &self.inner { + ValueInner::Product(left, right) => Some((left.as_ref(), right.as_ref())), _ => None, } } /// Encode a single bit as a value. Will panic if the input is out of range - pub fn u1(n: u8) -> Arc<Self> { + pub fn u1(n: u8) -> Self { match n { - 0 => Value::left(Value::unit()), - 1 => Value::right(Value::unit()), + 0 => Self::left(Self::unit(), Final::unit()), + 1 => Self::right(Final::unit(), Self::unit()), x => panic!("{} out of range for Value::u1", x), } } /// Encode a two-bit number as a value. Will panic if the input is out of range - pub fn u2(n: u8) -> Arc<Self> { + pub fn u2(n: u8) -> Self { let b0 = (n & 2) / 2; let b1 = n & 1; assert!(n <= 3, "{} out of range for Value::u2", n); - Value::product(Value::u1(b0), Value::u1(b1)) + Self::product(Self::u1(b0), Self::u1(b1)) } /// Encode a four-bit number as a value. Will panic if the input is out of range - pub fn u4(n: u8) -> Arc<Self> { + pub fn u4(n: u8) -> Self { let w0 = (n & 12) / 4; let w1 = n & 3; assert!(n <= 15, "{} out of range for Value::u2", n); - Value::product(Value::u2(w0), Value::u2(w1)) + Self::product(Self::u2(w0), Self::u2(w1)) } /// Encode an eight-bit number as a value - pub fn u8(n: u8) -> Arc<Self> { + pub fn u8(n: u8) -> Self { let w0 = n >> 4; let w1 = n & 0xf; - Value::product(Value::u4(w0), Value::u4(w1)) + Self::product(Self::u4(w0), Self::u4(w1)) } /// Encode a 16-bit number as a value - pub fn u16(n: u16) -> Arc<Self> { + pub fn u16(n: u16) -> Self { let w0 = (n >> 8) as u8; let w1 = (n & 0xff) as u8; - Value::product(Value::u8(w0), Value::u8(w1)) + Self::product(Self::u8(w0), Self::u8(w1)) } /// Encode a 32-bit number as a value - pub fn u32(n: u32) -> Arc<Self> { + pub fn u32(n: u32) -> Self { let w0 = (n >> 16) as u16; let w1 = (n & 0xffff) as u16; - Value::product(Value::u16(w0), Value::u16(w1)) + Self::product(Self::u16(w0), Self::u16(w1)) } /// Encode a 64-bit number as a value - pub fn u64(n: u64) -> Arc<Self> { + pub fn u64(n: u64) -> Self { let w0 = (n >> 32) as u32; let w1 = (n & 0xffff_ffff) as u32; - Value::product(Value::u32(w0), Value::u32(w1)) + Self::product(Self::u32(w0), Self::u32(w1)) } /// Encode a 128-bit number as a value - pub fn u128(n: u128) -> Arc<Self> { + pub fn u128(n: u128) -> Self { let w0 = (n >> 64) as u64; let w1 = n as u64; // Cast safety: picking last 64 bits - Value::product(Value::u64(w0), Value::u64(w1)) + Self::product(Self::u64(w0), Self::u64(w1)) } /// Create a value from 32 bytes. - pub fn u256(bytes: [u8; 32]) -> Arc<Self> { - Value::from_byte_array(bytes) + pub fn u256(bytes: [u8; 32]) -> Self { + Self::from_byte_array(bytes) } /// Create a value from 64 bytes. - pub fn u512(bytes: [u8; 64]) -> Arc<Self> { - Value::from_byte_array(bytes) + pub fn u512(bytes: [u8; 64]) -> Self { + Self::from_byte_array(bytes) } /// Create a value from a byte array. @@ -193,7 +245,7 @@ impl Value { /// ## Panics /// /// The array length is not a power of two. - pub fn from_byte_array<const N: usize>(bytes: [u8; N]) -> Arc<Self> { + pub fn from_byte_array<const N: usize>(bytes: [u8; N]) -> Self { assert!(N.is_power_of_two(), "Array length must be a power of two"); let mut values: VecDeque<_> = bytes.into_iter().map(Value::u8).collect(); @@ -210,95 +262,28 @@ impl Value { values.into_iter().next().unwrap() } - /// Execute function `f` on each bit of the encoding of the value. - pub fn do_each_bit<F>(&self, mut f: F) - where - F: FnMut(bool), - { - for val in self.pre_order_iter::<NoSharing>() { - match val { - Value::Unit => {} - Value::Left(..) => f(false), - Value::Right(..) => f(true), - Value::Product(..) => {} - } - } - } - - /// Encode value as big-endian byte string. - /// Fails if underlying bit string has length not divisible by 8 - pub fn try_to_bytes(&self) -> Result<Vec<u8>, &'static str> { - let (bytes, bit_length) = self.to_bytes_len(); - - if bit_length % 8 == 0 { - Ok(bytes) - } else { - Err("Length of bit string that encodes this value is not divisible by 8!") - } + /// Return an iterator over the compact bit encoding of the value. + /// + /// This encoding is used for writing witness data and for computing IMRs. + pub fn iter_compact(&self) -> impl Iterator<Item = bool> + '_ { + self.pre_order_iter::<NoSharing>() + .filter_map(|value| match &value.inner { + ValueInner::Left(..) => Some(false), + ValueInner::Right(..) => Some(true), + _ => None, + }) } - /// Encode value as big-endian byte string. - /// Trailing zeroes are added as padding if underlying bit string has length not divisible by 8. - /// The length of said bit string is returned as second argument - pub fn to_bytes_len(&self) -> (Vec<u8>, usize) { - let mut bytes = vec![]; - let mut unfinished_byte = Vec::with_capacity(8); - let update_bytes = |bit: bool| { - unfinished_byte.push(bit); - - if unfinished_byte.len() == 8 { - bytes.push( - unfinished_byte - .iter() - .fold(0, |acc, &b| acc * 2 + u8::from(b)), - ); - unfinished_byte.clear(); - } - }; - - self.do_each_bit(update_bytes); - let bit_length = bytes.len() * 8 + unfinished_byte.len(); - - if !unfinished_byte.is_empty() { - unfinished_byte.resize(8, false); - bytes.push( - unfinished_byte - .iter() - .fold(0, |acc, &b| acc * 2 + u8::from(b)), - ); - } - - (bytes, bit_length) + /// Return an iterator over the padded bit encoding of the value. + /// + /// This encoding is used to represent the value in the Bit Machine. + pub fn iter_padded(&self) -> impl Iterator<Item = bool> + '_ { + PaddedBitsIter::new(self) } /// Check if the value is of the given type. pub fn is_of_type(&self, ty: &Final) -> bool { - let mut stack = vec![(self, ty)]; - - while let Some((value, ty)) = stack.pop() { - if ty.is_unit() { - if !value.is_unit() { - return false; - } - } else if let Some((ty_l, ty_r)) = ty.as_sum() { - if let Some(value_l) = value.as_left() { - stack.push((value_l, ty_l)); - } else if let Some(value_r) = value.as_right() { - stack.push((value_r, ty_r)); - } else { - return false; - } - } else if let Some((ty_l, ty_r)) = ty.as_product() { - if let Some((value_l, value_r)) = value.as_product() { - stack.push((value_r, ty_r)); - stack.push((value_l, ty_l)); - } else { - return false; - } - } - } - - true + self.ty.as_ref() == ty } } @@ -311,25 +296,28 @@ impl fmt::Debug for Value { impl fmt::Display for Value { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { for data in self.verbose_pre_order_iter::<NoSharing>(None) { - match data.node { - Value::Unit => { + match &data.node.inner { + ValueInner::Unit => { if data.n_children_yielded == 0 - && !matches!(data.parent, Some(Value::Left(_)) | Some(Value::Right(_))) + && !matches!( + data.parent.map(|value| &value.inner), + Some(ValueInner::Left(_)) | Some(ValueInner::Right(_)) + ) { f.write_str("ε")?; } } - Value::Left(..) => { + ValueInner::Left(..) => { if data.n_children_yielded == 0 { f.write_str("0")?; } } - Value::Right(..) => { + ValueInner::Right(..) => { if data.n_children_yielded == 0 { f.write_str("1")?; } } - Value::Product(..) => match data.n_children_yielded { + ValueInner::Product(..) => match data.n_children_yielded { 0 => f.write_str("(")?, 1 => f.write_str(",")?, 2 => f.write_str(")")?, @@ -341,6 +329,61 @@ impl fmt::Display for Value { } } +/// An iterator over the bits of the padded encoding of a [`Value`]. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +struct PaddedBitsIter<'a> { + stack: Vec<&'a Value>, + next_padding: Option<usize>, +} + +impl<'a> PaddedBitsIter<'a> { + /// Create an iterator over the bits of the padded encoding of the `value`. + pub fn new(value: &'a Value) -> Self { + Self { + stack: vec![value], + next_padding: None, + } + } +} + +impl<'a> Iterator for PaddedBitsIter<'a> { + type Item = bool; + + fn next(&mut self) -> Option<Self::Item> { + match self.next_padding { + Some(0) => { + self.next_padding = None; + } + Some(n) => { + self.next_padding = Some(n - 1); + return Some(false); + } + None => {} + } + + while let Some(value) = self.stack.pop() { + if value.is_unit() { + // NOP + } else if let Some(l_value) = value.as_left() { + let (l_ty, r_ty) = value.ty.as_sum().unwrap(); + self.stack.push(l_value); + self.next_padding = Some(l_ty.pad_left(r_ty)); + return Some(false); + } else if let Some(r_value) = value.as_right() { + let (l_ty, r_ty) = value.ty.as_sum().unwrap(); + self.stack.push(r_value); + self.next_padding = Some(l_ty.pad_right(r_ty)); + return Some(true); + } else if let Some((l_value, r_value)) = value.as_product() { + self.stack.push(r_value); + self.stack.push(l_value); + } + } + + None + } +} + #[cfg(test)] mod tests { use super::*; @@ -359,10 +402,16 @@ mod tests { fn is_of_type() { let value_typename = [ (Value::unit(), TypeName(b"1")), - (Value::left(Value::unit()), TypeName(b"+11")), - (Value::right(Value::unit()), TypeName(b"+11")), - (Value::left(Value::unit()), TypeName(b"+1h")), - (Value::right(Value::unit()), TypeName(b"+h1")), + (Value::left(Value::unit(), Final::unit()), TypeName(b"+11")), + (Value::right(Final::unit(), Value::unit()), TypeName(b"+11")), + ( + Value::left(Value::unit(), Final::two_two_n(8)), + TypeName(b"+1h"), + ), + ( + Value::right(Final::two_two_n(8), Value::unit()), + TypeName(b"+h1"), + ), ( Value::product(Value::unit(), Value::unit()), TypeName(b"*11"),