From 72d2c6285e05063b5f25c80d6b5c6f2f3613059f Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Fri, 28 Jun 2024 21:39:45 +0000 Subject: [PATCH 1/4] types: refactor arrow module The arrow module has a ton of `for_*` methods which duplicate the methods in the Constructible traits. Most of these are purely redundant and can be deleted. A couple provide some code reuse, and are demoted to be non-`pub`. Smoe variable names changed in the redundant code, so this will likely not appear as a code-move in the diff, even though it is one. You can just skim this commit. --- src/types/arrow.rs | 232 ++++++++++++++++----------------------------- 1 file changed, 81 insertions(+), 151 deletions(-) diff --git a/src/types/arrow.rs b/src/types/arrow.rs index c18dc9e0..08ba86f7 100644 --- a/src/types/arrow.rs +++ b/src/types/arrow.rs @@ -74,135 +74,14 @@ impl Arrow { }) } - /// Create a unification arrow for a fresh `unit` combinator - pub fn for_unit() -> Self { - Arrow { - source: Type::free(new_name("unit_src_")), - target: Type::unit(), - } - } - - /// Create a unification arrow for a fresh `iden` combinator - pub fn for_iden() -> Self { - // Throughout this module, when two types are the same, we reuse a - // pointer to them rather than creating distinct types and unifying - // them. This theoretically could lead to more confusing errors for - // the user during type inference, but in practice type inference - // is completely opaque and there's no harm in making it moreso. - let new = Type::free(new_name("iden_src_")); - Arrow { - source: new.shallow_clone(), - target: new, - } - } - - /// Create a unification arrow for a fresh `witness` combinator - pub fn for_witness() -> Self { - Arrow { - source: Type::free(new_name("witness_src_")), - target: Type::free(new_name("witness_tgt_")), - } - } - - /// Create a unification arrow for a fresh `fail` combinator - pub fn for_fail() -> Self { - Arrow { - source: Type::free(new_name("fail_src_")), - target: Type::free(new_name("fail_tgt_")), - } - } - - /// Create a unification arrow for a fresh jet combinator - pub fn for_jet(jet: J) -> Self { - Arrow { - source: jet.source_ty().to_type(), - target: jet.target_ty().to_type(), - } - } - - /// Create a unification arrow for a fresh const-word combinator - pub fn for_const_word(word: &Value) -> Self { - let len = word.len(); - assert!(len > 0, "Words must not be the empty bitstring"); - assert!(len.is_power_of_two()); - let depth = word.len().trailing_zeros(); - Arrow { - source: Type::unit(), - target: Type::two_two_n(depth as usize), - } - } - - /// Create a unification arrow for a fresh `injl` combinator - pub fn for_injl(child_arrow: &Arrow) -> Self { - Arrow { - source: child_arrow.source.shallow_clone(), - target: Type::sum( - child_arrow.target.shallow_clone(), - Type::free(new_name("injl_tgt_")), - ), - } - } - - /// Create a unification arrow for a fresh `injr` combinator - pub fn for_injr(child_arrow: &Arrow) -> Self { - Arrow { - source: child_arrow.source.shallow_clone(), - target: Type::sum( - Type::free(new_name("injr_tgt_")), - child_arrow.target.shallow_clone(), - ), - } - } - - /// Create a unification arrow for a fresh `take` combinator - pub fn for_take(child_arrow: &Arrow) -> Self { - Arrow { - source: Type::product( - child_arrow.source.shallow_clone(), - Type::free(new_name("take_src_")), - ), - target: child_arrow.target.shallow_clone(), - } - } - - /// Create a unification arrow for a fresh `drop` combinator - pub fn for_drop(child_arrow: &Arrow) -> Self { + /// Same as [`Self::clone`] but named to make it clearer that this is cheap + pub fn shallow_clone(&self) -> Self { Arrow { - source: Type::product( - Type::free(new_name("drop_src_")), - child_arrow.source.shallow_clone(), - ), - target: child_arrow.target.shallow_clone(), + source: self.source.shallow_clone(), + target: self.target.shallow_clone(), } } - /// Create a unification arrow for a fresh `pair` combinator - pub fn for_pair(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { - lchild_arrow.source.unify( - &rchild_arrow.source, - "pair combinator: left source = right source", - )?; - Ok(Arrow { - source: lchild_arrow.source.shallow_clone(), - target: Type::product( - lchild_arrow.target.shallow_clone(), - rchild_arrow.target.shallow_clone(), - ), - }) - } - - /// Create a unification arrow for a fresh `comp` combinator - pub fn for_comp(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { - lchild_arrow.target.unify( - &rchild_arrow.source, - "comp combinator: left target = right source", - )?; - Ok(Arrow { - source: lchild_arrow.source.shallow_clone(), - target: rchild_arrow.target.shallow_clone(), - }) - } - /// Create a unification arrow for a fresh `case` combinator /// /// Either child may be `None`, in which case the combinator is assumed to be @@ -211,10 +90,7 @@ impl Arrow { /// /// If neither child is provided, this function will not raise an error; it /// is the responsibility of the caller to detect this case and error elsewhere. - pub fn for_case( - lchild_arrow: Option<&Arrow>, - rchild_arrow: Option<&Arrow>, - ) -> Result { + fn for_case(lchild_arrow: Option<&Arrow>, rchild_arrow: Option<&Arrow>) -> Result { let a = Type::free(new_name("case_a_")); let b = Type::free(new_name("case_b_")); let c = Type::free(new_name("case_c_")); @@ -247,8 +123,8 @@ impl Arrow { }) } - /// Create a unification arrow for a fresh `comp` combinator - pub fn for_disconnect(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { + /// Helper function to combine code for the two `DisconnectConstructible` impls for [`Arrow`]. + fn for_disconnect(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { let a = Type::free(new_name("disconnect_a_")); let b = Type::free(new_name("disconnect_b_")); let c = rchild_arrow.source.shallow_clone(); @@ -272,43 +148,76 @@ impl Arrow { target: prod_b_d, }) } - - /// Same as [`Self::clone`] but named to make it clearer that this is cheap - pub fn shallow_clone(&self) -> Self { - Arrow { - source: self.source.shallow_clone(), - target: self.target.shallow_clone(), - } - } } impl CoreConstructible for Arrow { fn iden() -> Self { - Self::for_iden() + // Throughout this module, when two types are the same, we reuse a + // pointer to them rather than creating distinct types and unifying + // them. This theoretically could lead to more confusing errors for + // the user during type inference, but in practice type inference + // is completely opaque and there's no harm in making it moreso. + let new = Type::free(new_name("iden_src_")); + Arrow { + source: new.shallow_clone(), + target: new, + } } fn unit() -> Self { - Self::for_unit() + Arrow { + source: Type::free(new_name("unit_src_")), + target: Type::unit(), + } } fn injl(child: &Self) -> Self { - Self::for_injl(child) + Arrow { + source: child.source.shallow_clone(), + target: Type::sum( + child.target.shallow_clone(), + Type::free(new_name("injl_tgt_")), + ), + } } fn injr(child: &Self) -> Self { - Self::for_injr(child) + Arrow { + source: child.source.shallow_clone(), + target: Type::sum( + Type::free(new_name("injr_tgt_")), + child.target.shallow_clone(), + ), + } } fn take(child: &Self) -> Self { - Self::for_take(child) + Arrow { + source: Type::product( + child.source.shallow_clone(), + Type::free(new_name("take_src_")), + ), + target: child.target.shallow_clone(), + } } fn drop_(child: &Self) -> Self { - Self::for_drop(child) + Arrow { + source: Type::product( + Type::free(new_name("drop_src_")), + child.source.shallow_clone(), + ), + target: child.target.shallow_clone(), + } } fn comp(left: &Self, right: &Self) -> Result { - Self::for_comp(left, right) + left.target + .unify(&right.source, "comp combinator: left target = right source")?; + Ok(Arrow { + source: left.source.shallow_clone(), + target: right.target.shallow_clone(), + }) } fn case(left: &Self, right: &Self) -> Result { @@ -324,15 +233,30 @@ impl CoreConstructible for Arrow { } fn pair(left: &Self, right: &Self) -> Result { - Self::for_pair(left, right) + left.source + .unify(&right.source, "pair combinator: left source = right source")?; + Ok(Arrow { + source: left.source.shallow_clone(), + target: Type::product(left.target.shallow_clone(), right.target.shallow_clone()), + }) } fn fail(_: crate::FailEntropy) -> Self { - Self::for_fail() + Arrow { + source: Type::free(new_name("fail_src_")), + target: Type::free(new_name("fail_tgt_")), + } } fn const_word(word: Arc) -> Self { - Self::for_const_word(&word) + let len = word.len(); + assert!(len > 0, "Words must not be the empty bitstring"); + assert!(len.is_power_of_two()); + let depth = word.len().trailing_zeros(); + Arrow { + source: Type::unit(), + target: Type::two_two_n(depth as usize), + } } } @@ -365,12 +289,18 @@ impl DisconnectConstructible> for Arrow { impl JetConstructible for Arrow { fn jet(jet: J) -> Self { - Self::for_jet(jet) + Arrow { + source: jet.source_ty().to_type(), + target: jet.target_ty().to_type(), + } } } impl WitnessConstructible for Arrow { fn witness(_: W) -> Self { - Self::for_witness() + Arrow { + source: Type::free(new_name("witness_src_")), + target: Type::free(new_name("witness_tgt_")), + } } } From dcba2f4e68e42f5ebfba66c29283ecbcd1740a3d Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Fri, 28 Jun 2024 22:32:49 +0000 Subject: [PATCH 2/4] named_node: name the fields in the Populator struct Refactor only. --- src/human_encoding/named_node.rs | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/human_encoding/named_node.rs b/src/human_encoding/named_node.rs index fec35c17..a4d208a6 100644 --- a/src/human_encoding/named_node.rs +++ b/src/human_encoding/named_node.rs @@ -113,11 +113,11 @@ impl NamedCommitNode { witness: &HashMap, Arc>, disconnect: &HashMap, Arc>>, ) -> Arc> { - struct Populator<'a, J: Jet>( - &'a HashMap, Arc>, - &'a HashMap, Arc>>, - PhantomData, - ); + struct Populator<'a, J: Jet> { + witness_map: &'a HashMap, Arc>, + disconnect_map: &'a HashMap, Arc>>, + phantom: PhantomData, + } impl<'a, J: Jet> Converter>, Witness> for Populator<'a, J> { type Error = (); @@ -133,7 +133,7 @@ impl NamedCommitNode { // Which nodes are pruned is not known when this code is executed. // If an unpruned node is unpopulated, then there will be an error // during the finalization. - Ok(self.0.get(name).cloned()) + Ok(self.witness_map.get(name).cloned()) } fn convert_disconnect( @@ -152,7 +152,7 @@ impl NamedCommitNode { // We keep the missing disconnected branches empty. // Like witness nodes (see above), disconnect nodes may be pruned later. // The finalization will detect missing branches and throw an error. - let maybe_commit = self.1.get(hole_name); + let maybe_commit = self.disconnect_map.get(hole_name); // FIXME: Recursive call of to_witness_node // We cannot introduce a stack // because we are implementing methods of the trait Converter @@ -161,7 +161,9 @@ impl NamedCommitNode { // OTOH, if a user writes a program with so many disconnected expressions // that there is a stack overflow, it's his own fault :) // This would fail in a fuzz test. - let witness = maybe_commit.map(|commit| commit.to_witness_node(self.0, self.1)); + let witness = maybe_commit.map(|commit| { + commit.to_witness_node(self.witness_map, self.disconnect_map) + }); Ok(witness) } } @@ -183,8 +185,12 @@ impl NamedCommitNode { } } - self.convert::(&mut Populator(witness, disconnect, PhantomData)) - .unwrap() + self.convert::(&mut Populator { + witness_map: witness, + disconnect_map: disconnect, + phantom: PhantomData, + }) + .unwrap() } /// Encode a Simplicity expression to bits without any witness data From 12a188570a1cdd5adf73cef73512da58aa9cec2f Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sat, 29 Jun 2024 14:55:37 +0000 Subject: [PATCH 3/4] types: refactor precomputed data to directly store finalize types This eliminates an `unwrap`, cleans up a bit of code in named_node.rs, and moves every instance of `Type::from` into types/mod.rs, where it will be easy to modify or remove in a later commit. Also removes From> for Type, a somewhat-weird From impl which was introduced in #218. It is now superceded by Type::complete. In general I would like to remove From impls from Type because they have an inflexible API and because they have somewhat nonobvious behavior. --- src/human_encoding/named_node.rs | 8 ++------ src/human_encoding/parse/ast.rs | 4 +--- src/types/final_data.rs | 9 +-------- src/types/mod.rs | 7 ++++++- src/types/precomputed.rs | 23 ++++++++++++----------- 5 files changed, 22 insertions(+), 29 deletions(-) diff --git a/src/human_encoding/named_node.rs b/src/human_encoding/named_node.rs index a4d208a6..f282d6f0 100644 --- a/src/human_encoding/named_node.rs +++ b/src/human_encoding/named_node.rs @@ -413,17 +413,13 @@ impl NamedConstructNode { if self.for_main { // For `main`, only apply type ascriptions *after* inference has completely // determined the type. - let source_bound = - types::Bound::Complete(Arc::clone(&commit_data.arrow().source)); - let source_ty = types::Type::from(source_bound); + let source_ty = types::Type::complete(Arc::clone(&commit_data.arrow().source)); for ty in data.node.cached_data().user_source_types.as_ref() { if let Err(e) = source_ty.unify(ty, "binding source type annotation") { self.errors.add(data.node.position(), e); } } - let target_bound = - types::Bound::Complete(Arc::clone(&commit_data.arrow().target)); - let target_ty = types::Type::from(target_bound); + let target_ty = types::Type::complete(Arc::clone(&commit_data.arrow().target)); for ty in data.node.cached_data().user_target_types.as_ref() { if let Err(e) = target_ty.unify(ty, "binding target type annotation") { self.errors.add(data.node.position(), e); diff --git a/src/human_encoding/parse/ast.rs b/src/human_encoding/parse/ast.rs index 241b606b..f58e9d4e 100644 --- a/src/human_encoding/parse/ast.rs +++ b/src/human_encoding/parse/ast.rs @@ -633,9 +633,7 @@ fn grammar() -> Grammar> { Error::BadWordLength { bit_length }, )); } - let ty = types::Type::two_two_n(bit_length.trailing_zeros() as usize) - .final_data() - .unwrap(); + let ty = types::Final::two_two_n(bit_length.trailing_zeros() as usize); // unwrap ok here since literally every sequence of bits is a valid // value for the given type let value = iter.read_value(&ty).unwrap(); diff --git a/src/types/final_data.rs b/src/types/final_data.rs index 1bfc06e9..8858edd3 100644 --- a/src/types/final_data.rs +++ b/src/types/final_data.rs @@ -12,7 +12,6 @@ //! use crate::dag::{Dag, DagLike, NoSharing}; -use crate::types::{Bound, Type}; use crate::Tmr; use std::sync::Arc; @@ -163,7 +162,7 @@ impl Final { /// /// The type is precomputed and fast to access. pub fn two_two_n(n: usize) -> Arc { - super::precomputed::nth_power_of_2(n).final_data().unwrap() + super::precomputed::nth_power_of_2(n) } /// Create the sum of the given `left` and `right` types. @@ -227,12 +226,6 @@ impl Final { } } -impl From> for Type { - fn from(value: Arc) -> Self { - Type::from(Bound::Complete(value)) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/types/mod.rs b/src/types/mod.rs index 3aad8f91..9461008e 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -402,7 +402,7 @@ impl Type { /// /// The type is precomputed and fast to access. pub fn two_two_n(n: usize) -> Self { - precomputed::nth_power_of_2(n) + Self::complete(precomputed::nth_power_of_2(n)) } /// Create the sum of the given `left` and `right` types. @@ -415,6 +415,11 @@ impl Type { Type::from(Bound::product(left, right)) } + /// Create a complete type. + pub fn complete(final_data: Arc) -> Self { + Type::from(Bound::Complete(final_data)) + } + /// Clones the `Type`. /// /// This is the same as just calling `.clone()` but has a different name to diff --git a/src/types/precomputed.rs b/src/types/precomputed.rs index 86d6a683..fb67ae32 100644 --- a/src/types/precomputed.rs +++ b/src/types/precomputed.rs @@ -14,33 +14,34 @@ use crate::Tmr; -use super::Type; +use super::Final; use std::cell::RefCell; use std::convert::TryInto; +use std::sync::Arc; // Directly use the size of the precomputed TMR table to make sure they're in sync. const N_POWERS: usize = Tmr::POWERS_OF_TWO.len(); thread_local! { - static POWERS_OF_TWO: RefCell> = RefCell::new(None); + static POWERS_OF_TWO: RefCell; N_POWERS]>> = RefCell::new(None); } -fn initialize(write: &mut Option<[Type; N_POWERS]>) { - let one = Type::unit(); +fn initialize(write: &mut Option<[Arc; N_POWERS]>) { + let one = Final::unit(); let mut powers = Vec::with_capacity(N_POWERS); // Two^(2^0) = Two = (One + One) - let mut power = Type::sum(one.shallow_clone(), one); - powers.push(power.shallow_clone()); + let mut power = Final::sum(Arc::clone(&one), one); + powers.push(Arc::clone(&power)); // Two^(2^(i + 1)) = (Two^(2^i) * Two^(2^i)) for _ in 1..N_POWERS { - power = Type::product(power.shallow_clone(), power); - powers.push(power.shallow_clone()); + power = Final::product(Arc::clone(&power), power); + powers.push(Arc::clone(&power)); } - let powers: [Type; N_POWERS] = powers.try_into().unwrap(); + let powers: [Arc; N_POWERS] = powers.try_into().unwrap(); *write = Some(powers); } @@ -49,12 +50,12 @@ fn initialize(write: &mut Option<[Type; N_POWERS]>) { /// # Panics /// /// Panics if you request a number `n` greater than or equal to [`Tmr::POWERS_OF_TWO`]. -pub fn nth_power_of_2(n: usize) -> Type { +pub fn nth_power_of_2(n: usize) -> Arc { POWERS_OF_TWO.with(|arr| { if arr.borrow().is_none() { initialize(&mut arr.borrow_mut()); } debug_assert!(arr.borrow().is_some()); - arr.borrow().as_ref().unwrap()[n].shallow_clone() + Arc::clone(&arr.borrow().as_ref().unwrap()[n]) }) } From dc83206e8916df3ae36dd4b466afcba968b4af65 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sat, 29 Jun 2024 16:23:19 +0000 Subject: [PATCH 4/4] jet: refactor type_name The type_name conversion methods use a private TypeConstructible trait which winds up taking more code to use than it saves by reducing reuse. (Though arguably the reuse makes the code more obviously internally consistent.) In future we won't be able to use this trait anyway because the Type constructor will need a type-inference context while the other constructors will not. So delete it now. --- src/jet/type_name.rs | 131 +++++++++++++++---------------------------- 1 file changed, 44 insertions(+), 87 deletions(-) diff --git a/src/jet/type_name.rs b/src/jet/type_name.rs index 3a2e49a7..de6027e3 100644 --- a/src/jet/type_name.rs +++ b/src/jet/type_name.rs @@ -30,94 +30,68 @@ use std::sync::Arc; #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] pub struct TypeName(pub &'static [u8]); -trait TypeConstructible { - fn two_two_n(n: Option) -> Self; - fn sum(left: Self, right: Self) -> Self; - fn product(left: Self, right: Self) -> Self; -} - -impl TypeConstructible for Type { - fn two_two_n(n: Option) -> Self { - match n { - None => Type::unit(), - Some(m) => Type::two_two_n(m as usize), // cast safety: 32-bit arch or higher - } +impl TypeName { + /// Convert the type name into a type. + pub fn to_type(&self) -> Type { + Type::complete(self.to_final()) } - fn sum(left: Self, right: Self) -> Self { - Type::sum(left, right) - } + /// Convert the type name into a finalized type. + pub fn to_final(&self) -> Arc { + let mut stack = Vec::with_capacity(16); - fn product(left: Self, right: Self) -> Self { - Type::product(left, right) - } -} + for c in self.0.iter().rev() { + match c { + b'1' => stack.push(Final::unit()), + b'2' => stack.push(Final::two_two_n(0)), + b'c' => stack.push(Final::two_two_n(3)), + b's' => stack.push(Final::two_two_n(4)), + b'i' => stack.push(Final::two_two_n(5)), + b'l' => stack.push(Final::two_two_n(6)), + b'h' => stack.push(Final::two_two_n(8)), + b'+' | b'*' => { + let left = stack.pop().expect("Illegal type name syntax!"); + let right = stack.pop().expect("Illegal type name syntax!"); -impl TypeConstructible for Arc { - fn two_two_n(n: Option) -> Self { - match n { - None => Final::unit(), - Some(m) => Final::two_two_n(m as usize), // cast safety: 32-bit arch or higher + match c { + b'+' => stack.push(Final::sum(left, right)), + b'*' => stack.push(Final::product(left, right)), + _ => unreachable!(), + } + } + _ => panic!("Illegal type name syntax!"), + } } - } - - fn sum(left: Self, right: Self) -> Self { - Final::sum(left, right) - } - - fn product(left: Self, right: Self) -> Self { - Final::product(left, right) - } -} -struct BitWidth(usize); - -impl TypeConstructible for BitWidth { - fn two_two_n(n: Option) -> Self { - match n { - None => BitWidth(0), - Some(m) => BitWidth(usize::pow(2, m)), + if stack.len() == 1 { + stack.pop().unwrap() + } else { + panic!("Illegal type name syntax!") } } - fn sum(left: Self, right: Self) -> Self { - BitWidth(1 + cmp::max(left.0, right.0)) - } - - fn product(left: Self, right: Self) -> Self { - BitWidth(left.0 + right.0) - } -} - -impl TypeName { - // b'1' = 49 - // b'2' = 50 - // b'c' = 99 - // b's' = 115 - // b'i' = 105 - // b'l' = 108 - // b'h' = 104 - // b'+' = 43 - // b'*' = 42 - fn construct(&self) -> T { + /// Convert the type name into a type's bitwidth. + /// + /// This is more efficient than creating the type and computing its bit-width + pub fn to_bit_width(&self) -> usize { let mut stack = Vec::with_capacity(16); for c in self.0.iter().rev() { match c { - b'1' => stack.push(T::two_two_n(None)), - b'2' => stack.push(T::two_two_n(Some(0))), - b'c' => stack.push(T::two_two_n(Some(3))), - b's' => stack.push(T::two_two_n(Some(4))), - b'i' => stack.push(T::two_two_n(Some(5))), - b'l' => stack.push(T::two_two_n(Some(6))), - b'h' => stack.push(T::two_two_n(Some(8))), + b'1' => stack.push(0), + b'2' => stack.push(1), + b'c' => stack.push(8), + b's' => stack.push(16), + b'i' => stack.push(32), + b'l' => stack.push(64), + b'h' => stack.push(256), b'+' | b'*' => { let left = stack.pop().expect("Illegal type name syntax!"); let right = stack.pop().expect("Illegal type name syntax!"); match c { - b'+' => stack.push(T::sum(left, right)), - b'*' => stack.push(T::product(left, right)), + b'+' => stack.push(1 + cmp::max(left, right)), + b'*' => stack.push(left + right), _ => unreachable!(), } } @@ -131,21 +105,4 @@ impl TypeName { panic!("Illegal type name syntax!") } } - - /// Convert the type name into a type. - pub fn to_type(&self) -> Type { - self.construct() - } - - /// Convert the type name into a finalized type. - pub fn to_final(&self) -> Arc { - self.construct() - } - - /// Convert the type name into a type's bitwidth. - /// - /// This is more efficient than creating the type and computing its bit-width - pub fn to_bit_width(&self) -> usize { - self.construct::().0 - } }