From 72d2c6285e05063b5f25c80d6b5c6f2f3613059f Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Fri, 28 Jun 2024 21:39:45 +0000 Subject: [PATCH 01/16] types: refactor arrow module The arrow module has a ton of `for_*` methods which duplicate the methods in the Constructible traits. Most of these are purely redundant and can be deleted. A couple provide some code reuse, and are demoted to be non-`pub`. Smoe variable names changed in the redundant code, so this will likely not appear as a code-move in the diff, even though it is one. You can just skim this commit. --- src/types/arrow.rs | 232 ++++++++++++++++----------------------------- 1 file changed, 81 insertions(+), 151 deletions(-) diff --git a/src/types/arrow.rs b/src/types/arrow.rs index c18dc9e0..08ba86f7 100644 --- a/src/types/arrow.rs +++ b/src/types/arrow.rs @@ -74,135 +74,14 @@ impl Arrow { }) } - /// Create a unification arrow for a fresh `unit` combinator - pub fn for_unit() -> Self { - Arrow { - source: Type::free(new_name("unit_src_")), - target: Type::unit(), - } - } - - /// Create a unification arrow for a fresh `iden` combinator - pub fn for_iden() -> Self { - // Throughout this module, when two types are the same, we reuse a - // pointer to them rather than creating distinct types and unifying - // them. This theoretically could lead to more confusing errors for - // the user during type inference, but in practice type inference - // is completely opaque and there's no harm in making it moreso. - let new = Type::free(new_name("iden_src_")); - Arrow { - source: new.shallow_clone(), - target: new, - } - } - - /// Create a unification arrow for a fresh `witness` combinator - pub fn for_witness() -> Self { - Arrow { - source: Type::free(new_name("witness_src_")), - target: Type::free(new_name("witness_tgt_")), - } - } - - /// Create a unification arrow for a fresh `fail` combinator - pub fn for_fail() -> Self { - Arrow { - source: Type::free(new_name("fail_src_")), - target: Type::free(new_name("fail_tgt_")), - } - } - - /// Create a unification arrow for a fresh jet combinator - pub fn for_jet(jet: J) -> Self { - Arrow { - source: jet.source_ty().to_type(), - target: jet.target_ty().to_type(), - } - } - - /// Create a unification arrow for a fresh const-word combinator - pub fn for_const_word(word: &Value) -> Self { - let len = word.len(); - assert!(len > 0, "Words must not be the empty bitstring"); - assert!(len.is_power_of_two()); - let depth = word.len().trailing_zeros(); - Arrow { - source: Type::unit(), - target: Type::two_two_n(depth as usize), - } - } - - /// Create a unification arrow for a fresh `injl` combinator - pub fn for_injl(child_arrow: &Arrow) -> Self { - Arrow { - source: child_arrow.source.shallow_clone(), - target: Type::sum( - child_arrow.target.shallow_clone(), - Type::free(new_name("injl_tgt_")), - ), - } - } - - /// Create a unification arrow for a fresh `injr` combinator - pub fn for_injr(child_arrow: &Arrow) -> Self { - Arrow { - source: child_arrow.source.shallow_clone(), - target: Type::sum( - Type::free(new_name("injr_tgt_")), - child_arrow.target.shallow_clone(), - ), - } - } - - /// Create a unification arrow for a fresh `take` combinator - pub fn for_take(child_arrow: &Arrow) -> Self { - Arrow { - source: Type::product( - child_arrow.source.shallow_clone(), - Type::free(new_name("take_src_")), - ), - target: child_arrow.target.shallow_clone(), - } - } - - /// Create a unification arrow for a fresh `drop` combinator - pub fn for_drop(child_arrow: &Arrow) -> Self { + /// Same as [`Self::clone`] but named to make it clearer that this is cheap + pub fn shallow_clone(&self) -> Self { Arrow { - source: Type::product( - Type::free(new_name("drop_src_")), - child_arrow.source.shallow_clone(), - ), - target: child_arrow.target.shallow_clone(), + source: self.source.shallow_clone(), + target: self.target.shallow_clone(), } } - /// Create a unification arrow for a fresh `pair` combinator - pub fn for_pair(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { - lchild_arrow.source.unify( - &rchild_arrow.source, - "pair combinator: left source = right source", - )?; - Ok(Arrow { - source: lchild_arrow.source.shallow_clone(), - target: Type::product( - lchild_arrow.target.shallow_clone(), - rchild_arrow.target.shallow_clone(), - ), - }) - } - - /// Create a unification arrow for a fresh `comp` combinator - pub fn for_comp(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { - lchild_arrow.target.unify( - &rchild_arrow.source, - "comp combinator: left target = right source", - )?; - Ok(Arrow { - source: lchild_arrow.source.shallow_clone(), - target: rchild_arrow.target.shallow_clone(), - }) - } - /// Create a unification arrow for a fresh `case` combinator /// /// Either child may be `None`, in which case the combinator is assumed to be @@ -211,10 +90,7 @@ impl Arrow { /// /// If neither child is provided, this function will not raise an error; it /// is the responsibility of the caller to detect this case and error elsewhere. - pub fn for_case( - lchild_arrow: Option<&Arrow>, - rchild_arrow: Option<&Arrow>, - ) -> Result { + fn for_case(lchild_arrow: Option<&Arrow>, rchild_arrow: Option<&Arrow>) -> Result { let a = Type::free(new_name("case_a_")); let b = Type::free(new_name("case_b_")); let c = Type::free(new_name("case_c_")); @@ -247,8 +123,8 @@ impl Arrow { }) } - /// Create a unification arrow for a fresh `comp` combinator - pub fn for_disconnect(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { + /// Helper function to combine code for the two `DisconnectConstructible` impls for [`Arrow`]. + fn for_disconnect(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { let a = Type::free(new_name("disconnect_a_")); let b = Type::free(new_name("disconnect_b_")); let c = rchild_arrow.source.shallow_clone(); @@ -272,43 +148,76 @@ impl Arrow { target: prod_b_d, }) } - - /// Same as [`Self::clone`] but named to make it clearer that this is cheap - pub fn shallow_clone(&self) -> Self { - Arrow { - source: self.source.shallow_clone(), - target: self.target.shallow_clone(), - } - } } impl CoreConstructible for Arrow { fn iden() -> Self { - Self::for_iden() + // Throughout this module, when two types are the same, we reuse a + // pointer to them rather than creating distinct types and unifying + // them. This theoretically could lead to more confusing errors for + // the user during type inference, but in practice type inference + // is completely opaque and there's no harm in making it moreso. + let new = Type::free(new_name("iden_src_")); + Arrow { + source: new.shallow_clone(), + target: new, + } } fn unit() -> Self { - Self::for_unit() + Arrow { + source: Type::free(new_name("unit_src_")), + target: Type::unit(), + } } fn injl(child: &Self) -> Self { - Self::for_injl(child) + Arrow { + source: child.source.shallow_clone(), + target: Type::sum( + child.target.shallow_clone(), + Type::free(new_name("injl_tgt_")), + ), + } } fn injr(child: &Self) -> Self { - Self::for_injr(child) + Arrow { + source: child.source.shallow_clone(), + target: Type::sum( + Type::free(new_name("injr_tgt_")), + child.target.shallow_clone(), + ), + } } fn take(child: &Self) -> Self { - Self::for_take(child) + Arrow { + source: Type::product( + child.source.shallow_clone(), + Type::free(new_name("take_src_")), + ), + target: child.target.shallow_clone(), + } } fn drop_(child: &Self) -> Self { - Self::for_drop(child) + Arrow { + source: Type::product( + Type::free(new_name("drop_src_")), + child.source.shallow_clone(), + ), + target: child.target.shallow_clone(), + } } fn comp(left: &Self, right: &Self) -> Result { - Self::for_comp(left, right) + left.target + .unify(&right.source, "comp combinator: left target = right source")?; + Ok(Arrow { + source: left.source.shallow_clone(), + target: right.target.shallow_clone(), + }) } fn case(left: &Self, right: &Self) -> Result { @@ -324,15 +233,30 @@ impl CoreConstructible for Arrow { } fn pair(left: &Self, right: &Self) -> Result { - Self::for_pair(left, right) + left.source + .unify(&right.source, "pair combinator: left source = right source")?; + Ok(Arrow { + source: left.source.shallow_clone(), + target: Type::product(left.target.shallow_clone(), right.target.shallow_clone()), + }) } fn fail(_: crate::FailEntropy) -> Self { - Self::for_fail() + Arrow { + source: Type::free(new_name("fail_src_")), + target: Type::free(new_name("fail_tgt_")), + } } fn const_word(word: Arc) -> Self { - Self::for_const_word(&word) + let len = word.len(); + assert!(len > 0, "Words must not be the empty bitstring"); + assert!(len.is_power_of_two()); + let depth = word.len().trailing_zeros(); + Arrow { + source: Type::unit(), + target: Type::two_two_n(depth as usize), + } } } @@ -365,12 +289,18 @@ impl DisconnectConstructible> for Arrow { impl JetConstructible for Arrow { fn jet(jet: J) -> Self { - Self::for_jet(jet) + Arrow { + source: jet.source_ty().to_type(), + target: jet.target_ty().to_type(), + } } } impl WitnessConstructible for Arrow { fn witness(_: W) -> Self { - Self::for_witness() + Arrow { + source: Type::free(new_name("witness_src_")), + target: Type::free(new_name("witness_tgt_")), + } } } From dcba2f4e68e42f5ebfba66c29283ecbcd1740a3d Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Fri, 28 Jun 2024 22:32:49 +0000 Subject: [PATCH 02/16] named_node: name the fields in the Populator struct Refactor only. --- src/human_encoding/named_node.rs | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/human_encoding/named_node.rs b/src/human_encoding/named_node.rs index fec35c17..a4d208a6 100644 --- a/src/human_encoding/named_node.rs +++ b/src/human_encoding/named_node.rs @@ -113,11 +113,11 @@ impl NamedCommitNode { witness: &HashMap, Arc>, disconnect: &HashMap, Arc>>, ) -> Arc> { - struct Populator<'a, J: Jet>( - &'a HashMap, Arc>, - &'a HashMap, Arc>>, - PhantomData, - ); + struct Populator<'a, J: Jet> { + witness_map: &'a HashMap, Arc>, + disconnect_map: &'a HashMap, Arc>>, + phantom: PhantomData, + } impl<'a, J: Jet> Converter>, Witness> for Populator<'a, J> { type Error = (); @@ -133,7 +133,7 @@ impl NamedCommitNode { // Which nodes are pruned is not known when this code is executed. // If an unpruned node is unpopulated, then there will be an error // during the finalization. - Ok(self.0.get(name).cloned()) + Ok(self.witness_map.get(name).cloned()) } fn convert_disconnect( @@ -152,7 +152,7 @@ impl NamedCommitNode { // We keep the missing disconnected branches empty. // Like witness nodes (see above), disconnect nodes may be pruned later. // The finalization will detect missing branches and throw an error. - let maybe_commit = self.1.get(hole_name); + let maybe_commit = self.disconnect_map.get(hole_name); // FIXME: Recursive call of to_witness_node // We cannot introduce a stack // because we are implementing methods of the trait Converter @@ -161,7 +161,9 @@ impl NamedCommitNode { // OTOH, if a user writes a program with so many disconnected expressions // that there is a stack overflow, it's his own fault :) // This would fail in a fuzz test. - let witness = maybe_commit.map(|commit| commit.to_witness_node(self.0, self.1)); + let witness = maybe_commit.map(|commit| { + commit.to_witness_node(self.witness_map, self.disconnect_map) + }); Ok(witness) } } @@ -183,8 +185,12 @@ impl NamedCommitNode { } } - self.convert::(&mut Populator(witness, disconnect, PhantomData)) - .unwrap() + self.convert::(&mut Populator { + witness_map: witness, + disconnect_map: disconnect, + phantom: PhantomData, + }) + .unwrap() } /// Encode a Simplicity expression to bits without any witness data From 12a188570a1cdd5adf73cef73512da58aa9cec2f Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sat, 29 Jun 2024 14:55:37 +0000 Subject: [PATCH 03/16] types: refactor precomputed data to directly store finalize types This eliminates an `unwrap`, cleans up a bit of code in named_node.rs, and moves every instance of `Type::from` into types/mod.rs, where it will be easy to modify or remove in a later commit. Also removes From> for Type, a somewhat-weird From impl which was introduced in #218. It is now superceded by Type::complete. In general I would like to remove From impls from Type because they have an inflexible API and because they have somewhat nonobvious behavior. --- src/human_encoding/named_node.rs | 8 ++------ src/human_encoding/parse/ast.rs | 4 +--- src/types/final_data.rs | 9 +-------- src/types/mod.rs | 7 ++++++- src/types/precomputed.rs | 23 ++++++++++++----------- 5 files changed, 22 insertions(+), 29 deletions(-) diff --git a/src/human_encoding/named_node.rs b/src/human_encoding/named_node.rs index a4d208a6..f282d6f0 100644 --- a/src/human_encoding/named_node.rs +++ b/src/human_encoding/named_node.rs @@ -413,17 +413,13 @@ impl NamedConstructNode { if self.for_main { // For `main`, only apply type ascriptions *after* inference has completely // determined the type. - let source_bound = - types::Bound::Complete(Arc::clone(&commit_data.arrow().source)); - let source_ty = types::Type::from(source_bound); + let source_ty = types::Type::complete(Arc::clone(&commit_data.arrow().source)); for ty in data.node.cached_data().user_source_types.as_ref() { if let Err(e) = source_ty.unify(ty, "binding source type annotation") { self.errors.add(data.node.position(), e); } } - let target_bound = - types::Bound::Complete(Arc::clone(&commit_data.arrow().target)); - let target_ty = types::Type::from(target_bound); + let target_ty = types::Type::complete(Arc::clone(&commit_data.arrow().target)); for ty in data.node.cached_data().user_target_types.as_ref() { if let Err(e) = target_ty.unify(ty, "binding target type annotation") { self.errors.add(data.node.position(), e); diff --git a/src/human_encoding/parse/ast.rs b/src/human_encoding/parse/ast.rs index 241b606b..f58e9d4e 100644 --- a/src/human_encoding/parse/ast.rs +++ b/src/human_encoding/parse/ast.rs @@ -633,9 +633,7 @@ fn grammar() -> Grammar> { Error::BadWordLength { bit_length }, )); } - let ty = types::Type::two_two_n(bit_length.trailing_zeros() as usize) - .final_data() - .unwrap(); + let ty = types::Final::two_two_n(bit_length.trailing_zeros() as usize); // unwrap ok here since literally every sequence of bits is a valid // value for the given type let value = iter.read_value(&ty).unwrap(); diff --git a/src/types/final_data.rs b/src/types/final_data.rs index 1bfc06e9..8858edd3 100644 --- a/src/types/final_data.rs +++ b/src/types/final_data.rs @@ -12,7 +12,6 @@ //! use crate::dag::{Dag, DagLike, NoSharing}; -use crate::types::{Bound, Type}; use crate::Tmr; use std::sync::Arc; @@ -163,7 +162,7 @@ impl Final { /// /// The type is precomputed and fast to access. pub fn two_two_n(n: usize) -> Arc { - super::precomputed::nth_power_of_2(n).final_data().unwrap() + super::precomputed::nth_power_of_2(n) } /// Create the sum of the given `left` and `right` types. @@ -227,12 +226,6 @@ impl Final { } } -impl From> for Type { - fn from(value: Arc) -> Self { - Type::from(Bound::Complete(value)) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/types/mod.rs b/src/types/mod.rs index 3aad8f91..9461008e 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -402,7 +402,7 @@ impl Type { /// /// The type is precomputed and fast to access. pub fn two_two_n(n: usize) -> Self { - precomputed::nth_power_of_2(n) + Self::complete(precomputed::nth_power_of_2(n)) } /// Create the sum of the given `left` and `right` types. @@ -415,6 +415,11 @@ impl Type { Type::from(Bound::product(left, right)) } + /// Create a complete type. + pub fn complete(final_data: Arc) -> Self { + Type::from(Bound::Complete(final_data)) + } + /// Clones the `Type`. /// /// This is the same as just calling `.clone()` but has a different name to diff --git a/src/types/precomputed.rs b/src/types/precomputed.rs index 86d6a683..fb67ae32 100644 --- a/src/types/precomputed.rs +++ b/src/types/precomputed.rs @@ -14,33 +14,34 @@ use crate::Tmr; -use super::Type; +use super::Final; use std::cell::RefCell; use std::convert::TryInto; +use std::sync::Arc; // Directly use the size of the precomputed TMR table to make sure they're in sync. const N_POWERS: usize = Tmr::POWERS_OF_TWO.len(); thread_local! { - static POWERS_OF_TWO: RefCell> = RefCell::new(None); + static POWERS_OF_TWO: RefCell; N_POWERS]>> = RefCell::new(None); } -fn initialize(write: &mut Option<[Type; N_POWERS]>) { - let one = Type::unit(); +fn initialize(write: &mut Option<[Arc; N_POWERS]>) { + let one = Final::unit(); let mut powers = Vec::with_capacity(N_POWERS); // Two^(2^0) = Two = (One + One) - let mut power = Type::sum(one.shallow_clone(), one); - powers.push(power.shallow_clone()); + let mut power = Final::sum(Arc::clone(&one), one); + powers.push(Arc::clone(&power)); // Two^(2^(i + 1)) = (Two^(2^i) * Two^(2^i)) for _ in 1..N_POWERS { - power = Type::product(power.shallow_clone(), power); - powers.push(power.shallow_clone()); + power = Final::product(Arc::clone(&power), power); + powers.push(Arc::clone(&power)); } - let powers: [Type; N_POWERS] = powers.try_into().unwrap(); + let powers: [Arc; N_POWERS] = powers.try_into().unwrap(); *write = Some(powers); } @@ -49,12 +50,12 @@ fn initialize(write: &mut Option<[Type; N_POWERS]>) { /// # Panics /// /// Panics if you request a number `n` greater than or equal to [`Tmr::POWERS_OF_TWO`]. -pub fn nth_power_of_2(n: usize) -> Type { +pub fn nth_power_of_2(n: usize) -> Arc { POWERS_OF_TWO.with(|arr| { if arr.borrow().is_none() { initialize(&mut arr.borrow_mut()); } debug_assert!(arr.borrow().is_some()); - arr.borrow().as_ref().unwrap()[n].shallow_clone() + Arc::clone(&arr.borrow().as_ref().unwrap()[n]) }) } From dc83206e8916df3ae36dd4b466afcba968b4af65 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sat, 29 Jun 2024 16:23:19 +0000 Subject: [PATCH 04/16] jet: refactor type_name The type_name conversion methods use a private TypeConstructible trait which winds up taking more code to use than it saves by reducing reuse. (Though arguably the reuse makes the code more obviously internally consistent.) In future we won't be able to use this trait anyway because the Type constructor will need a type-inference context while the other constructors will not. So delete it now. --- src/jet/type_name.rs | 131 +++++++++++++++---------------------------- 1 file changed, 44 insertions(+), 87 deletions(-) diff --git a/src/jet/type_name.rs b/src/jet/type_name.rs index 3a2e49a7..de6027e3 100644 --- a/src/jet/type_name.rs +++ b/src/jet/type_name.rs @@ -30,94 +30,68 @@ use std::sync::Arc; #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] pub struct TypeName(pub &'static [u8]); -trait TypeConstructible { - fn two_two_n(n: Option) -> Self; - fn sum(left: Self, right: Self) -> Self; - fn product(left: Self, right: Self) -> Self; -} - -impl TypeConstructible for Type { - fn two_two_n(n: Option) -> Self { - match n { - None => Type::unit(), - Some(m) => Type::two_two_n(m as usize), // cast safety: 32-bit arch or higher - } +impl TypeName { + /// Convert the type name into a type. + pub fn to_type(&self) -> Type { + Type::complete(self.to_final()) } - fn sum(left: Self, right: Self) -> Self { - Type::sum(left, right) - } + /// Convert the type name into a finalized type. + pub fn to_final(&self) -> Arc { + let mut stack = Vec::with_capacity(16); - fn product(left: Self, right: Self) -> Self { - Type::product(left, right) - } -} + for c in self.0.iter().rev() { + match c { + b'1' => stack.push(Final::unit()), + b'2' => stack.push(Final::two_two_n(0)), + b'c' => stack.push(Final::two_two_n(3)), + b's' => stack.push(Final::two_two_n(4)), + b'i' => stack.push(Final::two_two_n(5)), + b'l' => stack.push(Final::two_two_n(6)), + b'h' => stack.push(Final::two_two_n(8)), + b'+' | b'*' => { + let left = stack.pop().expect("Illegal type name syntax!"); + let right = stack.pop().expect("Illegal type name syntax!"); -impl TypeConstructible for Arc { - fn two_two_n(n: Option) -> Self { - match n { - None => Final::unit(), - Some(m) => Final::two_two_n(m as usize), // cast safety: 32-bit arch or higher + match c { + b'+' => stack.push(Final::sum(left, right)), + b'*' => stack.push(Final::product(left, right)), + _ => unreachable!(), + } + } + _ => panic!("Illegal type name syntax!"), + } } - } - - fn sum(left: Self, right: Self) -> Self { - Final::sum(left, right) - } - - fn product(left: Self, right: Self) -> Self { - Final::product(left, right) - } -} -struct BitWidth(usize); - -impl TypeConstructible for BitWidth { - fn two_two_n(n: Option) -> Self { - match n { - None => BitWidth(0), - Some(m) => BitWidth(usize::pow(2, m)), + if stack.len() == 1 { + stack.pop().unwrap() + } else { + panic!("Illegal type name syntax!") } } - fn sum(left: Self, right: Self) -> Self { - BitWidth(1 + cmp::max(left.0, right.0)) - } - - fn product(left: Self, right: Self) -> Self { - BitWidth(left.0 + right.0) - } -} - -impl TypeName { - // b'1' = 49 - // b'2' = 50 - // b'c' = 99 - // b's' = 115 - // b'i' = 105 - // b'l' = 108 - // b'h' = 104 - // b'+' = 43 - // b'*' = 42 - fn construct(&self) -> T { + /// Convert the type name into a type's bitwidth. + /// + /// This is more efficient than creating the type and computing its bit-width + pub fn to_bit_width(&self) -> usize { let mut stack = Vec::with_capacity(16); for c in self.0.iter().rev() { match c { - b'1' => stack.push(T::two_two_n(None)), - b'2' => stack.push(T::two_two_n(Some(0))), - b'c' => stack.push(T::two_two_n(Some(3))), - b's' => stack.push(T::two_two_n(Some(4))), - b'i' => stack.push(T::two_two_n(Some(5))), - b'l' => stack.push(T::two_two_n(Some(6))), - b'h' => stack.push(T::two_two_n(Some(8))), + b'1' => stack.push(0), + b'2' => stack.push(1), + b'c' => stack.push(8), + b's' => stack.push(16), + b'i' => stack.push(32), + b'l' => stack.push(64), + b'h' => stack.push(256), b'+' | b'*' => { let left = stack.pop().expect("Illegal type name syntax!"); let right = stack.pop().expect("Illegal type name syntax!"); match c { - b'+' => stack.push(T::sum(left, right)), - b'*' => stack.push(T::product(left, right)), + b'+' => stack.push(1 + cmp::max(left, right)), + b'*' => stack.push(left + right), _ => unreachable!(), } } @@ -131,21 +105,4 @@ impl TypeName { panic!("Illegal type name syntax!") } } - - /// Convert the type name into a type. - pub fn to_type(&self) -> Type { - self.construct() - } - - /// Convert the type name into a finalized type. - pub fn to_final(&self) -> Arc { - self.construct() - } - - /// Convert the type name into a type's bitwidth. - /// - /// This is more efficient than creating the type and computing its bit-width - pub fn to_bit_width(&self) -> usize { - self.construct::().0 - } } From 17c1828d515e35ed0b09596a0465ac102bc695c5 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Fri, 28 Jun 2024 22:08:11 +0000 Subject: [PATCH 05/16] cmr: pull Constructible impl on Cmr into an impl on an auxiliary type This looks like a purely noise-increasing change, but when we introduce type inference contexts, we will need the auxiliary type to hold an inference context (which will be unused except for sanity-checking that users are being consistent with their inference contexts). --- src/merkle/cmr.rs | 73 ++++++++++++++++++++++++++++++----------- src/policy/ast.rs | 3 +- src/policy/serialize.rs | 7 ++-- 3 files changed, 59 insertions(+), 24 deletions(-) diff --git a/src/merkle/cmr.rs b/src/merkle/cmr.rs index 8c8e421c..00ef9ff7 100644 --- a/src/merkle/cmr.rs +++ b/src/merkle/cmr.rs @@ -253,75 +253,108 @@ impl Cmr { ]; } -impl CoreConstructible for Cmr { +/// Wrapper around a CMR which allows it to be constructed with the +/// `*Constructible*` traits, allowing CMRs to be computed using the +/// same generic construction code that nodes are. +pub struct ConstructibleCmr { + pub cmr: Cmr, +} + +impl CoreConstructible for ConstructibleCmr { fn iden() -> Self { - Cmr::iden() + ConstructibleCmr { cmr: Cmr::iden() } } fn unit() -> Self { - Cmr::unit() + ConstructibleCmr { cmr: Cmr::unit() } } fn injl(child: &Self) -> Self { - Cmr::injl(*child) + ConstructibleCmr { + cmr: Cmr::injl(child.cmr), + } } fn injr(child: &Self) -> Self { - Cmr::injl(*child) + ConstructibleCmr { + cmr: Cmr::injl(child.cmr), + } } fn take(child: &Self) -> Self { - Cmr::take(*child) + ConstructibleCmr { + cmr: Cmr::take(child.cmr), + } } fn drop_(child: &Self) -> Self { - Cmr::drop(*child) + ConstructibleCmr { + cmr: Cmr::drop(child.cmr), + } } fn comp(left: &Self, right: &Self) -> Result { - Ok(Cmr::comp(*left, *right)) + Ok(ConstructibleCmr { + cmr: Cmr::comp(left.cmr, right.cmr), + }) } fn case(left: &Self, right: &Self) -> Result { - Ok(Cmr::case(*left, *right)) + Ok(ConstructibleCmr { + cmr: Cmr::case(left.cmr, right.cmr), + }) } fn assertl(left: &Self, right: Cmr) -> Result { - Ok(Cmr::case(*left, right)) + Ok(ConstructibleCmr { + cmr: Cmr::case(left.cmr, right), + }) } fn assertr(left: Cmr, right: &Self) -> Result { - Ok(Cmr::case(left, *right)) + Ok(ConstructibleCmr { + cmr: Cmr::case(left, right.cmr), + }) } fn pair(left: &Self, right: &Self) -> Result { - Ok(Cmr::pair(*left, *right)) + Ok(ConstructibleCmr { + cmr: Cmr::pair(left.cmr, right.cmr), + }) } fn fail(entropy: FailEntropy) -> Self { - Cmr::fail(entropy) + ConstructibleCmr { + cmr: Cmr::fail(entropy), + } } fn const_word(word: Arc) -> Self { - Cmr::const_word(&word) + ConstructibleCmr { + cmr: Cmr::const_word(&word), + } } } -impl DisconnectConstructible for Cmr { +impl DisconnectConstructible for ConstructibleCmr { fn disconnect(left: &Self, _right: &X) -> Result { - Ok(Cmr::disconnect(*left)) + Ok(ConstructibleCmr { + cmr: Cmr::disconnect(left.cmr), + }) } } -impl WitnessConstructible for Cmr { +impl WitnessConstructible for ConstructibleCmr { fn witness(_witness: W) -> Self { - Cmr::witness() + ConstructibleCmr { + cmr: Cmr::witness(), + } } } -impl JetConstructible for Cmr { +impl JetConstructible for ConstructibleCmr { fn jet(jet: J) -> Self { - jet.cmr() + ConstructibleCmr { cmr: jet.cmr() } } } diff --git a/src/policy/ast.rs b/src/policy/ast.rs index 5d55ec29..8c87c476 100644 --- a/src/policy/ast.rs +++ b/src/policy/ast.rs @@ -112,8 +112,9 @@ impl Policy { /// Return the CMR of the policy. pub fn cmr(&self) -> Cmr { - self.serialize_no_witness() + self.serialize_no_witness::() .expect("CMR is defined for asm fragment") + .cmr } } diff --git a/src/policy/serialize.rs b/src/policy/serialize.rs index 2322fbf0..ad6724fc 100644 --- a/src/policy/serialize.rs +++ b/src/policy/serialize.rs @@ -3,6 +3,7 @@ //! Serialization of Policy as Simplicity use crate::jet::{Elements, Jet}; +use crate::merkle::cmr::ConstructibleCmr; use crate::node::{CoreConstructible, JetConstructible, WitnessConstructible}; use crate::{Cmr, ConstructNode, ToXOnlyPubkey}; use crate::{FailEntropy, Value}; @@ -14,13 +15,13 @@ use std::sync::Arc; pub trait AssemblyConstructible: Sized { /// Construct the assembly fragment with the given CMR. /// - /// The construction fails if the CMR alone is not enough information to construct the type. + /// The construction fails if the CMR alone is not enough information to construct the object. fn assembly(cmr: Cmr) -> Option; } -impl AssemblyConstructible for Cmr { +impl AssemblyConstructible for ConstructibleCmr { fn assembly(cmr: Cmr) -> Option { - Some(cmr) + Some(ConstructibleCmr { cmr }) } } From 1e4d415b9acaaa8adc00d29a91c09f03da536d8b Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Fri, 28 Jun 2024 22:11:48 +0000 Subject: [PATCH 06/16] types: introduce inference context object, thread it through the API This is a super noisy commit, but all it does is introduce the src/types/context.rs module with a dummy Context type, and make all the node construction APIs take this and make copies of handles of it. The next commit will actually _use_ the context in type inference. This one can therefore be mostly skimmed. The maybe-interesting parts are the conversion methods, which can be found by searching the diff for calls to `from_inner`. --- src/bit_encoding/bitwriter.rs | 3 +- src/bit_encoding/decode.rs | 14 +++-- src/human_encoding/named_node.rs | 22 +++++--- src/human_encoding/parse/mod.rs | 4 +- src/jet/elements/tests.rs | 3 +- src/jet/mod.rs | 13 +++-- src/merkle/amr.rs | 4 +- src/merkle/cmr.rs | 60 ++++++++++++++++---- src/node/commit.rs | 15 ++++- src/node/construct.rs | 54 +++++++++++------- src/node/mod.rs | 92 ++++++++++++++++++------------- src/node/redeem.rs | 15 +++-- src/node/witness.rs | 43 +++++++++------ src/policy/ast.rs | 38 ++++++++----- src/policy/satisfy.rs | 35 +++++++----- src/policy/serialize.rs | 95 +++++++++++++++++++------------- src/types/arrow.rs | 68 +++++++++++++++++++---- src/types/context.rs | 77 ++++++++++++++++++++++++++ src/types/mod.rs | 15 ++++- 19 files changed, 471 insertions(+), 199 deletions(-) create mode 100644 src/types/context.rs diff --git a/src/bit_encoding/bitwriter.rs b/src/bit_encoding/bitwriter.rs index faae4f82..0a6a3be8 100644 --- a/src/bit_encoding/bitwriter.rs +++ b/src/bit_encoding/bitwriter.rs @@ -117,12 +117,13 @@ mod tests { use super::*; use crate::jet::Core; use crate::node::CoreConstructible; + use crate::types; use crate::ConstructNode; use std::sync::Arc; #[test] fn vec() { - let program = Arc::>::unit(); + let program = Arc::>::unit(&types::Context::new()); let _ = write_to_vec(|w| program.encode(w)); } diff --git a/src/bit_encoding/decode.rs b/src/bit_encoding/decode.rs index 8cd9cf96..d4999d0b 100644 --- a/src/bit_encoding/decode.rs +++ b/src/bit_encoding/decode.rs @@ -12,6 +12,7 @@ use crate::node::{ ConstructNode, CoreConstructible, DisconnectConstructible, JetConstructible, NoWitness, WitnessConstructible, }; +use crate::types; use crate::{BitIter, FailEntropy, Value}; use std::collections::HashSet; use std::sync::Arc; @@ -178,6 +179,7 @@ pub fn decode_expression, J: Jet>( return Err(Error::TooManyNodes(len)); } + let inference_context = types::Context::new(); let mut nodes = Vec::with_capacity(len); for _ in 0..len { let new_node = decode_node(bits, nodes.len())?; @@ -195,8 +197,8 @@ pub fn decode_expression, J: Jet>( } let new = match nodes[data.node.0] { - DecodeNode::Unit => Node(ArcNode::unit()), - DecodeNode::Iden => Node(ArcNode::iden()), + DecodeNode::Unit => Node(ArcNode::unit(&inference_context)), + DecodeNode::Iden => Node(ArcNode::iden(&inference_context)), DecodeNode::InjL(i) => Node(ArcNode::injl(converted[i].get()?)), DecodeNode::InjR(i) => Node(ArcNode::injr(converted[i].get()?)), DecodeNode::Take(i) => Node(ArcNode::take(converted[i].get()?)), @@ -222,16 +224,16 @@ pub fn decode_expression, J: Jet>( converted[i].get()?, &Some(Arc::clone(converted[j].get()?)), )?), - DecodeNode::Witness => Node(ArcNode::witness(NoWitness)), - DecodeNode::Fail(entropy) => Node(ArcNode::fail(entropy)), + DecodeNode::Witness => Node(ArcNode::witness(&inference_context, NoWitness)), + DecodeNode::Fail(entropy) => Node(ArcNode::fail(&inference_context, entropy)), DecodeNode::Hidden(cmr) => { if !hidden_set.insert(cmr) { return Err(Error::SharingNotMaximal); } Hidden(cmr) } - DecodeNode::Jet(j) => Node(ArcNode::jet(j)), - DecodeNode::Word(ref w) => Node(ArcNode::const_word(Arc::clone(w))), + DecodeNode::Jet(j) => Node(ArcNode::jet(&inference_context, j)), + DecodeNode::Word(ref w) => Node(ArcNode::const_word(&inference_context, Arc::clone(w))), }; converted.push(new); } diff --git a/src/human_encoding/named_node.rs b/src/human_encoding/named_node.rs index f282d6f0..4e9ee4dd 100644 --- a/src/human_encoding/named_node.rs +++ b/src/human_encoding/named_node.rs @@ -116,6 +116,7 @@ impl NamedCommitNode { struct Populator<'a, J: Jet> { witness_map: &'a HashMap, Arc>, disconnect_map: &'a HashMap, Arc>>, + inference_context: types::Context, phantom: PhantomData, } @@ -153,17 +154,16 @@ impl NamedCommitNode { // Like witness nodes (see above), disconnect nodes may be pruned later. // The finalization will detect missing branches and throw an error. let maybe_commit = self.disconnect_map.get(hole_name); - // FIXME: Recursive call of to_witness_node - // We cannot introduce a stack - // because we are implementing methods of the trait Converter - // which are used Marker::convert(). + // FIXME: recursive call to convert + // We cannot introduce a stack because we are implementing the Converter + // trait and do not have access to the actual algorithm used for conversion + // in order to save its state. // // OTOH, if a user writes a program with so many disconnected expressions // that there is a stack overflow, it's his own fault :) - // This would fail in a fuzz test. - let witness = maybe_commit.map(|commit| { - commit.to_witness_node(self.witness_map, self.disconnect_map) - }); + // This may fail in a fuzz test. + let witness = maybe_commit + .map(|commit| commit.convert::(self).unwrap()); Ok(witness) } } @@ -181,13 +181,15 @@ impl NamedCommitNode { let inner = inner .map(|node| node.cached_data()) .map_witness(|maybe_value| maybe_value.clone()); - Ok(WitnessData::from_inner(inner).expect("types are already finalized")) + Ok(WitnessData::from_inner(&self.inference_context, inner) + .expect("types are already finalized")) } } self.convert::(&mut Populator { witness_map: witness, disconnect_map: disconnect, + inference_context: types::Context::new(), phantom: PhantomData, }) .unwrap() @@ -245,6 +247,7 @@ pub struct NamedConstructData { impl NamedConstructNode { /// Construct a named construct node from parts. pub fn new( + inference_context: &types::Context, name: Arc, position: Position, user_source_types: Arc<[types::Type]>, @@ -252,6 +255,7 @@ impl NamedConstructNode { inner: node::Inner, J, Arc, WitnessOrHole>, ) -> Result { let construct_data = ConstructData::from_inner( + inference_context, inner .as_ref() .map(|data| &data.cached_data().internal) diff --git a/src/human_encoding/parse/mod.rs b/src/human_encoding/parse/mod.rs index 74940cf5..61919d9f 100644 --- a/src/human_encoding/parse/mod.rs +++ b/src/human_encoding/parse/mod.rs @@ -7,7 +7,7 @@ mod ast; use crate::dag::{Dag, DagLike, InternalSharing}; use crate::jet::Jet; use crate::node; -use crate::types::Type; +use crate::types::{self, Type}; use std::collections::HashMap; use std::mem; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -419,6 +419,7 @@ pub fn parse( drop(unresolved_map); // ** Step 3: convert each DAG of names/expressions into a DAG of NamedNodes. + let inference_context = types::Context::new(); let mut roots = HashMap::, Arc>>::new(); for (name, expr) in &resolved_map { if expr.in_degree.load(Ordering::SeqCst) > 0 { @@ -485,6 +486,7 @@ pub fn parse( .unwrap_or_else(|| Arc::from(namer.assign_name(inner.as_ref()).as_str())); let node = NamedConstructNode::new( + &inference_context, Arc::clone(&name), data.node.position, Arc::clone(&data.node.user_source_types), diff --git a/src/jet/elements/tests.rs b/src/jet/elements/tests.rs index b852edef..d1bbad9f 100644 --- a/src/jet/elements/tests.rs +++ b/src/jet/elements/tests.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use crate::jet::elements::{ElementsEnv, ElementsUtxo}; use crate::jet::Elements; use crate::node::{ConstructNode, JetConstructible}; +use crate::types; use crate::{BitMachine, Cmr, Value}; use elements::secp256k1_zkp::Tweak; use elements::taproot::ControlBlock; @@ -99,7 +100,7 @@ fn test_ffi_env() { BlockHash::all_zeros(), ); - let prog = Arc::>::jet(Elements::LockTime); + let prog = Arc::>::jet(&types::Context::new(), Elements::LockTime); assert_eq!( BitMachine::test_exec(prog, &env).expect("executing"), Value::u32(100), diff --git a/src/jet/mod.rs b/src/jet/mod.rs index 556492e3..8d78c202 100644 --- a/src/jet/mod.rs +++ b/src/jet/mod.rs @@ -93,18 +93,20 @@ pub trait Jet: mod tests { use crate::jet::Core; use crate::node::{ConstructNode, CoreConstructible, JetConstructible}; + use crate::types; use crate::{BitMachine, Value}; use std::sync::Arc; #[test] fn test_ffi_jet() { + let ctx = types::Context::new(); let two_words = Arc::>::comp( &Arc::>::pair( - &Arc::>::const_word(Value::u32(2)), - &Arc::>::const_word(Value::u32(16)), + &Arc::>::const_word(&ctx, Value::u32(2)), + &Arc::>::const_word(&ctx, Value::u32(16)), ) .unwrap(), - &Arc::>::jet(Core::Add32), + &Arc::>::jet(&ctx, Core::Add32), ) .unwrap(); assert_eq!( @@ -118,9 +120,10 @@ mod tests { #[test] fn test_simple() { + let ctx = types::Context::new(); let two_words = Arc::>::pair( - &Arc::>::const_word(Value::u32(2)), - &Arc::>::const_word(Value::u16(16)), + &Arc::>::const_word(&ctx, Value::u32(2)), + &Arc::>::const_word(&ctx, Value::u16(16)), ) .unwrap(); assert_eq!( diff --git a/src/merkle/amr.rs b/src/merkle/amr.rs index 10bb3ab2..e6d349ae 100644 --- a/src/merkle/amr.rs +++ b/src/merkle/amr.rs @@ -291,11 +291,13 @@ mod tests { use crate::jet::Core; use crate::node::{ConstructNode, JetConstructible}; + use crate::types; use std::sync::Arc; #[test] fn fixed_amr() { - let node = Arc::>::jet(Core::Verify) + let ctx = types::Context::new(); + let node = Arc::>::jet(&ctx, Core::Verify) .finalize_types_non_program() .unwrap(); // Checked against C implementation diff --git a/src/merkle/cmr.rs b/src/merkle/cmr.rs index 00ef9ff7..9093ac18 100644 --- a/src/merkle/cmr.rs +++ b/src/merkle/cmr.rs @@ -7,7 +7,7 @@ use crate::jet::Jet; use crate::node::{ CoreConstructible, DisconnectConstructible, JetConstructible, WitnessConstructible, }; -use crate::types::Error; +use crate::types::{self, Error}; use crate::{FailEntropy, Tmr, Value}; use hashes::sha256::Midstate; @@ -258,103 +258,137 @@ impl Cmr { /// same generic construction code that nodes are. pub struct ConstructibleCmr { pub cmr: Cmr, + pub inference_context: types::Context, } impl CoreConstructible for ConstructibleCmr { - fn iden() -> Self { - ConstructibleCmr { cmr: Cmr::iden() } + fn iden(inference_context: &types::Context) -> Self { + ConstructibleCmr { + cmr: Cmr::iden(), + inference_context: inference_context.shallow_clone(), + } } - fn unit() -> Self { - ConstructibleCmr { cmr: Cmr::unit() } + fn unit(inference_context: &types::Context) -> Self { + ConstructibleCmr { + cmr: Cmr::unit(), + inference_context: inference_context.shallow_clone(), + } } fn injl(child: &Self) -> Self { ConstructibleCmr { cmr: Cmr::injl(child.cmr), + inference_context: child.inference_context.shallow_clone(), } } fn injr(child: &Self) -> Self { ConstructibleCmr { cmr: Cmr::injl(child.cmr), + inference_context: child.inference_context.shallow_clone(), } } fn take(child: &Self) -> Self { ConstructibleCmr { cmr: Cmr::take(child.cmr), + inference_context: child.inference_context.shallow_clone(), } } fn drop_(child: &Self) -> Self { ConstructibleCmr { cmr: Cmr::drop(child.cmr), + inference_context: child.inference_context.shallow_clone(), } } fn comp(left: &Self, right: &Self) -> Result { + left.inference_context.check_eq(&right.inference_context)?; Ok(ConstructibleCmr { cmr: Cmr::comp(left.cmr, right.cmr), + inference_context: left.inference_context.shallow_clone(), }) } fn case(left: &Self, right: &Self) -> Result { + left.inference_context.check_eq(&right.inference_context)?; Ok(ConstructibleCmr { cmr: Cmr::case(left.cmr, right.cmr), + inference_context: left.inference_context.shallow_clone(), }) } fn assertl(left: &Self, right: Cmr) -> Result { Ok(ConstructibleCmr { cmr: Cmr::case(left.cmr, right), + inference_context: left.inference_context.shallow_clone(), }) } fn assertr(left: Cmr, right: &Self) -> Result { Ok(ConstructibleCmr { cmr: Cmr::case(left, right.cmr), + inference_context: right.inference_context.shallow_clone(), }) } fn pair(left: &Self, right: &Self) -> Result { + left.inference_context.check_eq(&right.inference_context)?; Ok(ConstructibleCmr { cmr: Cmr::pair(left.cmr, right.cmr), + inference_context: left.inference_context.shallow_clone(), }) } - fn fail(entropy: FailEntropy) -> Self { + fn fail(inference_context: &types::Context, entropy: FailEntropy) -> Self { ConstructibleCmr { cmr: Cmr::fail(entropy), + inference_context: inference_context.shallow_clone(), } } - fn const_word(word: Arc) -> Self { + fn const_word(inference_context: &types::Context, word: Arc) -> Self { ConstructibleCmr { cmr: Cmr::const_word(&word), + inference_context: inference_context.shallow_clone(), } } + + fn inference_context(&self) -> &types::Context { + &self.inference_context + } } impl DisconnectConstructible for ConstructibleCmr { + // Specifically with disconnect we don't check for consistency between the + // type inference context of the disconnected node, if any, and that of + // the left node. The idea is, from the point of view of (Constructible)Cmr, + // the right child of disconnect doesn't even exist. fn disconnect(left: &Self, _right: &X) -> Result { Ok(ConstructibleCmr { cmr: Cmr::disconnect(left.cmr), + inference_context: left.inference_context.shallow_clone(), }) } } impl WitnessConstructible for ConstructibleCmr { - fn witness(_witness: W) -> Self { + fn witness(inference_context: &types::Context, _witness: W) -> Self { ConstructibleCmr { cmr: Cmr::witness(), + inference_context: inference_context.shallow_clone(), } } } impl JetConstructible for ConstructibleCmr { - fn jet(jet: J) -> Self { - ConstructibleCmr { cmr: jet.cmr() } + fn jet(inference_context: &types::Context, jet: J) -> Self { + ConstructibleCmr { + cmr: jet.cmr(), + inference_context: inference_context.shallow_clone(), + } } } @@ -370,7 +404,8 @@ mod tests { #[test] fn cmr_display_unit() { - let c = Arc::>::unit(); + let ctx = types::Context::new(); + let c = Arc::>::unit(&ctx); assert_eq!( c.cmr().to_string(), @@ -397,7 +432,8 @@ mod tests { #[test] fn bit_cmr() { - let unit = Arc::>::unit(); + let ctx = types::Context::new(); + let unit = Arc::>::unit(&ctx); let bit0 = Arc::>::injl(&unit); assert_eq!(bit0.cmr(), Cmr::BITS[0]); diff --git a/src/node/commit.rs b/src/node/commit.rs index a2dc81c2..ab7ee71a 100644 --- a/src/node/commit.rs +++ b/src/node/commit.rs @@ -202,7 +202,10 @@ impl CommitNode { /// Convert a [`CommitNode`] back to a [`ConstructNode`] by redoing type inference pub fn unfinalize_types(&self) -> Result>, types::Error> { - struct UnfinalizeTypes(PhantomData); + struct UnfinalizeTypes { + inference_context: types::Context, + phantom: PhantomData, + } impl Converter, Construct> for UnfinalizeTypes { type Error = types::Error; @@ -232,11 +235,17 @@ impl CommitNode { .map(|node| node.arrow()) .map_disconnect(|maybe_node| maybe_node.as_ref().map(|node| node.arrow())); let inner = inner.disconnect_as_ref(); // lol sigh rust - Ok(ConstructData::new(Arrow::from_inner(inner)?)) + Ok(ConstructData::new(Arrow::from_inner( + &self.inference_context, + inner, + )?)) } } - self.convert::>, _, _>(&mut UnfinalizeTypes(PhantomData)) + self.convert::>, _, _>(&mut UnfinalizeTypes { + inference_context: types::Context::new(), + phantom: PhantomData, + }) } /// Decode a Simplicity program from bits, without witness data. diff --git a/src/node/construct.rs b/src/node/construct.rs index 29999c36..49c96a8c 100644 --- a/src/node/construct.rs +++ b/src/node/construct.rs @@ -165,16 +165,16 @@ impl ConstructData { } impl CoreConstructible for ConstructData { - fn iden() -> Self { + fn iden(inference_context: &types::Context) -> Self { ConstructData { - arrow: Arrow::iden(), + arrow: Arrow::iden(inference_context), phantom: PhantomData, } } - fn unit() -> Self { + fn unit(inference_context: &types::Context) -> Self { ConstructData { - arrow: Arrow::unit(), + arrow: Arrow::unit(inference_context), phantom: PhantomData, } } @@ -242,19 +242,23 @@ impl CoreConstructible for ConstructData { }) } - fn fail(entropy: FailEntropy) -> Self { + fn fail(inference_context: &types::Context, entropy: FailEntropy) -> Self { ConstructData { - arrow: Arrow::fail(entropy), + arrow: Arrow::fail(inference_context, entropy), phantom: PhantomData, } } - fn const_word(word: Arc) -> Self { + fn const_word(inference_context: &types::Context, word: Arc) -> Self { ConstructData { - arrow: Arrow::const_word(word), + arrow: Arrow::const_word(inference_context, word), phantom: PhantomData, } } + + fn inference_context(&self) -> &types::Context { + self.arrow.inference_context() + } } impl DisconnectConstructible>>> for ConstructData { @@ -271,18 +275,18 @@ impl DisconnectConstructible>>> for Construc } impl WitnessConstructible for ConstructData { - fn witness(witness: NoWitness) -> Self { + fn witness(inference_context: &types::Context, witness: NoWitness) -> Self { ConstructData { - arrow: Arrow::witness(witness), + arrow: Arrow::witness(inference_context, witness), phantom: PhantomData, } } } impl JetConstructible for ConstructData { - fn jet(jet: J) -> Self { + fn jet(inference_context: &types::Context, jet: J) -> Self { ConstructData { - arrow: Arrow::jet(jet), + arrow: Arrow::jet(inference_context, jet), phantom: PhantomData, } } @@ -295,7 +299,8 @@ mod tests { #[test] fn occurs_check_error() { - let iden = Arc::>::iden(); + let ctx = types::Context::new(); + let iden = Arc::>::iden(&ctx); let node = Arc::>::disconnect(&iden, &Some(Arc::clone(&iden))).unwrap(); assert!(matches!( @@ -306,8 +311,9 @@ mod tests { #[test] fn occurs_check_2() { + let ctx = types::Context::new(); // A more complicated occurs-check test that caused a deadlock in the past. - let iden = Arc::>::iden(); + let iden = Arc::>::iden(&ctx); let injr = Arc::>::injr(&iden); let pair = Arc::>::pair(&injr, &iden).unwrap(); let drop = Arc::>::drop_(&pair); @@ -326,8 +332,9 @@ mod tests { #[test] fn occurs_check_3() { + let ctx = types::Context::new(); // A similar example that caused a slightly different deadlock in the past. - let wit = Arc::>::witness(NoWitness); + let wit = Arc::>::witness(&ctx, NoWitness); let drop = Arc::>::drop_(&wit); let comp1 = Arc::>::comp(&drop, &drop).unwrap(); @@ -353,7 +360,8 @@ mod tests { #[test] fn type_check_error() { - let unit = Arc::>::unit(); + let ctx = types::Context::new(); + let unit = Arc::>::unit(&ctx); let case = Arc::>::case(&unit, &unit).unwrap(); assert!(matches!( @@ -364,26 +372,30 @@ mod tests { #[test] fn scribe() { - let unit = Arc::>::unit(); + // Ok to use same type inference context for all the below tests, + // since everything has concrete types and anyway we only care + // about CMRs, for which type inference is irrelevant. + let ctx = types::Context::new(); + let unit = Arc::>::unit(&ctx); let bit0 = Arc::>::injl(&unit); let bit1 = Arc::>::injr(&unit); let bits01 = Arc::>::pair(&bit0, &bit1).unwrap(); assert_eq!( unit.cmr(), - Arc::>::scribe(&Value::Unit).cmr() + Arc::>::scribe(&ctx, &Value::Unit).cmr() ); assert_eq!( bit0.cmr(), - Arc::>::scribe(&Value::u1(0)).cmr() + Arc::>::scribe(&ctx, &Value::u1(0)).cmr() ); assert_eq!( bit1.cmr(), - Arc::>::scribe(&Value::u1(1)).cmr() + Arc::>::scribe(&ctx, &Value::u1(1)).cmr() ); assert_eq!( bits01.cmr(), - Arc::>::scribe(&Value::u2(1)).cmr() + Arc::>::scribe(&ctx, &Value::u2(1)).cmr() ); } } diff --git a/src/node/mod.rs b/src/node/mod.rs index c0b7b943..f057c1a6 100644 --- a/src/node/mod.rs +++ b/src/node/mod.rs @@ -130,10 +130,13 @@ pub trait Constructible: + CoreConstructible + Sized { - fn from_inner(inner: Inner<&Self, J, &X, W>) -> Result { + fn from_inner( + inference_context: &types::Context, + inner: Inner<&Self, J, &X, W>, + ) -> Result { match inner { - Inner::Iden => Ok(Self::iden()), - Inner::Unit => Ok(Self::unit()), + Inner::Iden => Ok(Self::iden(inference_context)), + Inner::Unit => Ok(Self::unit(inference_context)), Inner::InjL(child) => Ok(Self::injl(child)), Inner::InjR(child) => Ok(Self::injr(child)), Inner::Take(child) => Ok(Self::take(child)), @@ -144,10 +147,10 @@ pub trait Constructible: Inner::AssertR(l_cmr, right) => Self::assertr(l_cmr, right), Inner::Pair(left, right) => Self::pair(left, right), Inner::Disconnect(left, right) => Self::disconnect(left, right), - Inner::Fail(entropy) => Ok(Self::fail(entropy)), - Inner::Word(ref w) => Ok(Self::const_word(Arc::clone(w))), - Inner::Jet(j) => Ok(Self::jet(j)), - Inner::Witness(w) => Ok(Self::witness(w)), + Inner::Fail(entropy) => Ok(Self::fail(inference_context, entropy)), + Inner::Word(ref w) => Ok(Self::const_word(inference_context, Arc::clone(w))), + Inner::Jet(j) => Ok(Self::jet(inference_context, j)), + Inner::Witness(w) => Ok(Self::witness(inference_context, w)), } } } @@ -162,8 +165,8 @@ impl Constructible for T where } pub trait CoreConstructible: Sized { - fn iden() -> Self; - fn unit() -> Self; + fn iden(inference_context: &types::Context) -> Self; + fn unit(inference_context: &types::Context) -> Self; fn injl(child: &Self) -> Self; fn injr(child: &Self) -> Self; fn take(child: &Self) -> Self; @@ -173,17 +176,20 @@ pub trait CoreConstructible: Sized { fn assertl(left: &Self, right: Cmr) -> Result; fn assertr(left: Cmr, right: &Self) -> Result; fn pair(left: &Self, right: &Self) -> Result; - fn fail(entropy: FailEntropy) -> Self; - fn const_word(word: Arc) -> Self; + fn fail(inference_context: &types::Context, entropy: FailEntropy) -> Self; + fn const_word(inference_context: &types::Context, word: Arc) -> Self; + + /// Accessor for the type inference context used to create the object. + fn inference_context(&self) -> &types::Context; /// Create a DAG that takes any input and returns `value` as constant output. /// /// _Overall type: A → B where value: B_ - fn scribe(value: &Value) -> Self { + fn scribe(inference_context: &types::Context, value: &Value) -> Self { let mut stack = vec![]; for data in value.post_order_iter::() { match data.node { - Value::Unit => stack.push(Self::unit()), + Value::Unit => stack.push(Self::unit(inference_context)), Value::SumL(..) => { let child = stack.pop().unwrap(); stack.push(Self::injl(&child)); @@ -208,16 +214,16 @@ pub trait CoreConstructible: Sized { /// Create a DAG that takes any input and returns bit `0` as constant output. /// /// _Overall type: A → 2_ - fn bit_false() -> Self { - let unit = Self::unit(); + fn bit_false(inference_context: &types::Context) -> Self { + let unit = Self::unit(inference_context); Self::injl(&unit) } /// Create a DAG that takes any input and returns bit `1` as constant output. /// /// _Overall type: A → 2_ - fn bit_true() -> Self { - let unit = Self::unit(); + fn bit_true(inference_context: &types::Context) -> Self { + let unit = Self::unit(inference_context); Self::injr(&unit) } @@ -241,7 +247,7 @@ pub trait CoreConstructible: Sized { /// /// _Type inference will fail if children are not of the correct type._ fn assert(child: &Self, hash: Cmr) -> Result { - let unit = Self::unit(); + let unit = Self::unit(child.inference_context()); let pair_child_unit = Self::pair(child, &unit)?; let assertr_hidden_unit = Self::assertr(hash, &unit)?; @@ -255,10 +261,10 @@ pub trait CoreConstructible: Sized { /// _Type inference will fail if children are not of the correct type._ #[allow(clippy::should_implement_trait)] fn not(child: &Self) -> Result { - let unit = Self::unit(); + let unit = Self::unit(child.inference_context()); let pair_child_unit = Self::pair(child, &unit)?; - let bit_true = Self::bit_true(); - let bit_false = Self::bit_false(); + let bit_true = Self::bit_true(child.inference_context()); + let bit_false = Self::bit_false(child.inference_context()); let case_true_false = Self::case(&bit_true, &bit_false)?; Self::comp(&pair_child_unit, &case_true_false) @@ -270,9 +276,11 @@ pub trait CoreConstructible: Sized { /// /// _Type inference will fail if children are not of the correct type._ fn and(left: &Self, right: &Self) -> Result { - let iden = Self::iden(); + left.inference_context() + .check_eq(right.inference_context())?; + let iden = Self::iden(left.inference_context()); let pair_left_iden = Self::pair(left, &iden)?; - let bit_false = Self::bit_false(); + let bit_false = Self::bit_false(left.inference_context()); let drop_right = Self::drop_(right); let case_false_right = Self::case(&bit_false, &drop_right)?; @@ -285,10 +293,12 @@ pub trait CoreConstructible: Sized { /// /// _Type inference will fail if children are not of the correct type._ fn or(left: &Self, right: &Self) -> Result { - let iden = Self::iden(); + left.inference_context() + .check_eq(right.inference_context())?; + let iden = Self::iden(left.inference_context()); let pair_left_iden = Self::pair(left, &iden)?; let drop_right = Self::drop_(right); - let bit_true = Self::bit_true(); + let bit_true = Self::bit_true(left.inference_context()); let case_right_true = Self::case(&drop_right, &bit_true)?; Self::comp(&pair_left_iden, &case_right_true) @@ -300,11 +310,11 @@ pub trait DisconnectConstructible: Sized { } pub trait JetConstructible: Sized { - fn jet(jet: J) -> Self; + fn jet(inference_context: &types::Context, jet: J) -> Self; } pub trait WitnessConstructible: Sized { - fn witness(witness: W) -> Self; + fn witness(inference_context: &types::Context, witness: W) -> Self; } /// A node in a Simplicity expression. @@ -373,18 +383,18 @@ where N: Marker, N::CachedData: CoreConstructible, { - fn iden() -> Self { + fn iden(inference_context: &types::Context) -> Self { Arc::new(Node { cmr: Cmr::iden(), - data: N::CachedData::iden(), + data: N::CachedData::iden(inference_context), inner: Inner::Iden, }) } - fn unit() -> Self { + fn unit(inference_context: &types::Context) -> Self { Arc::new(Node { cmr: Cmr::unit(), - data: N::CachedData::unit(), + data: N::CachedData::unit(inference_context), inner: Inner::Unit, }) } @@ -461,21 +471,25 @@ where })) } - fn fail(entropy: FailEntropy) -> Self { + fn fail(inference_context: &types::Context, entropy: FailEntropy) -> Self { Arc::new(Node { cmr: Cmr::fail(entropy), - data: N::CachedData::fail(entropy), + data: N::CachedData::fail(inference_context, entropy), inner: Inner::Fail(entropy), }) } - fn const_word(value: Arc) -> Self { + fn const_word(inference_context: &types::Context, value: Arc) -> Self { Arc::new(Node { cmr: Cmr::const_word(&value), - data: N::CachedData::const_word(Arc::clone(&value)), + data: N::CachedData::const_word(inference_context, Arc::clone(&value)), inner: Inner::Word(value), }) } + + fn inference_context(&self) -> &types::Context { + self.data.inference_context() + } } impl DisconnectConstructible for Arc> @@ -497,10 +511,10 @@ where N: Marker, N::CachedData: WitnessConstructible, { - fn witness(value: N::Witness) -> Self { + fn witness(inference_context: &types::Context, value: N::Witness) -> Self { Arc::new(Node { cmr: Cmr::witness(), - data: N::CachedData::witness(value.clone()), + data: N::CachedData::witness(inference_context, value.clone()), inner: Inner::Witness(value), }) } @@ -511,10 +525,10 @@ where N: Marker, N::CachedData: JetConstructible, { - fn jet(jet: N::Jet) -> Self { + fn jet(inference_context: &types::Context, jet: N::Jet) -> Self { Arc::new(Node { cmr: Cmr::jet(jet), - data: N::CachedData::jet(jet), + data: N::CachedData::jet(inference_context, jet), inner: Inner::Jet(jet), }) } diff --git a/src/node/redeem.rs b/src/node/redeem.rs index 0679ec22..f621de29 100644 --- a/src/node/redeem.rs +++ b/src/node/redeem.rs @@ -223,7 +223,10 @@ impl RedeemNode { /// Convert a [`RedeemNode`] back into a [`WitnessNode`] /// by loosening the finalized types, witness data and disconnected branches. pub fn to_witness_node(&self) -> Arc> { - struct ToWitness(PhantomData); + struct ToWitness { + inference_context: types::Context, + phantom: PhantomData, + } impl Converter, Witness> for ToWitness { type Error = (); @@ -258,12 +261,16 @@ impl RedeemNode { let inner = inner .map(|node| node.cached_data()) .map_witness(|maybe_value| maybe_value.clone()); - Ok(WitnessData::from_inner(inner).expect("types are already finalized")) + Ok(WitnessData::from_inner(&self.inference_context, inner) + .expect("types were already finalized")) } } - self.convert::(&mut ToWitness(PhantomData)) - .unwrap() + self.convert::(&mut ToWitness { + inference_context: types::Context::new(), + phantom: PhantomData, + }) + .unwrap() } /// Decode a Simplicity program from bits, including the witness data. diff --git a/src/node/witness.rs b/src/node/witness.rs index 41cd9086..b1aea3f3 100644 --- a/src/node/witness.rs +++ b/src/node/witness.rs @@ -74,7 +74,10 @@ impl WitnessNode { } pub fn prune_and_retype(&self) -> Arc { - struct Retyper(PhantomData); + struct Retyper { + inference_context: types::Context, + phantom: PhantomData, + } impl Converter, Witness> for Retyper { type Error = types::Error; @@ -131,7 +134,8 @@ impl WitnessNode { .map(|node| node.cached_data()) .map_witness(Option::>::clone); // This next line does the actual retyping. - let mut retyped = WitnessData::from_inner(converted_inner)?; + let mut retyped = + WitnessData::from_inner(&self.inference_context, converted_inner)?; // Sometimes we set the prune bit on nodes without setting that // of their children; in this case the prune bit inferred from // `converted_inner` will be incorrect. @@ -144,8 +148,11 @@ impl WitnessNode { // FIXME after running the `ReTyper` we should run a `WitnessShrinker` which // shrinks the witness data in case the ReTyper shrank its types. - self.convert::(&mut Retyper(PhantomData)) - .expect("type inference won't fail if it succeeded before") + self.convert::(&mut Retyper { + inference_context: types::Context::new(), + phantom: PhantomData, + }) + .expect("type inference won't fail if it succeeded before") } pub fn finalize(&self) -> Result>, Error> { @@ -228,17 +235,17 @@ pub struct WitnessData { } impl CoreConstructible for WitnessData { - fn iden() -> Self { + fn iden(inference_context: &types::Context) -> Self { WitnessData { - arrow: Arrow::iden(), + arrow: Arrow::iden(inference_context), must_prune: false, phantom: PhantomData, } } - fn unit() -> Self { + fn unit(inference_context: &types::Context) -> Self { WitnessData { - arrow: Arrow::unit(), + arrow: Arrow::unit(inference_context), must_prune: false, phantom: PhantomData, } @@ -319,22 +326,26 @@ impl CoreConstructible for WitnessData { }) } - fn fail(entropy: FailEntropy) -> Self { + fn fail(inference_context: &types::Context, entropy: FailEntropy) -> Self { // Fail nodes always get pruned. WitnessData { - arrow: Arrow::fail(entropy), + arrow: Arrow::fail(inference_context, entropy), must_prune: true, phantom: PhantomData, } } - fn const_word(word: Arc) -> Self { + fn const_word(inference_context: &types::Context, word: Arc) -> Self { WitnessData { - arrow: Arrow::const_word(word), + arrow: Arrow::const_word(inference_context, word), must_prune: false, phantom: PhantomData, } } + + fn inference_context(&self) -> &types::Context { + self.arrow.inference_context() + } } impl DisconnectConstructible>>> for WitnessData { @@ -349,9 +360,9 @@ impl DisconnectConstructible>>> for WitnessDat } impl WitnessConstructible>> for WitnessData { - fn witness(witness: Option>) -> Self { + fn witness(inference_context: &types::Context, witness: Option>) -> Self { WitnessData { - arrow: Arrow::witness(NoWitness), + arrow: Arrow::witness(inference_context, NoWitness), must_prune: witness.is_none(), phantom: PhantomData, } @@ -359,9 +370,9 @@ impl WitnessConstructible>> for WitnessData { } impl JetConstructible for WitnessData { - fn jet(jet: J) -> Self { + fn jet(inference_context: &types::Context, jet: J) -> Self { WitnessData { - arrow: Arrow::jet(jet), + arrow: Arrow::jet(inference_context, jet), must_prune: false, phantom: PhantomData, } diff --git a/src/policy/ast.rs b/src/policy/ast.rs index 8c87c476..94360646 100644 --- a/src/policy/ast.rs +++ b/src/policy/ast.rs @@ -17,6 +17,7 @@ use crate::node::{ ConstructNode, CoreConstructible, JetConstructible, NoWitness, WitnessConstructible, }; use crate::policy::serialize::{self, AssemblyConstructible}; +use crate::types; use crate::{Cmr, CommitNode, FailEntropy}; use crate::{SimplicityKey, ToXOnlyPubkey, Translator}; @@ -58,7 +59,7 @@ pub enum Policy { impl Policy { /// Serializes the policy as a Simplicity fragment, with all witness nodes unpopulated. - fn serialize_no_witness(&self) -> Option + fn serialize_no_witness(&self, inference_context: &types::Context) -> Option where N: CoreConstructible + JetConstructible @@ -66,53 +67,60 @@ impl Policy { + AssemblyConstructible, { match *self { - Policy::Unsatisfiable(entropy) => Some(serialize::unsatisfiable(entropy)), - Policy::Trivial => Some(serialize::trivial()), - Policy::After(n) => Some(serialize::after(n)), - Policy::Older(n) => Some(serialize::older(n)), - Policy::Key(ref key) => Some(serialize::key(key, NoWitness)), - Policy::Sha256(ref hash) => Some(serialize::sha256::(hash, NoWitness)), + Policy::Unsatisfiable(entropy) => { + Some(serialize::unsatisfiable(inference_context, entropy)) + } + Policy::Trivial => Some(serialize::trivial(inference_context)), + Policy::After(n) => Some(serialize::after(inference_context, n)), + Policy::Older(n) => Some(serialize::older(inference_context, n)), + Policy::Key(ref key) => Some(serialize::key(inference_context, key, NoWitness)), + Policy::Sha256(ref hash) => Some(serialize::sha256::( + inference_context, + hash, + NoWitness, + )), Policy::And { ref left, ref right, } => { - let left = left.serialize_no_witness()?; - let right = right.serialize_no_witness()?; + let left = left.serialize_no_witness(inference_context)?; + let right = right.serialize_no_witness(inference_context)?; Some(serialize::and(&left, &right)) } Policy::Or { ref left, ref right, } => { - let left = left.serialize_no_witness()?; - let right = right.serialize_no_witness()?; + let left = left.serialize_no_witness(inference_context)?; + let right = right.serialize_no_witness(inference_context)?; Some(serialize::or(&left, &right, NoWitness)) } Policy::Threshold(k, ref subs) => { let k = u32::try_from(k).expect("can have k at most 2^32 in a threshold"); let subs = subs .iter() - .map(Self::serialize_no_witness) + .map(|sub| sub.serialize_no_witness(inference_context)) .collect::>>()?; let wits = iter::repeat(NoWitness) .take(subs.len()) .collect::>(); Some(serialize::threshold(k, &subs, &wits)) } - Policy::Assembly(cmr) => N::assembly(cmr), + Policy::Assembly(cmr) => N::assembly(inference_context, cmr), } } /// Return the program commitment of the policy. pub fn commit(&self) -> Option>> { - let construct: Arc> = self.serialize_no_witness()?; + let construct: Arc> = + self.serialize_no_witness(&types::Context::new())?; let commit = construct.finalize_types().expect("policy has sound types"); Some(commit) } /// Return the CMR of the policy. pub fn cmr(&self) -> Cmr { - self.serialize_no_witness::() + self.serialize_no_witness::(&types::Context::new()) .expect("CMR is defined for asm fragment") .cmr } diff --git a/src/policy/satisfy.rs b/src/policy/satisfy.rs index f8888a29..94d43be2 100644 --- a/src/policy/satisfy.rs +++ b/src/policy/satisfy.rs @@ -4,6 +4,7 @@ use crate::analysis::Cost; use crate::jet::Elements; use crate::node::{RedeemNode, WitnessNode}; use crate::policy::ToXOnlyPubkey; +use crate::types; use crate::{Cmr, Error, Policy, Value}; use elements::bitcoin; @@ -93,19 +94,22 @@ impl Satisfier for elements::LockTime { impl Policy { fn satisfy_internal>( &self, + inference_context: &types::Context, satisfier: &S, ) -> Result>, Error> { let node = match *self { - Policy::Unsatisfiable(entropy) => super::serialize::unsatisfiable(entropy), - Policy::Trivial => super::serialize::trivial(), + Policy::Unsatisfiable(entropy) => { + super::serialize::unsatisfiable(inference_context, entropy) + } + Policy::Trivial => super::serialize::trivial(inference_context), Policy::Key(ref key) => { let sig_wit = satisfier .lookup_tap_leaf_script_sig(key, &TapLeafHash::all_zeros()) .map(|sig| Value::u512_from_slice(sig.sig.as_ref())); - super::serialize::key(key, sig_wit) + super::serialize::key(inference_context, key, sig_wit) } Policy::After(n) => { - let node = super::serialize::after::>(n); + let node = super::serialize::after::>(inference_context, n); let height = Height::from_consensus(n).expect("timelock is valid"); if satisfier.check_after(elements::LockTime::Blocks(height)) { node @@ -114,7 +118,7 @@ impl Policy { } } Policy::Older(n) => { - let node = super::serialize::older::>(n); + let node = super::serialize::older::>(inference_context, n); if satisfier.check_older(elements::Sequence((n).into())) { node } else { @@ -125,22 +129,22 @@ impl Policy { let preimage_wit = satisfier .lookup_sha256(hash) .map(|preimage| Value::u256_from_slice(preimage.as_ref())); - super::serialize::sha256::(hash, preimage_wit) + super::serialize::sha256::(inference_context, hash, preimage_wit) } Policy::And { ref left, ref right, } => { - let left = left.satisfy_internal(satisfier)?; - let right = right.satisfy_internal(satisfier)?; + let left = left.satisfy_internal(inference_context, satisfier)?; + let right = right.satisfy_internal(inference_context, satisfier)?; super::serialize::and(&left, &right) } Policy::Or { ref left, ref right, } => { - let left = left.satisfy_internal(satisfier)?; - let right = right.satisfy_internal(satisfier)?; + let left = left.satisfy_internal(inference_context, satisfier)?; + let right = right.satisfy_internal(inference_context, satisfier)?; let take_right = match (left.must_prune(), right.must_prune()) { (false, false) => { @@ -165,7 +169,7 @@ impl Policy { Policy::Threshold(k, ref subs) => { let nodes: Result>>, Error> = subs .iter() - .map(|sub| sub.satisfy_internal(satisfier)) + .map(|sub| sub.satisfy_internal(inference_context, satisfier)) .collect(); let mut nodes = nodes?; let mut costs = vec![Cost::CONSENSUS_MAX; subs.len()]; @@ -215,7 +219,7 @@ impl Policy { &self, satisfier: &S, ) -> Result>, Error> { - let witnode = self.satisfy_internal(satisfier)?; + let witnode = self.satisfy_internal(&types::Context::new(), satisfier)?; if witnode.must_prune() { Err(Error::IncompleteFinalization) } else { @@ -626,17 +630,18 @@ mod tests { #[test] fn satisfy_asm() { + let ctx = types::Context::new(); let env = ElementsEnv::dummy(); let mut satisfier = get_satisfier(&env); let mut assert_branch = |witness0: Arc, witness1: Arc| { let asm_program = serialize::verify_bexp( &Arc::>::pair( - &Arc::>::witness(Some(witness0.clone())), - &Arc::>::witness(Some(witness1.clone())), + &Arc::>::witness(&ctx, Some(witness0.clone())), + &Arc::>::witness(&ctx, Some(witness1.clone())), ) .expect("sound types"), - &Arc::>::jet(Elements::Eq8), + &Arc::>::jet(&ctx, Elements::Eq8), ); let cmr = asm_program.cmr(); satisfier.assembly.insert(cmr, asm_program); diff --git a/src/policy/serialize.rs b/src/policy/serialize.rs index ad6724fc..2c586657 100644 --- a/src/policy/serialize.rs +++ b/src/policy/serialize.rs @@ -5,6 +5,7 @@ use crate::jet::{Elements, Jet}; use crate::merkle::cmr::ConstructibleCmr; use crate::node::{CoreConstructible, JetConstructible, WitnessConstructible}; +use crate::types; use crate::{Cmr, ConstructNode, ToXOnlyPubkey}; use crate::{FailEntropy, Value}; @@ -16,69 +17,72 @@ pub trait AssemblyConstructible: Sized { /// Construct the assembly fragment with the given CMR. /// /// The construction fails if the CMR alone is not enough information to construct the object. - fn assembly(cmr: Cmr) -> Option; + fn assembly(inference_context: &types::Context, cmr: Cmr) -> Option; } impl AssemblyConstructible for ConstructibleCmr { - fn assembly(cmr: Cmr) -> Option { - Some(ConstructibleCmr { cmr }) + fn assembly(inference_context: &types::Context, cmr: Cmr) -> Option { + Some(ConstructibleCmr { + cmr, + inference_context: inference_context.shallow_clone(), + }) } } impl AssemblyConstructible for Arc> { - fn assembly(_cmr: Cmr) -> Option { + fn assembly(_: &types::Context, _cmr: Cmr) -> Option { None } } -pub fn unsatisfiable(entropy: FailEntropy) -> N +pub fn unsatisfiable(inference_context: &types::Context, entropy: FailEntropy) -> N where N: CoreConstructible, { - N::fail(entropy) + N::fail(inference_context, entropy) } -pub fn trivial() -> N +pub fn trivial(inference_context: &types::Context) -> N where N: CoreConstructible, { - N::unit() + N::unit(inference_context) } -pub fn key(key: &Pk, witness: W) -> N +pub fn key(inference_context: &types::Context, key: &Pk, witness: W) -> N where Pk: ToXOnlyPubkey, N: CoreConstructible + JetConstructible + WitnessConstructible, { let key_value = Value::u256_from_slice(&key.to_x_only_pubkey().serialize()); - let const_key = N::const_word(key_value); - let sighash_all = N::jet(Elements::SigAllHash); + let const_key = N::const_word(inference_context, key_value); + let sighash_all = N::jet(inference_context, Elements::SigAllHash); let pair_key_msg = N::pair(&const_key, &sighash_all).expect("consistent types"); - let witness = N::witness(witness); + let witness = N::witness(inference_context, witness); let pair_key_msg_sig = N::pair(&pair_key_msg, &witness).expect("consistent types"); - let bip_0340_verify = N::jet(Elements::Bip0340Verify); + let bip_0340_verify = N::jet(inference_context, Elements::Bip0340Verify); N::comp(&pair_key_msg_sig, &bip_0340_verify).expect("consistent types") } -pub fn after(n: u32) -> N +pub fn after(inference_context: &types::Context, n: u32) -> N where N: CoreConstructible + JetConstructible, { let n_value = Value::u32(n); - let const_n = N::const_word(n_value); - let check_lock_height = N::jet(Elements::CheckLockHeight); + let const_n = N::const_word(inference_context, n_value); + let check_lock_height = N::jet(inference_context, Elements::CheckLockHeight); N::comp(&const_n, &check_lock_height).expect("consistent types") } -pub fn older(n: u16) -> N +pub fn older(inference_context: &types::Context, n: u16) -> N where N: CoreConstructible + JetConstructible, { let n_value = Value::u16(n); - let const_n = N::const_word(n_value); - let check_lock_distance = N::jet(Elements::CheckLockDistance); + let const_n = N::const_word(inference_context, n_value); + let check_lock_distance = N::jet(inference_context, Elements::CheckLockDistance); N::comp(&const_n, &check_lock_distance).expect("consistent types") } @@ -87,11 +91,11 @@ pub fn compute_sha256(witness256: &N) -> N where N: CoreConstructible + JetConstructible, { - let ctx = N::jet(Elements::Sha256Ctx8Init); + let ctx = N::jet(witness256.inference_context(), Elements::Sha256Ctx8Init); let pair_ctx_witness = N::pair(&ctx, witness256).expect("consistent types"); - let add256 = N::jet(Elements::Sha256Ctx8Add32); + let add256 = N::jet(witness256.inference_context(), Elements::Sha256Ctx8Add32); let digest_ctx = N::comp(&pair_ctx_witness, &add256).expect("consistent types"); - let finalize = N::jet(Elements::Sha256Ctx8Finalize); + let finalize = N::jet(witness256.inference_context(), Elements::Sha256Ctx8Finalize); N::comp(&digest_ctx, &finalize).expect("consistent types") } @@ -99,22 +103,27 @@ pub fn verify_bexp(input: &N, bexp: &N) -> N where N: CoreConstructible + JetConstructible, { + assert_eq!( + input.inference_context(), + bexp.inference_context(), + "cannot compose policy fragments with different type inference contexts", + ); let computed_bexp = N::comp(input, bexp).expect("consistent types"); - let verify = N::jet(Elements::Verify); + let verify = N::jet(input.inference_context(), Elements::Verify); N::comp(&computed_bexp, &verify).expect("consistent types") } -pub fn sha256(hash: &Pk::Sha256, witness: W) -> N +pub fn sha256(inference_context: &types::Context, hash: &Pk::Sha256, witness: W) -> N where Pk: ToXOnlyPubkey, N: CoreConstructible + JetConstructible + WitnessConstructible, { let hash_value = Value::u256_from_slice(Pk::to_sha256(hash).as_ref()); - let const_hash = N::const_word(hash_value); - let witness256 = N::witness(witness); + let const_hash = N::const_word(inference_context, hash_value); + let witness256 = N::witness(inference_context, witness); let computed_hash = compute_sha256(&witness256); let pair_hash_computed_hash = N::pair(&const_hash, &computed_hash).expect("consistent types"); - let eq256 = N::jet(Elements::Eq256); + let eq256 = N::jet(inference_context, Elements::Eq256); verify_bexp(&pair_hash_computed_hash, &eq256) } @@ -126,12 +135,12 @@ where N::comp(left, right).expect("consistent types") } -pub fn selector(witness_bit: W) -> N +pub fn selector(inference_context: &types::Context, witness_bit: W) -> N where N: CoreConstructible + WitnessConstructible, { - let witness = N::witness(witness_bit); - let unit = N::unit(); + let witness = N::witness(inference_context, witness_bit); + let unit = N::unit(inference_context); N::pair(&witness, &unit).expect("consistent types") } @@ -139,10 +148,15 @@ pub fn or(left: &N, right: &N, witness_bit: W) -> N where N: CoreConstructible + WitnessConstructible, { + assert_eq!( + left.inference_context(), + right.inference_context(), + "cannot compose policy fragments with different type inference contexts", + ); let drop_left = N::drop_(left); let drop_right = N::drop_(right); let case_left_right = N::case(&drop_left, &drop_right).expect("consistent types"); - let selector = selector(witness_bit); + let selector = selector(left.inference_context(), witness_bit); N::comp(&selector, &case_left_right).expect("consistent types") } @@ -155,14 +169,14 @@ where N: CoreConstructible + WitnessConstructible, { // 1 → 2 x 1 - let selector = selector(witness_bit); + let selector = selector(child.inference_context(), witness_bit); // 1 → 2^32 - let const_one = N::const_word(Value::u32(1)); + let const_one = N::const_word(child.inference_context(), Value::u32(1)); // 1 → 2^32 let child_one = N::comp(child, &const_one).expect("consistent types"); // 1 → 2^32 - let const_zero = N::const_word(Value::u32(0)); + let const_zero = N::const_word(child.inference_context(), Value::u32(0)); // 1 × 1 → 2^32 let drop_left = N::drop_(&const_zero); @@ -182,14 +196,19 @@ pub fn thresh_add(sum: &N, summand: &N) -> N where N: CoreConstructible + JetConstructible, { + assert_eq!( + sum.inference_context(), + summand.inference_context(), + "cannot compose policy fragments with different type inference contexts", + ); // 1 → 2^32 × 2^32 let pair_sum_summand = N::pair(sum, summand).expect("consistent types"); // 2^32 × 2^32 → 2 × 2^32 - let add32 = N::jet(Elements::Add32); + let add32 = N::jet(sum.inference_context(), Elements::Add32); // 1 → 2 x 2^32 let full_sum = N::comp(&pair_sum_summand, &add32).expect("consistent types"); // 2^32 → 2^32 - let iden = N::iden(); + let iden = N::iden(sum.inference_context()); // 2 × 2^32 → 2^32 let drop_iden = N::drop_(&iden); @@ -205,11 +224,11 @@ where N: CoreConstructible + JetConstructible, { // 1 → 2^32 - let const_k = N::const_word(Value::u32(k)); + let const_k = N::const_word(sum.inference_context(), Value::u32(k)); // 1 → 2^32 × 2^32 let pair_k_sum = N::pair(&const_k, sum).expect("consistent types"); // 2^32 × 2^32 → 2 - let eq32 = N::jet(Elements::Eq32); + let eq32 = N::jet(sum.inference_context(), Elements::Eq32); // 1 → 1 verify_bexp(&pair_k_sum, &eq32) diff --git a/src/types/arrow.rs b/src/types/arrow.rs index 08ba86f7..69faf17c 100644 --- a/src/types/arrow.rs +++ b/src/types/arrow.rs @@ -18,19 +18,30 @@ use crate::node::{ CoreConstructible, DisconnectConstructible, JetConstructible, NoDisconnect, WitnessConstructible, }; -use crate::types::{Bound, Error, Final, Type}; +use crate::types::{Bound, Context, Error, Final, Type}; use crate::{jet::Jet, Value}; use super::variable::new_name; /// A container for an expression's source and target types, whether or not /// these types are complete. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Arrow { /// The source type pub source: Type, /// The target type pub target: Type, + /// Type inference context for both types. + pub inference_context: Context, +} + +// Having `Clone` makes it easier to derive Clone on structures +// that contain Arrow, even though it is potentially confusing +// to use `.clone` to mean a shallow clone. +impl Clone for Arrow { + fn clone(&self) -> Self { + self.shallow_clone() + } } impl fmt::Display for Arrow { @@ -79,6 +90,7 @@ impl Arrow { Arrow { source: self.source.shallow_clone(), target: self.target.shallow_clone(), + inference_context: self.inference_context.shallow_clone(), } } @@ -88,9 +100,20 @@ impl Arrow { /// an assertion, which for type-inference purposes means there are no bounds /// on the missing child. /// - /// If neither child is provided, this function will not raise an error; it - /// is the responsibility of the caller to detect this case and error elsewhere. + /// # Panics + /// + /// If neither child is provided, this function will panic. fn for_case(lchild_arrow: Option<&Arrow>, rchild_arrow: Option<&Arrow>) -> Result { + if let (Some(left), Some(right)) = (lchild_arrow, rchild_arrow) { + left.inference_context.check_eq(&right.inference_context)?; + } + + let inference_context = match (lchild_arrow, rchild_arrow) { + (Some(left), _) => left.inference_context.shallow_clone(), + (_, Some(right)) => right.inference_context.shallow_clone(), + (None, None) => panic!("called `for_case` with no children"), + }; + let a = Type::free(new_name("case_a_")); let b = Type::free(new_name("case_b_")); let c = Type::free(new_name("case_c_")); @@ -120,11 +143,16 @@ impl Arrow { Ok(Arrow { source: prod_sum_a_b_c, target, + inference_context, }) } /// Helper function to combine code for the two `DisconnectConstructible` impls for [`Arrow`]. fn for_disconnect(lchild_arrow: &Arrow, rchild_arrow: &Arrow) -> Result { + lchild_arrow + .inference_context + .check_eq(&rchild_arrow.inference_context)?; + let a = Type::free(new_name("disconnect_a_")); let b = Type::free(new_name("disconnect_b_")); let c = rchild_arrow.source.shallow_clone(); @@ -146,12 +174,13 @@ impl Arrow { Ok(Arrow { source: a, target: prod_b_d, + inference_context: lchild_arrow.inference_context.shallow_clone(), }) } } impl CoreConstructible for Arrow { - fn iden() -> Self { + fn iden(inference_context: &Context) -> Self { // Throughout this module, when two types are the same, we reuse a // pointer to them rather than creating distinct types and unifying // them. This theoretically could lead to more confusing errors for @@ -161,13 +190,15 @@ impl CoreConstructible for Arrow { Arrow { source: new.shallow_clone(), target: new, + inference_context: inference_context.shallow_clone(), } } - fn unit() -> Self { + fn unit(inference_context: &Context) -> Self { Arrow { source: Type::free(new_name("unit_src_")), target: Type::unit(), + inference_context: inference_context.shallow_clone(), } } @@ -178,6 +209,7 @@ impl CoreConstructible for Arrow { child.target.shallow_clone(), Type::free(new_name("injl_tgt_")), ), + inference_context: child.inference_context.shallow_clone(), } } @@ -188,6 +220,7 @@ impl CoreConstructible for Arrow { Type::free(new_name("injr_tgt_")), child.target.shallow_clone(), ), + inference_context: child.inference_context.shallow_clone(), } } @@ -198,6 +231,7 @@ impl CoreConstructible for Arrow { Type::free(new_name("take_src_")), ), target: child.target.shallow_clone(), + inference_context: child.inference_context.shallow_clone(), } } @@ -208,15 +242,18 @@ impl CoreConstructible for Arrow { child.source.shallow_clone(), ), target: child.target.shallow_clone(), + inference_context: child.inference_context.shallow_clone(), } } fn comp(left: &Self, right: &Self) -> Result { + left.inference_context.check_eq(&right.inference_context)?; left.target .unify(&right.source, "comp combinator: left target = right source")?; Ok(Arrow { source: left.source.shallow_clone(), target: right.target.shallow_clone(), + inference_context: left.inference_context.shallow_clone(), }) } @@ -233,22 +270,25 @@ impl CoreConstructible for Arrow { } fn pair(left: &Self, right: &Self) -> Result { + left.inference_context.check_eq(&right.inference_context)?; left.source .unify(&right.source, "pair combinator: left source = right source")?; Ok(Arrow { source: left.source.shallow_clone(), target: Type::product(left.target.shallow_clone(), right.target.shallow_clone()), + inference_context: left.inference_context.shallow_clone(), }) } - fn fail(_: crate::FailEntropy) -> Self { + fn fail(inference_context: &Context, _: crate::FailEntropy) -> Self { Arrow { source: Type::free(new_name("fail_src_")), target: Type::free(new_name("fail_tgt_")), + inference_context: inference_context.shallow_clone(), } } - fn const_word(word: Arc) -> Self { + fn const_word(inference_context: &Context, word: Arc) -> Self { let len = word.len(); assert!(len > 0, "Words must not be the empty bitstring"); assert!(len.is_power_of_two()); @@ -256,8 +296,13 @@ impl CoreConstructible for Arrow { Arrow { source: Type::unit(), target: Type::two_two_n(depth as usize), + inference_context: inference_context.shallow_clone(), } } + + fn inference_context(&self) -> &Context { + &self.inference_context + } } impl DisconnectConstructible for Arrow { @@ -273,6 +318,7 @@ impl DisconnectConstructible for Arrow { &Arrow { source: Type::free("disc_src".into()), target: Type::free("disc_tgt".into()), + inference_context: left.inference_context.shallow_clone(), }, ) } @@ -288,19 +334,21 @@ impl DisconnectConstructible> for Arrow { } impl JetConstructible for Arrow { - fn jet(jet: J) -> Self { + fn jet(inference_context: &Context, jet: J) -> Self { Arrow { source: jet.source_ty().to_type(), target: jet.target_ty().to_type(), + inference_context: inference_context.shallow_clone(), } } } impl WitnessConstructible for Arrow { - fn witness(_: W) -> Self { + fn witness(inference_context: &Context, _: W) -> Self { Arrow { source: Type::free(new_name("witness_src_")), target: Type::free(new_name("witness_tgt_")), + inference_context: inference_context.shallow_clone(), } } } diff --git a/src/types/context.rs b/src/types/context.rs new file mode 100644 index 00000000..c8585f93 --- /dev/null +++ b/src/types/context.rs @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Type Inference Context +//! +//! When constructing a Simplicity program, you must first create a type inference +//! context, in which type inference occurs incrementally during construction. Each +//! leaf node (e.g. `unit` and `iden`) must explicitly refer to the type inference +//! context, while combinator nodes (e.g. `comp`) infer the context from their +//! children, raising an error if there are multiple children whose contexts don't +//! match. +//! +//! This helps to prevent situations in which users attempt to construct multiple +//! independent programs, but types in one program accidentally refer to types in +//! the other. +//! + +use std::fmt; +use std::sync::{Arc, Mutex}; + +use super::Bound; + +/// Type inference context, or handle to a context. +/// +/// Can be cheaply cloned with [`Context::shallow_clone`]. These clones will +/// refer to the same underlying type inference context, and can be used as +/// handles to each other. The derived [`Context::clone`] has the same effect. +/// +/// There is currently no way to create an independent context with the same +/// type inference variables (i.e. a deep clone). If you need this functionality, +/// please file an issue. +#[derive(Clone, Default)] +pub struct Context { + slab: Arc>>, +} + +impl fmt::Debug for Context { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let id = Arc::as_ptr(&self.slab) as usize; + write!(f, "inference_ctx_{:08x}", id) + } +} + +impl PartialEq for Context { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.slab, &other.slab) + } +} +impl Eq for Context {} + +impl Context { + /// Creates a new empty type inference context. + pub fn new() -> Self { + Context { + slab: Arc::new(Mutex::new(vec![])), + } + } + + /// Creates a new handle to the context. + /// + /// This handle holds a reference to the underlying context and will keep + /// it alive. The context will only be dropped once all handles, including + /// the original context object, are dropped. + pub fn shallow_clone(&self) -> Self { + Self { + slab: Arc::clone(&self.slab), + } + } + + /// Checks whether two inference contexts are equal, and returns an error if not. + pub fn check_eq(&self, other: &Self) -> Result<(), super::Error> { + if self == other { + Ok(()) + } else { + Err(super::Error::InferenceContextMismatch) + } + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 9461008e..4f1f4d88 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -79,11 +79,13 @@ use std::fmt; use std::sync::Arc; pub mod arrow; +mod context; mod final_data; mod precomputed; mod union_bound; mod variable; +pub use context::Context; pub use final_data::{CompleteBound, Final}; /// Error type for simplicity @@ -104,6 +106,9 @@ pub enum Error { }, /// A type is recursive (i.e., occurs within itself), violating the "occurs check" OccursCheck { infinite_bound: Arc }, + /// Attempted to combine two nodes which had different type inference + /// contexts. This is probably a programming error. + InferenceContextMismatch, } impl fmt::Display for Error { @@ -134,6 +139,9 @@ impl fmt::Display for Error { Error::OccursCheck { infinite_bound } => { write!(f, "infinitely-sized type {}", infinite_bound,) } + Error::InferenceContextMismatch => { + f.write_str("attempted to combine two nodes with different type inference contexts") + } } } } @@ -606,8 +614,10 @@ mod tests { #[test] fn inference_failure() { + let ctx = Context::new(); + // unit: A -> 1 - let unit = Arc::>::unit(); // 1 -> 1 + let unit = Arc::>::unit(&ctx); // 1 -> 1 // Force unit to be 1->1 Arc::>::comp(&unit, &unit).unwrap(); @@ -623,7 +633,8 @@ mod tests { #[test] fn memory_leak() { - let iden = Arc::>::iden(); + let ctx = Context::new(); + let iden = Arc::>::iden(&ctx); let drop = Arc::>::drop_(&iden); let case = Arc::>::case(&iden, &drop).unwrap(); From c304a9a7a8feadf5c4e5cd745e1b95a159dd30cf Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 30 Jun 2024 14:27:14 +0000 Subject: [PATCH 07/16] types: make `bind` and `unify` go through Context This is again an API-only change. --- src/human_encoding/named_node.rs | 44 +++++++++++++++++++++----------- src/node/construct.rs | 17 +++++++----- src/node/witness.rs | 19 ++++++++------ src/types/arrow.rs | 34 ++++++++++++++++-------- src/types/context.rs | 17 +++++++++++- src/types/mod.rs | 4 +-- 6 files changed, 93 insertions(+), 42 deletions(-) diff --git a/src/human_encoding/named_node.rs b/src/human_encoding/named_node.rs index 4e9ee4dd..f5db9acc 100644 --- a/src/human_encoding/named_node.rs +++ b/src/human_encoding/named_node.rs @@ -9,7 +9,7 @@ use crate::node::{ self, Commit, CommitData, CommitNode, Converter, Inner, NoDisconnect, NoWitness, Node, Witness, WitnessData, }; -use crate::node::{Construct, ConstructData, Constructible}; +use crate::node::{Construct, ConstructData, Constructible as _, CoreConstructible as _}; use crate::types; use crate::types::arrow::{Arrow, FinalArrow}; use crate::{encode, Value, WitnessNode}; @@ -299,6 +299,11 @@ impl NamedConstructNode { self.cached_data().internal.arrow() } + /// Accessor for the node's type inference context. + pub fn inference_context(&self) -> &types::Context { + self.cached_data().internal.inference_context() + } + /// Finalizes the types of the underlying [`crate::ConstructNode`]. pub fn finalize_types_main(&self) -> Result>, ErrorSet> { self.finalize_types_inner(true) @@ -390,17 +395,23 @@ impl NamedConstructNode { .map_disconnect(|_| &NoDisconnect) .copy_witness(); + let ctx = data.node.inference_context(); + if !self.for_main { // For non-`main` fragments, treat the ascriptions as normative, and apply them // before finalizing the type. let arrow = data.node.arrow(); for ty in data.node.cached_data().user_source_types.as_ref() { - if let Err(e) = arrow.source.unify(ty, "binding source type annotation") { + if let Err(e) = + ctx.unify(&arrow.source, ty, "binding source type annotation") + { self.errors.add(data.node.position(), e); } } for ty in data.node.cached_data().user_target_types.as_ref() { - if let Err(e) = arrow.target.unify(ty, "binding target type annotation") { + if let Err(e) = + ctx.unify(&arrow.target, ty, "binding target type annotation") + { self.errors.add(data.node.position(), e); } } @@ -419,13 +430,15 @@ impl NamedConstructNode { // determined the type. let source_ty = types::Type::complete(Arc::clone(&commit_data.arrow().source)); for ty in data.node.cached_data().user_source_types.as_ref() { - if let Err(e) = source_ty.unify(ty, "binding source type annotation") { + if let Err(e) = ctx.unify(&source_ty, ty, "binding source type annotation") + { self.errors.add(data.node.position(), e); } } let target_ty = types::Type::complete(Arc::clone(&commit_data.arrow().target)); for ty in data.node.cached_data().user_target_types.as_ref() { - if let Err(e) = target_ty.unify(ty, "binding target type annotation") { + if let Err(e) = ctx.unify(&target_ty, ty, "binding target type annotation") + { self.errors.add(data.node.position(), e); } } @@ -446,22 +459,23 @@ impl NamedConstructNode { }; if for_main { + let ctx = self.inference_context(); let unit_ty = types::Type::unit(); if self.cached_data().user_source_types.is_empty() { - if let Err(e) = self - .arrow() - .source - .unify(&unit_ty, "setting root source to unit") - { + if let Err(e) = ctx.unify( + &self.arrow().source, + &unit_ty, + "setting root source to unit", + ) { finalizer.errors.add(self.position(), e); } } if self.cached_data().user_target_types.is_empty() { - if let Err(e) = self - .arrow() - .target - .unify(&unit_ty, "setting root source to unit") - { + if let Err(e) = ctx.unify( + &self.arrow().target, + &unit_ty, + "setting root target to unit", + ) { finalizer.errors.add(self.position(), e); } } diff --git a/src/node/construct.rs b/src/node/construct.rs index 49c96a8c..06a216b3 100644 --- a/src/node/construct.rs +++ b/src/node/construct.rs @@ -52,13 +52,18 @@ impl ConstructNode { /// Sets the source and target type of the node to unit pub fn set_arrow_to_program(&self) -> Result<(), types::Error> { + let ctx = self.data.inference_context(); let unit_ty = types::Type::unit(); - self.arrow() - .source - .unify(&unit_ty, "setting root source to unit")?; - self.arrow() - .target - .unify(&unit_ty, "setting root target to unit")?; + ctx.unify( + &self.arrow().source, + &unit_ty, + "setting root source to unit", + )?; + ctx.unify( + &self.arrow().target, + &unit_ty, + "setting root target to unit", + )?; Ok(()) } diff --git a/src/node/witness.rs b/src/node/witness.rs index b1aea3f3..93b9043a 100644 --- a/src/node/witness.rs +++ b/src/node/witness.rs @@ -205,15 +205,18 @@ impl WitnessNode { // 1. First, prune everything that we can let pruned_self = self.prune_and_retype(); // 2. Then, set the root arrow to 1->1 + let ctx = pruned_self.inference_context(); let unit_ty = types::Type::unit(); - pruned_self - .arrow() - .source - .unify(&unit_ty, "setting root source to unit")?; - pruned_self - .arrow() - .target - .unify(&unit_ty, "setting root source to unit")?; + ctx.unify( + &pruned_self.arrow().source, + &unit_ty, + "setting root source to unit", + )?; + ctx.unify( + &pruned_self.arrow().target, + &unit_ty, + "setting root target to unit", + )?; // 3. Then attempt to convert the whole program to a RedeemNode. // Despite all of the above this can still fail due to the diff --git a/src/types/arrow.rs b/src/types/arrow.rs index 69faf17c..ddae7c3e 100644 --- a/src/types/arrow.rs +++ b/src/types/arrow.rs @@ -123,18 +123,23 @@ impl Arrow { let target = Type::free(String::new()); if let Some(lchild_arrow) = lchild_arrow { - lchild_arrow.source.bind( + inference_context.bind( + &lchild_arrow.source, Arc::new(Bound::Product(a, c.shallow_clone())), "case combinator: left source = A × C", )?; - target.unify(&lchild_arrow.target, "").unwrap(); + inference_context + .unify(&target, &lchild_arrow.target, "") + .unwrap(); } if let Some(rchild_arrow) = rchild_arrow { - rchild_arrow.source.bind( + inference_context.bind( + &rchild_arrow.source, Arc::new(Bound::Product(b, c)), "case combinator: left source = B × C", )?; - target.unify( + inference_context.unify( + &target, &rchild_arrow.target, "case combinator: left target = right target", )?; @@ -153,6 +158,7 @@ impl Arrow { .inference_context .check_eq(&rchild_arrow.inference_context)?; + let ctx = lchild_arrow.inference_context(); let a = Type::free(new_name("disconnect_a_")); let b = Type::free(new_name("disconnect_b_")); let c = rchild_arrow.source.shallow_clone(); @@ -162,11 +168,13 @@ impl Arrow { let prod_b_c = Bound::Product(b.shallow_clone(), c); let prod_b_d = Type::product(b, d); - lchild_arrow.source.bind( + ctx.bind( + &lchild_arrow.source, Arc::new(prod_256_a), "disconnect combinator: left source = 2^256 × A", )?; - lchild_arrow.target.bind( + ctx.bind( + &lchild_arrow.target, Arc::new(prod_b_c), "disconnect combinator: left target = B × C", )?; @@ -248,8 +256,11 @@ impl CoreConstructible for Arrow { fn comp(left: &Self, right: &Self) -> Result { left.inference_context.check_eq(&right.inference_context)?; - left.target - .unify(&right.source, "comp combinator: left target = right source")?; + left.inference_context.unify( + &left.target, + &right.source, + "comp combinator: left target = right source", + )?; Ok(Arrow { source: left.source.shallow_clone(), target: right.target.shallow_clone(), @@ -271,8 +282,11 @@ impl CoreConstructible for Arrow { fn pair(left: &Self, right: &Self) -> Result { left.inference_context.check_eq(&right.inference_context)?; - left.source - .unify(&right.source, "pair combinator: left source = right source")?; + left.inference_context.unify( + &left.source, + &right.source, + "pair combinator: left source = right source", + )?; Ok(Arrow { source: left.source.shallow_clone(), target: Type::product(left.target.shallow_clone(), right.target.shallow_clone()), diff --git a/src/types/context.rs b/src/types/context.rs index c8585f93..ed9ea81d 100644 --- a/src/types/context.rs +++ b/src/types/context.rs @@ -17,7 +17,7 @@ use std::fmt; use std::sync::{Arc, Mutex}; -use super::Bound; +use super::{Bound, Error, Type}; /// Type inference context, or handle to a context. /// @@ -74,4 +74,19 @@ impl Context { Err(super::Error::InferenceContextMismatch) } } + + /// Binds the type to a given bound. If this fails, attach the provided + /// hint to the error. + /// + /// Fails if the type has an existing incompatible bound. + pub fn bind(&self, existing: &Type, new: Arc, hint: &'static str) -> Result<(), Error> { + existing.bind(new, hint) + } + + /// Unify the type with another one. + /// + /// Fails if the bounds on the two types are incompatible + pub fn unify(&self, ty1: &Type, ty2: &Type, hint: &'static str) -> Result<(), Error> { + ty1.unify(ty2, hint) + } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 4f1f4d88..d6f5f631 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -440,7 +440,7 @@ impl Type { /// hint to the error. /// /// Fails if the type has an existing incompatible bound. - pub fn bind(&self, bound: Arc, hint: &'static str) -> Result<(), Error> { + fn bind(&self, bound: Arc, hint: &'static str) -> Result<(), Error> { let root = self.bound.root(); root.bind(bound, hint) } @@ -448,7 +448,7 @@ impl Type { /// Unify the type with another one. /// /// Fails if the bounds on the two types are incompatible - pub fn unify(&self, other: &Self, hint: &'static str) -> Result<(), Error> { + fn unify(&self, other: &Self, hint: &'static str) -> Result<(), Error> { self.bound.unify(&other.bound, |x_bound, y_bound| { x_bound.bind(y_bound.get(), hint) }) From 2b400713010c75656595bc9fccc3eac9c69ea108 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sat, 29 Jun 2024 18:50:12 +0000 Subject: [PATCH 08/16] types: add &Context to type constructors This is again an "API only" change. Also fixes a couple code-cleanliness things in the benches/ crate. --- .../benches/elements/data_structures.rs | 6 +- jets-bench/benches/elements/main.rs | 6 +- src/human_encoding/named_node.rs | 8 ++- src/human_encoding/parse/ast.rs | 14 ++-- src/human_encoding/parse/mod.rs | 6 +- src/jet/type_name.rs | 6 +- src/merkle/tmr.rs | 6 +- src/node/construct.rs | 2 +- src/node/witness.rs | 2 +- src/types/arrow.rs | 64 +++++++++---------- src/types/mod.rs | 14 ++-- 11 files changed, 69 insertions(+), 65 deletions(-) diff --git a/jets-bench/benches/elements/data_structures.rs b/jets-bench/benches/elements/data_structures.rs index dd857578..f0c9ffba 100644 --- a/jets-bench/benches/elements/data_structures.rs +++ b/jets-bench/benches/elements/data_structures.rs @@ -4,9 +4,8 @@ use bitcoin::secp256k1; use elements::Txid; use rand::{thread_rng, RngCore}; -pub use simplicity::hashes::sha256; use simplicity::{ - bitcoin, elements, hashes::Hash, hex::FromHex, types::Type, BitIter, Error, Value, + bitcoin, elements, hashes::Hash, hex::FromHex, types::{self, Type}, BitIter, Error, Value, }; use std::sync::Arc; @@ -57,7 +56,8 @@ pub fn var_len_buf_from_slice(v: &[u8], mut n: usize) -> Result, Erro assert!(n < 16); assert!(v.len() < (1 << (n + 1))); let mut iter = BitIter::new(v.iter().copied()); - let types = Type::powers_of_two(n); // size n + 1 + let ctx = types::Context::new(); + let types = Type::powers_of_two(&ctx, n); // size n + 1 let mut res = None; while n > 0 { let v = if v.len() >= (1 << (n + 1)) { diff --git a/jets-bench/benches/elements/main.rs b/jets-bench/benches/elements/main.rs index efc3f6e4..25761d5c 100644 --- a/jets-bench/benches/elements/main.rs +++ b/jets-bench/benches/elements/main.rs @@ -93,8 +93,8 @@ impl ElementsBenchEnvType { } fn jet_arrow(jet: Elements) -> (Arc, Arc) { - let src_ty = jet.source_ty().to_type().final_data().unwrap(); - let tgt_ty = jet.target_ty().to_type().final_data().unwrap(); + let src_ty = jet.source_ty().to_final(); + let tgt_ty = jet.target_ty().to_final(); (src_ty, tgt_ty) } @@ -302,7 +302,7 @@ fn bench(c: &mut Criterion) { let keypair = bitcoin::key::Keypair::new(&secp_ctx, &mut thread_rng()); let xpk = bitcoin::key::XOnlyPublicKey::from_keypair(&keypair); - let msg = bitcoin::secp256k1::Message::from_slice(&rand::random::<[u8; 32]>()).unwrap(); + let msg = bitcoin::secp256k1::Message::from_digest_slice(&rand::random::<[u8; 32]>()).unwrap(); let sig = secp_ctx.sign_schnorr(&msg, &keypair); let xpk_value = Value::u256_from_slice(&xpk.0.serialize()); let sig_value = Value::u512_from_slice(sig.as_ref()); diff --git a/src/human_encoding/named_node.rs b/src/human_encoding/named_node.rs index f5db9acc..8aeccec2 100644 --- a/src/human_encoding/named_node.rs +++ b/src/human_encoding/named_node.rs @@ -428,14 +428,16 @@ impl NamedConstructNode { if self.for_main { // For `main`, only apply type ascriptions *after* inference has completely // determined the type. - let source_ty = types::Type::complete(Arc::clone(&commit_data.arrow().source)); + let source_ty = + types::Type::complete(ctx, Arc::clone(&commit_data.arrow().source)); for ty in data.node.cached_data().user_source_types.as_ref() { if let Err(e) = ctx.unify(&source_ty, ty, "binding source type annotation") { self.errors.add(data.node.position(), e); } } - let target_ty = types::Type::complete(Arc::clone(&commit_data.arrow().target)); + let target_ty = + types::Type::complete(ctx, Arc::clone(&commit_data.arrow().target)); for ty in data.node.cached_data().user_target_types.as_ref() { if let Err(e) = ctx.unify(&target_ty, ty, "binding target type annotation") { @@ -460,7 +462,7 @@ impl NamedConstructNode { if for_main { let ctx = self.inference_context(); - let unit_ty = types::Type::unit(); + let unit_ty = types::Type::unit(ctx); if self.cached_data().user_source_types.is_empty() { if let Err(e) = ctx.unify( &self.arrow().source, diff --git a/src/human_encoding/parse/ast.rs b/src/human_encoding/parse/ast.rs index f58e9d4e..6cf5483d 100644 --- a/src/human_encoding/parse/ast.rs +++ b/src/human_encoding/parse/ast.rs @@ -82,14 +82,14 @@ pub enum Type { impl Type { /// Convert to a Simplicity type - pub fn reify(self) -> types::Type { + pub fn reify(self, ctx: &types::Context) -> types::Type { match self { - Type::Name(s) => types::Type::free(s), - Type::One => types::Type::unit(), - Type::Two => types::Type::sum(types::Type::unit(), types::Type::unit()), - Type::Product(left, right) => types::Type::product(left.reify(), right.reify()), - Type::Sum(left, right) => types::Type::sum(left.reify(), right.reify()), - Type::TwoTwoN(n) => types::Type::two_two_n(n as usize), // cast OK as we are only using tiny numbers + Type::Name(s) => types::Type::free(ctx, s), + Type::One => types::Type::unit(ctx), + Type::Two => types::Type::sum(types::Type::unit(ctx), types::Type::unit(ctx)), + Type::Product(left, right) => types::Type::product(left.reify(ctx), right.reify(ctx)), + Type::Sum(left, right) => types::Type::sum(left.reify(ctx), right.reify(ctx)), + Type::TwoTwoN(n) => types::Type::two_two_n(ctx, n as usize), // cast OK as we are only using tiny numbers } } } diff --git a/src/human_encoding/parse/mod.rs b/src/human_encoding/parse/mod.rs index 61919d9f..21a4e7d5 100644 --- a/src/human_encoding/parse/mod.rs +++ b/src/human_encoding/parse/mod.rs @@ -181,6 +181,7 @@ pub fn parse( program: &str, ) -> Result, Arc>>, ErrorSet> { let mut errors = ErrorSet::new(); + let inference_context = types::Context::new(); // ** // Step 1: Read expressions into HashMap, checking for dupes and illegal names. // ** @@ -205,10 +206,10 @@ pub fn parse( } } if let Some(ty) = line.arrow.0 { - entry.add_source_type(ty.reify()); + entry.add_source_type(ty.reify(&inference_context)); } if let Some(ty) = line.arrow.1 { - entry.add_target_type(ty.reify()); + entry.add_target_type(ty.reify(&inference_context)); } } @@ -419,7 +420,6 @@ pub fn parse( drop(unresolved_map); // ** Step 3: convert each DAG of names/expressions into a DAG of NamedNodes. - let inference_context = types::Context::new(); let mut roots = HashMap::, Arc>>::new(); for (name, expr) in &resolved_map { if expr.in_degree.load(Ordering::SeqCst) > 0 { diff --git a/src/jet/type_name.rs b/src/jet/type_name.rs index de6027e3..092f8222 100644 --- a/src/jet/type_name.rs +++ b/src/jet/type_name.rs @@ -4,7 +4,7 @@ //! //! Source and target types of jet nodes need to be specified manually. -use crate::types::{Final, Type}; +use crate::types::{self, Final, Type}; use std::cmp; use std::sync::Arc; @@ -32,8 +32,8 @@ pub struct TypeName(pub &'static [u8]); impl TypeName { /// Convert the type name into a type. - pub fn to_type(&self) -> Type { - Type::complete(self.to_final()) + pub fn to_type(&self, ctx: &types::Context) -> Type { + Type::complete(ctx, self.to_final()) } /// Convert the type name into a finalized type. diff --git a/src/merkle/tmr.rs b/src/merkle/tmr.rs index f36d6268..1c468f69 100644 --- a/src/merkle/tmr.rs +++ b/src/merkle/tmr.rs @@ -257,9 +257,11 @@ impl Tmr { #[cfg(test)] mod tests { - use super::super::bip340_iv; use super::*; + use crate::merkle::bip340_iv; + use crate::types; + #[test] fn const_ivs() { assert_eq!( @@ -280,7 +282,7 @@ mod tests { #[allow(clippy::needless_range_loop)] fn const_powers_of_2() { let n = Tmr::POWERS_OF_TWO.len(); - let types = crate::types::Type::powers_of_two(n); + let types = crate::types::Type::powers_of_two(&types::Context::new(), n); for i in 0..n { assert_eq!(Some(Tmr::POWERS_OF_TWO[i]), types[i].tmr()); } diff --git a/src/node/construct.rs b/src/node/construct.rs index 06a216b3..2b57de20 100644 --- a/src/node/construct.rs +++ b/src/node/construct.rs @@ -53,7 +53,7 @@ impl ConstructNode { /// Sets the source and target type of the node to unit pub fn set_arrow_to_program(&self) -> Result<(), types::Error> { let ctx = self.data.inference_context(); - let unit_ty = types::Type::unit(); + let unit_ty = types::Type::unit(ctx); ctx.unify( &self.arrow().source, &unit_ty, diff --git a/src/node/witness.rs b/src/node/witness.rs index 93b9043a..b9748c58 100644 --- a/src/node/witness.rs +++ b/src/node/witness.rs @@ -206,7 +206,7 @@ impl WitnessNode { let pruned_self = self.prune_and_retype(); // 2. Then, set the root arrow to 1->1 let ctx = pruned_self.inference_context(); - let unit_ty = types::Type::unit(); + let unit_ty = types::Type::unit(ctx); ctx.unify( &pruned_self.arrow().source, &unit_ty, diff --git a/src/types/arrow.rs b/src/types/arrow.rs index ddae7c3e..40d47678 100644 --- a/src/types/arrow.rs +++ b/src/types/arrow.rs @@ -108,37 +108,35 @@ impl Arrow { left.inference_context.check_eq(&right.inference_context)?; } - let inference_context = match (lchild_arrow, rchild_arrow) { + let ctx = match (lchild_arrow, rchild_arrow) { (Some(left), _) => left.inference_context.shallow_clone(), (_, Some(right)) => right.inference_context.shallow_clone(), (None, None) => panic!("called `for_case` with no children"), }; - let a = Type::free(new_name("case_a_")); - let b = Type::free(new_name("case_b_")); - let c = Type::free(new_name("case_c_")); + let a = Type::free(&ctx, new_name("case_a_")); + let b = Type::free(&ctx, new_name("case_b_")); + let c = Type::free(&ctx, new_name("case_c_")); let sum_a_b = Type::sum(a.shallow_clone(), b.shallow_clone()); let prod_sum_a_b_c = Type::product(sum_a_b, c.shallow_clone()); - let target = Type::free(String::new()); + let target = Type::free(&ctx, String::new()); if let Some(lchild_arrow) = lchild_arrow { - inference_context.bind( + ctx.bind( &lchild_arrow.source, Arc::new(Bound::Product(a, c.shallow_clone())), "case combinator: left source = A × C", )?; - inference_context - .unify(&target, &lchild_arrow.target, "") - .unwrap(); + ctx.unify(&target, &lchild_arrow.target, "").unwrap(); } if let Some(rchild_arrow) = rchild_arrow { - inference_context.bind( + ctx.bind( &rchild_arrow.source, Arc::new(Bound::Product(b, c)), "case combinator: left source = B × C", )?; - inference_context.unify( + ctx.unify( &target, &rchild_arrow.target, "case combinator: left target = right target", @@ -148,7 +146,7 @@ impl Arrow { Ok(Arrow { source: prod_sum_a_b_c, target, - inference_context, + inference_context: ctx, }) } @@ -159,12 +157,12 @@ impl Arrow { .check_eq(&rchild_arrow.inference_context)?; let ctx = lchild_arrow.inference_context(); - let a = Type::free(new_name("disconnect_a_")); - let b = Type::free(new_name("disconnect_b_")); + let a = Type::free(ctx, new_name("disconnect_a_")); + let b = Type::free(ctx, new_name("disconnect_b_")); let c = rchild_arrow.source.shallow_clone(); let d = rchild_arrow.target.shallow_clone(); - let prod_256_a = Bound::Product(Type::two_two_n(8), a.shallow_clone()); + let prod_256_a = Bound::Product(Type::two_two_n(ctx, 8), a.shallow_clone()); let prod_b_c = Bound::Product(b.shallow_clone(), c); let prod_b_d = Type::product(b, d); @@ -194,7 +192,7 @@ impl CoreConstructible for Arrow { // them. This theoretically could lead to more confusing errors for // the user during type inference, but in practice type inference // is completely opaque and there's no harm in making it moreso. - let new = Type::free(new_name("iden_src_")); + let new = Type::free(inference_context, new_name("iden_src_")); Arrow { source: new.shallow_clone(), target: new, @@ -204,8 +202,8 @@ impl CoreConstructible for Arrow { fn unit(inference_context: &Context) -> Self { Arrow { - source: Type::free(new_name("unit_src_")), - target: Type::unit(), + source: Type::free(inference_context, new_name("unit_src_")), + target: Type::unit(inference_context), inference_context: inference_context.shallow_clone(), } } @@ -215,7 +213,7 @@ impl CoreConstructible for Arrow { source: child.source.shallow_clone(), target: Type::sum( child.target.shallow_clone(), - Type::free(new_name("injl_tgt_")), + Type::free(&child.inference_context, new_name("injl_tgt_")), ), inference_context: child.inference_context.shallow_clone(), } @@ -225,7 +223,7 @@ impl CoreConstructible for Arrow { Arrow { source: child.source.shallow_clone(), target: Type::sum( - Type::free(new_name("injr_tgt_")), + Type::free(&child.inference_context, new_name("injr_tgt_")), child.target.shallow_clone(), ), inference_context: child.inference_context.shallow_clone(), @@ -236,7 +234,7 @@ impl CoreConstructible for Arrow { Arrow { source: Type::product( child.source.shallow_clone(), - Type::free(new_name("take_src_")), + Type::free(&child.inference_context, new_name("take_src_")), ), target: child.target.shallow_clone(), inference_context: child.inference_context.shallow_clone(), @@ -246,7 +244,7 @@ impl CoreConstructible for Arrow { fn drop_(child: &Self) -> Self { Arrow { source: Type::product( - Type::free(new_name("drop_src_")), + Type::free(&child.inference_context, new_name("drop_src_")), child.source.shallow_clone(), ), target: child.target.shallow_clone(), @@ -296,8 +294,8 @@ impl CoreConstructible for Arrow { fn fail(inference_context: &Context, _: crate::FailEntropy) -> Self { Arrow { - source: Type::free(new_name("fail_src_")), - target: Type::free(new_name("fail_tgt_")), + source: Type::free(inference_context, new_name("fail_src_")), + target: Type::free(inference_context, new_name("fail_tgt_")), inference_context: inference_context.shallow_clone(), } } @@ -308,8 +306,8 @@ impl CoreConstructible for Arrow { assert!(len.is_power_of_two()); let depth = word.len().trailing_zeros(); Arrow { - source: Type::unit(), - target: Type::two_two_n(depth as usize), + source: Type::unit(inference_context), + target: Type::two_two_n(inference_context, depth as usize), inference_context: inference_context.shallow_clone(), } } @@ -327,11 +325,13 @@ impl DisconnectConstructible for Arrow { impl DisconnectConstructible for Arrow { fn disconnect(left: &Self, _: &NoDisconnect) -> Result { + let source = Type::free(&left.inference_context, "disc_src".into()); + let target = Type::free(&left.inference_context, "disc_tgt".into()); Self::for_disconnect( left, &Arrow { - source: Type::free("disc_src".into()), - target: Type::free("disc_tgt".into()), + source, + target, inference_context: left.inference_context.shallow_clone(), }, ) @@ -350,8 +350,8 @@ impl DisconnectConstructible> for Arrow { impl JetConstructible for Arrow { fn jet(inference_context: &Context, jet: J) -> Self { Arrow { - source: jet.source_ty().to_type(), - target: jet.target_ty().to_type(), + source: jet.source_ty().to_type(inference_context), + target: jet.target_ty().to_type(inference_context), inference_context: inference_context.shallow_clone(), } } @@ -360,8 +360,8 @@ impl JetConstructible for Arrow { impl WitnessConstructible for Arrow { fn witness(inference_context: &Context, _: W) -> Self { Arrow { - source: Type::free(new_name("witness_src_")), - target: Type::free(new_name("witness_tgt_")), + source: Type::free(inference_context, new_name("witness_src_")), + target: Type::free(inference_context, new_name("witness_tgt_")), inference_context: inference_context.shallow_clone(), } } diff --git a/src/types/mod.rs b/src/types/mod.rs index d6f5f631..57c6ac99 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -397,20 +397,20 @@ pub struct Type { impl Type { /// Return an unbound type with the given name - pub fn free(name: String) -> Self { + pub fn free(_: &Context, name: String) -> Self { Type::from(Bound::free(name)) } /// Create the unit type. - pub fn unit() -> Self { + pub fn unit(_: &Context) -> Self { Type::from(Bound::unit()) } /// Create the type `2^(2^n)` for the given `n`. /// /// The type is precomputed and fast to access. - pub fn two_two_n(n: usize) -> Self { - Self::complete(precomputed::nth_power_of_2(n)) + pub fn two_two_n(ctx: &Context, n: usize) -> Self { + Self::complete(ctx, precomputed::nth_power_of_2(n)) } /// Create the sum of the given `left` and `right` types. @@ -424,7 +424,7 @@ impl Type { } /// Create a complete type. - pub fn complete(final_data: Arc) -> Self { + pub fn complete(_: &Context, final_data: Arc) -> Self { Type::from(Bound::Complete(final_data)) } @@ -561,10 +561,10 @@ impl Type { } /// Return a vector containing the types 2^(2^i) for i from 0 to n-1. - pub fn powers_of_two(n: usize) -> Vec { + pub fn powers_of_two(ctx: &Context, n: usize) -> Vec { let mut ret = Vec::with_capacity(n); - let unit = Type::unit(); + let unit = Type::unit(ctx); let mut two = Type::sum(unit.shallow_clone(), unit); for _ in 0..n { ret.push(two.shallow_clone()); From 4fdb46321a81e9bfa9dbca804efdbd006d433858 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sat, 29 Jun 2024 18:52:14 +0000 Subject: [PATCH 09/16] types: add &Context to recursive type constructors The sum and product constructors don't obviously need to be passed a type context -- after all, the two child types already have a type context embedded in them. However, it turns out in practice we have a context available every single time we call these methods, and it's a little bit difficult to pull the context out of Types, since Types only contain a weak pointer to the context object, and in theory the context might have been dropped, leading to a panic. Since the user already has a context handle, just make them pass it in. Then we know it exists. --- src/human_encoding/parse/ast.rs | 17 ++++++++++++++--- src/types/arrow.rs | 16 ++++++++++++---- src/types/mod.rs | 8 ++++---- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/human_encoding/parse/ast.rs b/src/human_encoding/parse/ast.rs index 6cf5483d..d34153e2 100644 --- a/src/human_encoding/parse/ast.rs +++ b/src/human_encoding/parse/ast.rs @@ -86,9 +86,20 @@ impl Type { match self { Type::Name(s) => types::Type::free(ctx, s), Type::One => types::Type::unit(ctx), - Type::Two => types::Type::sum(types::Type::unit(ctx), types::Type::unit(ctx)), - Type::Product(left, right) => types::Type::product(left.reify(ctx), right.reify(ctx)), - Type::Sum(left, right) => types::Type::sum(left.reify(ctx), right.reify(ctx)), + Type::Two => { + let unit_ty = types::Type::unit(ctx); + types::Type::sum(ctx, unit_ty.shallow_clone(), unit_ty) + } + Type::Product(left, right) => { + let left = left.reify(ctx); + let right = right.reify(ctx); + types::Type::product(ctx, left, right) + } + Type::Sum(left, right) => { + let left = left.reify(ctx); + let right = right.reify(ctx); + types::Type::sum(ctx, left, right) + } Type::TwoTwoN(n) => types::Type::two_two_n(ctx, n as usize), // cast OK as we are only using tiny numbers } } diff --git a/src/types/arrow.rs b/src/types/arrow.rs index 40d47678..54aa40d7 100644 --- a/src/types/arrow.rs +++ b/src/types/arrow.rs @@ -118,8 +118,8 @@ impl Arrow { let b = Type::free(&ctx, new_name("case_b_")); let c = Type::free(&ctx, new_name("case_c_")); - let sum_a_b = Type::sum(a.shallow_clone(), b.shallow_clone()); - let prod_sum_a_b_c = Type::product(sum_a_b, c.shallow_clone()); + let sum_a_b = Type::sum(&ctx, a.shallow_clone(), b.shallow_clone()); + let prod_sum_a_b_c = Type::product(&ctx, sum_a_b, c.shallow_clone()); let target = Type::free(&ctx, String::new()); if let Some(lchild_arrow) = lchild_arrow { @@ -164,7 +164,7 @@ impl Arrow { let prod_256_a = Bound::Product(Type::two_two_n(ctx, 8), a.shallow_clone()); let prod_b_c = Bound::Product(b.shallow_clone(), c); - let prod_b_d = Type::product(b, d); + let prod_b_d = Type::product(ctx, b, d); ctx.bind( &lchild_arrow.source, @@ -212,6 +212,7 @@ impl CoreConstructible for Arrow { Arrow { source: child.source.shallow_clone(), target: Type::sum( + &child.inference_context, child.target.shallow_clone(), Type::free(&child.inference_context, new_name("injl_tgt_")), ), @@ -223,6 +224,7 @@ impl CoreConstructible for Arrow { Arrow { source: child.source.shallow_clone(), target: Type::sum( + &child.inference_context, Type::free(&child.inference_context, new_name("injr_tgt_")), child.target.shallow_clone(), ), @@ -233,6 +235,7 @@ impl CoreConstructible for Arrow { fn take(child: &Self) -> Self { Arrow { source: Type::product( + &child.inference_context, child.source.shallow_clone(), Type::free(&child.inference_context, new_name("take_src_")), ), @@ -244,6 +247,7 @@ impl CoreConstructible for Arrow { fn drop_(child: &Self) -> Self { Arrow { source: Type::product( + &child.inference_context, Type::free(&child.inference_context, new_name("drop_src_")), child.source.shallow_clone(), ), @@ -287,7 +291,11 @@ impl CoreConstructible for Arrow { )?; Ok(Arrow { source: left.source.shallow_clone(), - target: Type::product(left.target.shallow_clone(), right.target.shallow_clone()), + target: Type::product( + &left.inference_context, + left.target.shallow_clone(), + right.target.shallow_clone(), + ), inference_context: left.inference_context.shallow_clone(), }) } diff --git a/src/types/mod.rs b/src/types/mod.rs index 57c6ac99..b28ed4c1 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -414,12 +414,12 @@ impl Type { } /// Create the sum of the given `left` and `right` types. - pub fn sum(left: Self, right: Self) -> Self { + pub fn sum(_: &Context, left: Self, right: Self) -> Self { Type::from(Bound::sum(left, right)) } /// Create the product of the given `left` and `right` types. - pub fn product(left: Self, right: Self) -> Self { + pub fn product(_: &Context, left: Self, right: Self) -> Self { Type::from(Bound::product(left, right)) } @@ -565,10 +565,10 @@ impl Type { let mut ret = Vec::with_capacity(n); let unit = Type::unit(ctx); - let mut two = Type::sum(unit.shallow_clone(), unit); + let mut two = Type::sum(ctx, unit.shallow_clone(), unit); for _ in 0..n { ret.push(two.shallow_clone()); - two = Type::product(two.shallow_clone(), two); + two = Type::product(ctx, two.shallow_clone(), two); } ret } From 5c562a7a6b7815eac4d71ef4f7c791edaa4e3cb8 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 30 Jun 2024 14:13:34 +0000 Subject: [PATCH 10/16] types: abstract pointer type in union-bound algorithm Our union-bound algorithm uses some features of Arc (notably: cheap clones and pointer equality) and therefore holds an Arc. But in a later commit we will want to replace the Arcs with a different reference type, and to do so, we need to make this code more generic. --- src/types/mod.rs | 2 +- src/types/union_bound.rs | 60 ++++++++++++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/src/types/mod.rs b/src/types/mod.rs index b28ed4c1..b6a66f28 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -392,7 +392,7 @@ impl DagLike for Arc { pub struct Type { /// A set of constraints, which maintained by the union-bound algorithm and /// is progressively tightened as type inference proceeds. - bound: UbElement, + bound: UbElement>, } impl Type { diff --git a/src/types/union_bound.rs b/src/types/union_bound.rs index 2a984679..f47438d9 100644 --- a/src/types/union_bound.rs +++ b/src/types/union_bound.rs @@ -32,6 +32,30 @@ use std::sync::{Arc, Mutex}; use std::{cmp, fmt, mem}; +/// Trait describing objects that can be stored and manipulated by the union-bound +/// algorithm. +/// +/// Because the algorithm depends on identity equality (i.e. two objects being +/// exactly the same in memory) such objects need to have such a notion of +/// equality. In general this differs from the `Eq` trait which implements +/// "semantic equality". +pub trait PointerLike { + /// Whether two objects are the same. + fn ptr_eq(&self, other: &Self) -> bool; + + /// A "shallow copy" of the object. + fn shallow_clone(&self) -> Self; +} + +impl PointerLike for Arc { + fn ptr_eq(&self, other: &Self) -> bool { + Arc::ptr_eq(self, other) + } + fn shallow_clone(&self) -> Self { + Arc::clone(self) + } +} + pub struct UbElement { inner: Arc>>, } @@ -44,7 +68,7 @@ impl Clone for UbElement { } } -impl fmt::Debug for UbElement { +impl fmt::Debug for UbElement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Debug::fmt(&self.root(), f) } @@ -56,29 +80,31 @@ struct UbInner { } enum UbData { - Root(Arc), + Root(T), EqualTo(UbElement), } impl UbData { - fn shallow_clone(&self) -> Self { + fn unwrap_root(&self) -> &T { match self { - UbData::Root(x) => UbData::Root(Arc::clone(x)), - UbData::EqualTo(eq) => UbData::EqualTo(eq.shallow_clone()), + UbData::Root(ref x) => x, + UbData::EqualTo(..) => unreachable!(), } } +} - fn unwrap_root(&self) -> &Arc { +impl UbData { + fn shallow_clone(&self) -> Self { match self { - UbData::Root(ref x) => x, - UbData::EqualTo(..) => unreachable!(), + UbData::Root(x) => UbData::Root(x.shallow_clone()), + UbData::EqualTo(eq) => UbData::EqualTo(eq.shallow_clone()), } } } impl UbElement { /// Turns an existing piece of data into a singleton union-bound set. - pub fn new(data: Arc) -> Self { + pub fn new(data: T) -> Self { UbElement { inner: Arc::new(Mutex::new(UbInner { data: UbData::Root(data), @@ -92,14 +118,18 @@ impl UbElement { /// This is the same as just calling `.clone()` but has a different name to /// emphasize that what's being cloned is internally just an Arc. pub fn shallow_clone(&self) -> Self { - self.clone() + Self { + inner: Arc::clone(&self.inner), + } } +} +impl UbElement { /// Find the representative of this object in its disjoint set. - pub fn root(&self) -> Arc { + pub fn root(&self) -> T { let root = self.root_element(); let inner_lock = root.inner.lock().unwrap(); - Arc::clone(inner_lock.data.unwrap_root()) + inner_lock.data.unwrap_root().shallow_clone() } /// Find the representative of this object in its disjoint set. @@ -145,7 +175,7 @@ impl UbElement { /// to actually be equal. This is accomplished with the `bind_fn` function, /// which takes two arguments: the **new representative that will be kept** /// followed by the **old representative that will be dropped**. - pub fn unify, &Arc) -> Result<(), E>>( + pub fn unify Result<(), E>>( &self, other: &Self, bind_fn: Bind, @@ -167,7 +197,7 @@ impl UbElement { // If our two variables are not literally the same, but through // unification have become the same, we detect _this_ and exit early. - if Arc::ptr_eq(x_lock.data.unwrap_root(), y_lock.data.unwrap_root()) { + if x_lock.data.unwrap_root().ptr_eq(y_lock.data.unwrap_root()) { return Ok(()); } @@ -197,7 +227,7 @@ impl UbElement { } let x_data = match x_lock.data { - UbData::Root(ref arc) => Arc::clone(arc), + UbData::Root(ref data) => data.shallow_clone(), UbData::EqualTo(..) => unreachable!(), }; drop(x_lock); From c37cbcd629371fa1eb2c53668106c29c2925f1c1 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 30 Jun 2024 14:27:41 +0000 Subject: [PATCH 11/16] types: introduce BoundRef type, use in place of Arc in union-bound This introduces the BoundRef type, an opaque reference type which currently just holds an Arc, but which in a later commit will be tied to the type inference context and refer to data within the context. It currently preserves the get() and set() methods on BoundRef, which come from BoundMutex, but these will need to be replaced in a later commit, since eventually BoundRef by itself will not contain enough information to update its data. This is essentially an API-only change though it does move some data structures around. --- src/types/context.rs | 97 +++++++++++++++++++++++++++++++++++++++++++- src/types/mod.rs | 53 +++++++++++------------- 2 files changed, 119 insertions(+), 31 deletions(-) diff --git a/src/types/context.rs b/src/types/context.rs index ed9ea81d..92c4a107 100644 --- a/src/types/context.rs +++ b/src/types/context.rs @@ -17,7 +17,8 @@ use std::fmt; use std::sync::{Arc, Mutex}; -use super::{Bound, Error, Type}; +use super::bound_mutex::BoundMutex; +use super::{Bound, Error, Final, Type}; /// Type inference context, or handle to a context. /// @@ -55,6 +56,53 @@ impl Context { } } + /// Helper function to allocate a bound and return a reference to it. + fn alloc_bound(&self, bound: Bound) -> BoundRef { + BoundRef { + context: Arc::as_ptr(&self.slab), + index: Arc::new(BoundMutex::new(bound)), + } + } + + /// Allocate a new free type bound, and return a reference to it. + pub fn alloc_free(&self, name: String) -> BoundRef { + self.alloc_bound(Bound::Free(name)) + } + + /// Allocate a new unit type bound, and return a reference to it. + pub fn alloc_unit(&self) -> BoundRef { + self.alloc_bound(Bound::Complete(Final::unit())) + } + + /// Allocate a new unit type bound, and return a reference to it. + pub fn alloc_complete(&self, data: Arc) -> BoundRef { + self.alloc_bound(Bound::Complete(data)) + } + + /// Allocate a new sum-type bound, and return a reference to it. + /// + /// # Panics + /// + /// Panics if either of the child types are from a different inference context. + pub fn alloc_sum(&self, left: Type, right: Type) -> BoundRef { + left.bound.root().assert_matches_context(self); + right.bound.root().assert_matches_context(self); + + self.alloc_bound(Bound::sum(left, right)) + } + + /// Allocate a new product-type bound, and return a reference to it. + /// + /// # Panics + /// + /// Panics if either of the child types are from a different inference context. + pub fn alloc_product(&self, left: Type, right: Type) -> BoundRef { + left.bound.root().assert_matches_context(self); + right.bound.root().assert_matches_context(self); + + self.alloc_bound(Bound::product(left, right)) + } + /// Creates a new handle to the context. /// /// This handle holds a reference to the underlying context and will keep @@ -90,3 +138,50 @@ impl Context { ty1.unify(ty2, hint) } } + +#[derive(Debug)] +pub struct BoundRef { + context: *const Mutex>, + // Will become an index into the context in a latter commit, but for + // now we set it to an Arc to preserve semantics. + index: Arc, +} + +impl BoundRef { + pub fn assert_matches_context(&self, ctx: &Context) { + assert_eq!( + self.context, + Arc::as_ptr(&ctx.slab), + "bound was accessed from a type inference context that did not create it", + ); + } + + pub fn get(&self) -> Arc { + self.index.get() + } + + pub fn set(&self, new: Arc) { + self.index.set(new) + } + + pub fn bind(&self, bound: Arc, hint: &'static str) -> Result<(), Error> { + self.index.bind(bound, hint) + } +} + +impl super::PointerLike for BoundRef { + fn ptr_eq(&self, other: &Self) -> bool { + debug_assert_eq!( + self.context, other.context, + "tried to compare two bounds from different inference contexts" + ); + Arc::ptr_eq(&self.index, &other.index) + } + + fn shallow_clone(&self) -> Self { + BoundRef { + context: self.context, + index: Arc::clone(&self.index), + } + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index b6a66f28..9f3ed76b 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -70,7 +70,7 @@ //! or a sum or product of other types. //! -use self::union_bound::UbElement; +use self::union_bound::{PointerLike, UbElement}; use crate::dag::{Dag, DagLike, NoSharing}; use crate::Tmr; @@ -85,7 +85,7 @@ mod precomputed; mod union_bound; mod variable; -pub use context::Context; +pub use context::{BoundRef, Context}; pub use final_data::{CompleteBound, Final}; /// Error type for simplicity @@ -285,14 +285,6 @@ impl Bound { self.clone() } - fn free(name: String) -> Self { - Bound::Free(name) - } - - fn unit() -> Self { - Bound::Complete(Final::unit()) - } - fn sum(a: Type, b: Type) -> Self { if let (Some(adata), Some(bdata)) = (a.final_data(), b.final_data()) { Bound::Complete(Final::sum(adata, bdata)) @@ -392,18 +384,22 @@ impl DagLike for Arc { pub struct Type { /// A set of constraints, which maintained by the union-bound algorithm and /// is progressively tightened as type inference proceeds. - bound: UbElement>, + bound: UbElement, } impl Type { /// Return an unbound type with the given name - pub fn free(_: &Context, name: String) -> Self { - Type::from(Bound::free(name)) + pub fn free(ctx: &Context, name: String) -> Self { + Type { + bound: UbElement::new(ctx.alloc_free(name)), + } } /// Create the unit type. - pub fn unit(_: &Context) -> Self { - Type::from(Bound::unit()) + pub fn unit(ctx: &Context) -> Self { + Type { + bound: UbElement::new(ctx.alloc_unit()), + } } /// Create the type `2^(2^n)` for the given `n`. @@ -414,18 +410,24 @@ impl Type { } /// Create the sum of the given `left` and `right` types. - pub fn sum(_: &Context, left: Self, right: Self) -> Self { - Type::from(Bound::sum(left, right)) + pub fn sum(ctx: &Context, left: Self, right: Self) -> Self { + Type { + bound: UbElement::new(ctx.alloc_sum(left, right)), + } } /// Create the product of the given `left` and `right` types. - pub fn product(_: &Context, left: Self, right: Self) -> Self { - Type::from(Bound::product(left, right)) + pub fn product(ctx: &Context, left: Self, right: Self) -> Self { + Type { + bound: UbElement::new(ctx.alloc_product(left, right)), + } } /// Create a complete type. - pub fn complete(_: &Context, final_data: Arc) -> Self { - Type::from(Bound::Complete(final_data)) + pub fn complete(ctx: &Context, final_data: Arc) -> Self { + Type { + bound: UbElement::new(ctx.alloc_complete(final_data)), + } } /// Clones the `Type`. @@ -580,15 +582,6 @@ impl fmt::Display for Type { } } -impl From for Type { - /// Promotes a `Bound` to a type defined by that constraint - fn from(bound: Bound) -> Type { - Type { - bound: UbElement::new(Arc::new(bound_mutex::BoundMutex::new(bound))), - } - } -} - impl DagLike for Type { type Node = Type; fn data(&self) -> &Type { From f3ec9d2a02769f3950aad11377137407c785b481 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 30 Jun 2024 23:28:00 +0000 Subject: [PATCH 12/16] types: remove set and get methods from BoundRef Once our BoundRefs start requiring an inference context to access their data, we won't be able to call .set() and .get() on them individually. Remove these methods, and instead add them on Context. Doing this means that everywhere we currently call .get and .set, we need a context available. To achieve this, we add the context to Type, and swap the fmt::Debug/fmt::Display impls for Type and Bound so that Type is the primary one (since it has the context object available). The change to use BoundRef in the finalization code means we can change our occurs-check from directly using Arc::::as_ptr to using a more "principled" OccursCheckId object yielded from the BoundRef. This in turn means that we no longer need to use Arc anywhere, and can instead directly use Bound (which is cheap to clone and doesn't need to be wrapped in an Arc, except when we are using Arc to obtain a pointer-id for use in the occurs check). Converting Arc to Bound in turn lets us remove a bunch of Arc::new and Arc::clone calls throughout. Again, "API only", but there's a lot going on here. --- src/node/redeem.rs | 3 +- src/types/arrow.rs | 8 +- src/types/context.rs | 74 +++++++++++-- src/types/mod.rs | 247 +++++++++++++++++++++++-------------------- 4 files changed, 202 insertions(+), 130 deletions(-) diff --git a/src/node/redeem.rs b/src/node/redeem.rs index f621de29..11e5eb49 100644 --- a/src/node/redeem.rs +++ b/src/node/redeem.rs @@ -290,7 +290,8 @@ impl RedeemNode { data: &PostOrderIterItem<&ConstructNode>, _: &NoWitness, ) -> Result, Self::Error> { - let target_ty = data.node.data.arrow().target.finalize()?; + let arrow = data.node.data.arrow(); + let target_ty = arrow.target.finalize()?; self.bits.read_value(&target_ty).map_err(Error::from) } diff --git a/src/types/arrow.rs b/src/types/arrow.rs index 54aa40d7..70e886b8 100644 --- a/src/types/arrow.rs +++ b/src/types/arrow.rs @@ -125,7 +125,7 @@ impl Arrow { if let Some(lchild_arrow) = lchild_arrow { ctx.bind( &lchild_arrow.source, - Arc::new(Bound::Product(a, c.shallow_clone())), + Bound::Product(a, c.shallow_clone()), "case combinator: left source = A × C", )?; ctx.unify(&target, &lchild_arrow.target, "").unwrap(); @@ -133,7 +133,7 @@ impl Arrow { if let Some(rchild_arrow) = rchild_arrow { ctx.bind( &rchild_arrow.source, - Arc::new(Bound::Product(b, c)), + Bound::Product(b, c), "case combinator: left source = B × C", )?; ctx.unify( @@ -168,12 +168,12 @@ impl Arrow { ctx.bind( &lchild_arrow.source, - Arc::new(prod_256_a), + prod_256_a, "disconnect combinator: left source = 2^256 × A", )?; ctx.bind( &lchild_arrow.target, - Arc::new(prod_b_c), + prod_b_c, "disconnect combinator: left target = B × C", )?; diff --git a/src/types/context.rs b/src/types/context.rs index 92c4a107..2ca23034 100644 --- a/src/types/context.rs +++ b/src/types/context.rs @@ -17,6 +17,8 @@ use std::fmt; use std::sync::{Arc, Mutex}; +use crate::dag::{Dag, DagLike}; + use super::bound_mutex::BoundMutex; use super::{Bound, Error, Final, Type}; @@ -123,11 +125,37 @@ impl Context { } } + /// Accesses a bound. + /// + /// # Panics + /// + /// Panics if passed a `BoundRef` that was not allocated by this context. + pub fn get(&self, bound: &BoundRef) -> Bound { + bound.assert_matches_context(self); + bound.index.get().shallow_clone() + } + + /// Reassigns a bound to a different bound. + /// + /// # Panics + /// + /// Panics if called on a complete type. This is a sanity-check to avoid + /// replacing already-completed types, which can cause inefficiencies in + /// the union-bound algorithm (and if our replacement changes the type, + /// this is probably a bug. + /// probably a bug. + /// + /// Also panics if passed a `BoundRef` that was not allocated by this context. + pub fn reassign_non_complete(&self, bound: BoundRef, new: Bound) { + bound.assert_matches_context(self); + bound.index.set(new) + } + /// Binds the type to a given bound. If this fails, attach the provided /// hint to the error. /// /// Fails if the type has an existing incompatible bound. - pub fn bind(&self, existing: &Type, new: Arc, hint: &'static str) -> Result<(), Error> { + pub fn bind(&self, existing: &Type, new: Bound, hint: &'static str) -> Result<(), Error> { existing.bind(new, hint) } @@ -139,7 +167,7 @@ impl Context { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BoundRef { context: *const Mutex>, // Will become an index into the context in a latter commit, but for @@ -156,16 +184,18 @@ impl BoundRef { ); } - pub fn get(&self) -> Arc { - self.index.get() - } - - pub fn set(&self, new: Arc) { - self.index.set(new) + pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> { + self.index.bind(bound, hint) } - pub fn bind(&self, bound: Arc, hint: &'static str) -> Result<(), Error> { - self.index.bind(bound, hint) + /// Creates an "occurs-check ID" which is just a copy of the [`BoundRef`] + /// with `PartialEq` and `Eq` implemented in terms of underlying pointer + /// equality. + pub fn occurs_check_id(&self) -> OccursCheckId { + OccursCheckId { + context: self.context, + index: Arc::as_ptr(&self.index), + } } } @@ -185,3 +215,27 @@ impl super::PointerLike for BoundRef { } } } + +impl<'ctx> DagLike for (&'ctx Context, BoundRef) { + type Node = BoundRef; + fn data(&self) -> &BoundRef { + &self.1 + } + + fn as_dag_node(&self) -> Dag { + match self.0.get(&self.1) { + Bound::Free(..) | Bound::Complete(..) => Dag::Nullary, + Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => { + Dag::Binary((self.0, ty1.bound.root()), (self.0, ty2.bound.root())) + } + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +pub struct OccursCheckId { + context: *const Mutex>, + // Will become an index into the context in a latter commit, but for + // now we set it to an Arc to preserve semantics. + index: *const BoundMutex, +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 9f3ed76b..b50a9de5 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -105,7 +105,7 @@ pub enum Error { hint: &'static str, }, /// A type is recursive (i.e., occurs within itself), violating the "occurs check" - OccursCheck { infinite_bound: Arc }, + OccursCheck { infinite_bound: Bound }, /// Attempted to combine two nodes which had different type inference /// contexts. This is probably a programming error. InferenceContextMismatch, @@ -156,7 +156,7 @@ mod bound_mutex { /// Source or target type of a Simplicity expression pub struct BoundMutex { /// The type's status according to the union-bound algorithm. - inner: Mutex>, + inner: Mutex, } impl fmt::Debug for BoundMutex { @@ -168,24 +168,24 @@ mod bound_mutex { impl BoundMutex { pub fn new(bound: Bound) -> Self { BoundMutex { - inner: Mutex::new(Arc::new(bound)), + inner: Mutex::new(bound), } } - pub fn get(&self) -> Arc { - Arc::clone(&self.inner.lock().unwrap()) + pub fn get(&self) -> Bound { + self.inner.lock().unwrap().shallow_clone() } - pub fn set(&self, new: Arc) { + pub fn set(&self, new: Bound) { let mut lock = self.inner.lock().unwrap(); assert!( - !matches!(**lock, Bound::Complete(..)), + !matches!(*lock, Bound::Complete(..)), "tried to modify finalized type", ); *lock = new; } - pub fn bind(&self, bound: Arc, hint: &'static str) -> Result<(), Error> { + pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> { let existing_bound = self.get(); let bind_error = || Error::Bind { existing_bound: existing_bound.shallow_clone(), @@ -193,7 +193,7 @@ mod bound_mutex { hint, }; - match (existing_bound.as_ref(), bound.as_ref()) { + match (&existing_bound, &bound) { // Binding a free type to anything is a no-op (_, Bound::Free(_)) => Ok(()), // Free types are simply dropped and replaced by the new bound @@ -226,8 +226,8 @@ mod bound_mutex { CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2), ) => { - ty1.bind(Arc::new(Bound::Complete(Arc::clone(comp1))), hint)?; - ty2.bind(Arc::new(Bound::Complete(Arc::clone(comp2))), hint) + ty1.bind(Bound::Complete(Arc::clone(comp1)), hint)?; + ty2.bind(Bound::Complete(Arc::clone(comp2)), hint) } _ => Err(bind_error()), } @@ -244,11 +244,11 @@ mod bound_mutex { // It also gives the user access to more information about the type, // prior to finalization. if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) { - self.set(Arc::new(Bound::Complete(if let Bound::Sum(..) = *bound { + self.set(Bound::Complete(if let Bound::Sum(..) = bound { Final::sum(data1, data2) } else { Final::product(data1, data2) - }))); + })); } Ok(()) } @@ -263,7 +263,7 @@ mod bound_mutex { } /// The state of a [`Type`] based on all constraints currently imposed on it. -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum Bound { /// Fully-unconstrained type Free(String), @@ -275,6 +275,25 @@ pub enum Bound { Product(Type, Type), } +impl fmt::Display for Bound { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Bound::Free(s) => f.write_str(s), + Bound::Complete(comp) => comp.fmt(f), + Bound::Sum(ty1, ty2) => { + ty1.fmt(f)?; + f.write_str(" + ")?; + ty2.fmt(f) + } + Bound::Product(ty1, ty2) => { + ty1.fmt(f)?; + f.write_str(" × ")?; + ty2.fmt(f) + } + } + } +} + impl Bound { /// Clones the `Bound`. /// @@ -302,86 +321,15 @@ impl Bound { } } -const MAX_DISPLAY_DEPTH: usize = 64; - -impl fmt::Debug for Bound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let arc = Arc::new(self.shallow_clone()); - for data in arc.verbose_pre_order_iter::(Some(MAX_DISPLAY_DEPTH)) { - if data.depth == MAX_DISPLAY_DEPTH { - if data.n_children_yielded == 0 { - f.write_str("...")?; - } - continue; - } - match (&*data.node, data.n_children_yielded) { - (Bound::Free(ref s), _) => f.write_str(s)?, - (Bound::Complete(ref comp), _) => fmt::Debug::fmt(comp, f)?, - (Bound::Sum(..), 0) | (Bound::Product(..), 0) => f.write_str("(")?, - (Bound::Sum(..), 2) | (Bound::Product(..), 2) => f.write_str(")")?, - (Bound::Sum(..), _) => f.write_str(" + ")?, - (Bound::Product(..), _) => f.write_str(" × ")?, - } - } - Ok(()) - } -} - -impl fmt::Display for Bound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let arc = Arc::new(self.shallow_clone()); - for data in arc.verbose_pre_order_iter::(Some(MAX_DISPLAY_DEPTH)) { - if data.depth == MAX_DISPLAY_DEPTH { - if data.n_children_yielded == 0 { - f.write_str("...")?; - } - continue; - } - match (&*data.node, data.n_children_yielded) { - (Bound::Free(ref s), _) => f.write_str(s)?, - (Bound::Complete(ref comp), _) => fmt::Display::fmt(comp, f)?, - (Bound::Sum(..), 0) | (Bound::Product(..), 0) => { - if data.index > 0 { - f.write_str("(")?; - } - } - (Bound::Sum(..), 2) | (Bound::Product(..), 2) => { - if data.index > 0 { - f.write_str(")")? - } - } - (Bound::Sum(..), _) => f.write_str(" + ")?, - (Bound::Product(..), _) => f.write_str(" × ")?, - } - } - Ok(()) - } -} - -impl DagLike for Arc { - type Node = Bound; - fn data(&self) -> &Bound { - self - } - - fn as_dag_node(&self) -> Dag { - match **self { - Bound::Free(..) | Bound::Complete(..) => Dag::Nullary, - Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => { - Dag::Binary(ty1.bound.root().get(), ty2.bound.root().get()) - } - } - } -} - /// Source or target type of a Simplicity expression. /// /// Internally this type is essentially just a refcounted pointer; it is /// therefore quite cheap to clone, but be aware that cloning will not /// actually create a new independent type, just a second pointer to the /// first one. -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Type { + ctx: Context, /// A set of constraints, which maintained by the union-bound algorithm and /// is progressively tightened as type inference proceeds. bound: UbElement, @@ -391,6 +339,7 @@ impl Type { /// Return an unbound type with the given name pub fn free(ctx: &Context, name: String) -> Self { Type { + ctx: ctx.shallow_clone(), bound: UbElement::new(ctx.alloc_free(name)), } } @@ -398,6 +347,7 @@ impl Type { /// Create the unit type. pub fn unit(ctx: &Context) -> Self { Type { + ctx: ctx.shallow_clone(), bound: UbElement::new(ctx.alloc_unit()), } } @@ -412,6 +362,7 @@ impl Type { /// Create the sum of the given `left` and `right` types. pub fn sum(ctx: &Context, left: Self, right: Self) -> Self { Type { + ctx: ctx.shallow_clone(), bound: UbElement::new(ctx.alloc_sum(left, right)), } } @@ -419,6 +370,7 @@ impl Type { /// Create the product of the given `left` and `right` types. pub fn product(ctx: &Context, left: Self, right: Self) -> Self { Type { + ctx: ctx.shallow_clone(), bound: UbElement::new(ctx.alloc_product(left, right)), } } @@ -426,6 +378,7 @@ impl Type { /// Create a complete type. pub fn complete(ctx: &Context, final_data: Arc) -> Self { Type { + ctx: ctx.shallow_clone(), bound: UbElement::new(ctx.alloc_complete(final_data)), } } @@ -442,7 +395,7 @@ impl Type { /// hint to the error. /// /// Fails if the type has an existing incompatible bound. - fn bind(&self, bound: Arc, hint: &'static str) -> Result<(), Error> { + fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> { let root = self.bound.root(); root.bind(bound, hint) } @@ -452,13 +405,13 @@ impl Type { /// Fails if the bounds on the two types are incompatible fn unify(&self, other: &Self, hint: &'static str) -> Result<(), Error> { self.bound.unify(&other.bound, |x_bound, y_bound| { - x_bound.bind(y_bound.get(), hint) + x_bound.bind(self.ctx.get(y_bound), hint) }) } /// Accessor for this type's bound - pub fn bound(&self) -> Arc { - self.bound.root().get() + pub fn bound(&self) -> Bound { + self.ctx.get(&self.bound.root()) } /// Accessor for the TMR of this type, if it is final @@ -468,7 +421,7 @@ impl Type { /// Accessor for the data of this type, if it is complete pub fn final_data(&self) -> Option> { - if let Bound::Complete(ref data) = *self.bound.root().get() { + if let Bound::Complete(ref data) = self.bound() { Some(Arc::clone(data)) } else { None @@ -481,55 +434,57 @@ impl Type { /// complete, since its children may have been unified to a complete type. To /// ensure a type is complete, call [`Type::finalize`]. pub fn is_final(&self) -> bool { - matches!(*self.bound.root().get(), Bound::Complete(..)) + self.final_data().is_some() } /// Attempts to finalize the type. Returns its TMR on success. pub fn finalize(&self) -> Result, Error> { + use context::OccursCheckId; + /// Helper type for the occurs-check. enum OccursCheckStack { - Iterate(Arc), - Complete(*const Bound), + Iterate(BoundRef), + Complete(OccursCheckId), } // Done with sharing tracker. Actual algorithm follows. let root = self.bound.root(); - let bound = root.get(); - if let Bound::Complete(ref data) = *bound { + let bound = self.ctx.get(&root); + if let Bound::Complete(ref data) = bound { return Ok(Arc::clone(data)); } // First, do occurs-check to ensure that we have no infinitely sized types. - let mut stack = vec![OccursCheckStack::Iterate(Arc::clone(&bound))]; + let mut stack = vec![OccursCheckStack::Iterate(root)]; let mut in_progress = HashSet::new(); let mut completed = HashSet::new(); while let Some(top) = stack.pop() { let bound = match top { - OccursCheckStack::Complete(ptr) => { - in_progress.remove(&ptr); - completed.insert(ptr); + OccursCheckStack::Complete(id) => { + in_progress.remove(&id); + completed.insert(id); continue; } OccursCheckStack::Iterate(b) => b, }; - let ptr = bound.as_ref() as *const _; - if completed.contains(&ptr) { + let id = bound.occurs_check_id(); + if completed.contains(&id) { // Once we have iterated through a type, we don't need to check it again. // Without this shortcut the occurs-check would take exponential time. continue; } - if !in_progress.insert(ptr) { + if !in_progress.insert(id) { return Err(Error::OccursCheck { - infinite_bound: bound, + infinite_bound: self.ctx.get(&bound), }); } - stack.push(OccursCheckStack::Complete(ptr)); - if let Some(child) = bound.right_child() { + stack.push(OccursCheckStack::Complete(id)); + if let Some((_, child)) = (&self.ctx, bound.shallow_clone()).right_child() { stack.push(OccursCheckStack::Iterate(child)); } - if let Some(child) = bound.left_child() { + if let Some((_, child)) = (&self.ctx, bound).left_child() { stack.push(OccursCheckStack::Iterate(child)); } } @@ -539,8 +494,8 @@ impl Type { let mut finalized = vec![]; for data in self.shallow_clone().post_order_iter::() { let bound = data.node.bound.root(); - let bound_get = bound.get(); - let final_data = match *bound_get { + let bound_get = self.ctx.get(&bound); + let final_data = match bound_get { Bound::Free(_) => Final::unit(), Bound::Complete(ref arc) => Arc::clone(arc), Bound::Sum(..) => Final::sum( @@ -553,9 +508,9 @@ impl Type { ), }; - if !matches!(*bound_get, Bound::Complete(..)) { - // set() ok because we are if-guarded on this variable not being complete - bound.set(Arc::new(Bound::Complete(Arc::clone(&final_data)))); + if !matches!(bound_get, Bound::Complete(..)) { + self.ctx + .reassign_non_complete(bound, Bound::Complete(Arc::clone(&final_data))); } finalized.push(final_data); } @@ -576,9 +531,71 @@ impl Type { } } +const MAX_DISPLAY_DEPTH: usize = 64; + +impl fmt::Debug for Type { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + for data in (&self.ctx, self.bound.root()) + .verbose_pre_order_iter::(Some(MAX_DISPLAY_DEPTH)) + { + if data.depth == MAX_DISPLAY_DEPTH { + if data.n_children_yielded == 0 { + f.write_str("...")?; + } + continue; + } + let bound = data.node.0.get(&data.node.1); + match (bound, data.n_children_yielded) { + (Bound::Free(ref s), _) => f.write_str(s)?, + (Bound::Complete(ref comp), _) => fmt::Debug::fmt(comp, f)?, + (Bound::Sum(..), 0) | (Bound::Product(..), 0) => { + if data.index > 0 { + f.write_str("(")?; + } + } + (Bound::Sum(..), 2) | (Bound::Product(..), 2) => { + if data.index > 0 { + f.write_str(")")? + } + } + (Bound::Sum(..), _) => f.write_str(" + ")?, + (Bound::Product(..), _) => f.write_str(" × ")?, + } + } + Ok(()) + } +} + impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Display::fmt(&self.bound.root().get(), f) + for data in (&self.ctx, self.bound.root()) + .verbose_pre_order_iter::(Some(MAX_DISPLAY_DEPTH)) + { + if data.depth == MAX_DISPLAY_DEPTH { + if data.n_children_yielded == 0 { + f.write_str("...")?; + } + continue; + } + let bound = data.node.0.get(&data.node.1); + match (bound, data.n_children_yielded) { + (Bound::Free(ref s), _) => f.write_str(s)?, + (Bound::Complete(ref comp), _) => fmt::Display::fmt(comp, f)?, + (Bound::Sum(..), 0) | (Bound::Product(..), 0) => { + if data.index > 0 { + f.write_str("(")?; + } + } + (Bound::Sum(..), 2) | (Bound::Product(..), 2) => { + if data.index > 0 { + f.write_str(")")? + } + } + (Bound::Sum(..), _) => f.write_str(" + ")?, + (Bound::Product(..), _) => f.write_str(" × ")?, + } + } + Ok(()) } } @@ -589,7 +606,7 @@ impl DagLike for Type { } fn as_dag_node(&self) -> Dag { - match *self.bound.root().get() { + match self.bound() { Bound::Free(..) | Bound::Complete(..) => Dag::Nullary, Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => { Dag::Binary(ty1.shallow_clone(), ty2.shallow_clone()) From b6a3193c3e472b2fced3348f21ec41d476bf3152 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 30 Jun 2024 16:11:01 +0000 Subject: [PATCH 13/16] types: pull unify and bind into inference context Pulls the unify and bind methods out of Type and BoundMutex and implement them on a new private LockedContext struct. This locks the entire inference context for the duration of a bind or unify operation, and because it only locks inside of non-recursive methods, it is impossible to deadlock. This is "API-only" in the sense that the actual type bounds continue to be represented by free-floating Arcs, but it has a semantic change in that binds and unifications now happen atomically (due to the continuously held lock on the context) which fixes a likely class of bugs wherein if you try to unify related variables from multiple threads at once, the old code probably would due weird things, due to the very local locking and total lack of other synchronization. The next commit will finally delete BoundMutex, move the bounds into the actual context object, and you will see the point of all these massive code lifts :). --- src/types/context.rs | 116 ++++++++++++++++++++++++++++++++++++++++--- src/types/mod.rs | 97 +----------------------------------- 2 files changed, 110 insertions(+), 103 deletions(-) diff --git a/src/types/context.rs b/src/types/context.rs index 2ca23034..981ca355 100644 --- a/src/types/context.rs +++ b/src/types/context.rs @@ -15,12 +15,12 @@ //! use std::fmt; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard}; use crate::dag::{Dag, DagLike}; use super::bound_mutex::BoundMutex; -use super::{Bound, Error, Final, Type}; +use super::{Bound, CompleteBound, Error, Final, Type}; /// Type inference context, or handle to a context. /// @@ -156,14 +156,24 @@ impl Context { /// /// Fails if the type has an existing incompatible bound. pub fn bind(&self, existing: &Type, new: Bound, hint: &'static str) -> Result<(), Error> { - existing.bind(new, hint) + let existing_root = existing.bound.root(); + let lock = self.lock(); + lock.bind(existing_root, new, hint) } /// Unify the type with another one. /// /// Fails if the bounds on the two types are incompatible pub fn unify(&self, ty1: &Type, ty2: &Type, hint: &'static str) -> Result<(), Error> { - ty1.unify(ty2, hint) + let lock = self.lock(); + lock.unify(ty1, ty2, hint) + } + + /// Locks the underlying slab mutex. + fn lock(&self) -> LockedContext { + LockedContext { + slab: self.slab.lock().unwrap(), + } } } @@ -184,10 +194,6 @@ impl BoundRef { ); } - pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> { - self.index.bind(bound, hint) - } - /// Creates an "occurs-check ID" which is just a copy of the [`BoundRef`] /// with `PartialEq` and `Eq` implemented in terms of underlying pointer /// equality. @@ -239,3 +245,97 @@ pub struct OccursCheckId { // now we set it to an Arc to preserve semantics. index: *const BoundMutex, } + +/// Structure representing an inference context with its slab allocator mutex locked. +/// +/// This type is never exposed outside of this module and should only exist +/// ephemerally within function calls into this module. +struct LockedContext<'ctx> { + slab: MutexGuard<'ctx, Vec>, +} + +impl<'ctx> LockedContext<'ctx> { + /// Unify the type with another one. + /// + /// Fails if the bounds on the two types are incompatible + fn unify(&self, existing: &Type, other: &Type, hint: &'static str) -> Result<(), Error> { + existing.bound.unify(&other.bound, |x_bound, y_bound| { + self.bind(x_bound, y_bound.index.get(), hint) + }) + } + + fn bind(&self, existing: BoundRef, new: Bound, hint: &'static str) -> Result<(), Error> { + let existing_bound = existing.index.get(); + let bind_error = || Error::Bind { + existing_bound: existing_bound.shallow_clone(), + new_bound: new.shallow_clone(), + hint, + }; + + match (&existing_bound, &new) { + // Binding a free type to anything is a no-op + (_, Bound::Free(_)) => Ok(()), + // Free types are simply dropped and replaced by the new bound + (Bound::Free(_), _) => { + // Free means non-finalized, so set() is ok. + existing.index.set(new); + Ok(()) + } + // Binding complete->complete shouldn't ever happen, but if so, we just + // compare the two types and return a pass/fail + (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => { + if existing_final == new_final { + Ok(()) + } else { + Err(bind_error()) + } + } + // Binding an incomplete to a complete type requires recursion. + (Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => { + match (complete.bound(), incomplete) { + // A unit might match a Bound::Free(..) or a Bound::Complete(..), + // and both cases were handled above. So this is an error. + (CompleteBound::Unit, _) => Err(bind_error()), + ( + CompleteBound::Product(ref comp1, ref comp2), + Bound::Product(ref ty1, ref ty2), + ) + | (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => { + let bound1 = ty1.bound.root(); + let bound2 = ty2.bound.root(); + self.bind(bound1, Bound::Complete(Arc::clone(comp1)), hint)?; + self.bind(bound2, Bound::Complete(Arc::clone(comp2)), hint) + } + _ => Err(bind_error()), + } + } + (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2)) + | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => { + self.unify(x1, y1, hint)?; + self.unify(x2, y2, hint)?; + // This type was not complete, but it may be after unification, giving us + // an opportunity to finaliize it. We do this eagerly to make sure that + // "complete" (no free children) is always equivalent to "finalized" (the + // bound field having variant Bound::Complete(..)), even during inference. + // + // It also gives the user access to more information about the type, + // prior to finalization. + if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) { + existing + .index + .set(Bound::Complete(if let Bound::Sum(..) = existing_bound { + Final::sum(data1, data2) + } else { + Final::product(data1, data2) + })); + } + Ok(()) + } + (x, y) => Err(Error::Bind { + existing_bound: x.shallow_clone(), + new_bound: y.shallow_clone(), + hint, + }), + } + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index b50a9de5..fd5c650d 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -149,9 +149,9 @@ impl fmt::Display for Error { impl std::error::Error for Error {} mod bound_mutex { - use super::{Bound, CompleteBound, Error, Final}; + use super::Bound; use std::fmt; - use std::sync::{Arc, Mutex}; + use std::sync::Mutex; /// Source or target type of a Simplicity expression pub struct BoundMutex { @@ -184,81 +184,6 @@ mod bound_mutex { ); *lock = new; } - - pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> { - let existing_bound = self.get(); - let bind_error = || Error::Bind { - existing_bound: existing_bound.shallow_clone(), - new_bound: bound.shallow_clone(), - hint, - }; - - match (&existing_bound, &bound) { - // Binding a free type to anything is a no-op - (_, Bound::Free(_)) => Ok(()), - // Free types are simply dropped and replaced by the new bound - (Bound::Free(_), _) => { - // Free means non-finalized, so set() is ok. - self.set(bound); - Ok(()) - } - // Binding complete->complete shouldn't ever happen, but if so, we just - // compare the two types and return a pass/fail - (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => { - if existing_final == new_final { - Ok(()) - } else { - Err(bind_error()) - } - } - // Binding an incomplete to a complete type requires recursion. - (Bound::Complete(complete), incomplete) - | (incomplete, Bound::Complete(complete)) => { - match (complete.bound(), incomplete) { - // A unit might match a Bound::Free(..) or a Bound::Complete(..), - // and both cases were handled above. So this is an error. - (CompleteBound::Unit, _) => Err(bind_error()), - ( - CompleteBound::Product(ref comp1, ref comp2), - Bound::Product(ref ty1, ref ty2), - ) - | ( - CompleteBound::Sum(ref comp1, ref comp2), - Bound::Sum(ref ty1, ref ty2), - ) => { - ty1.bind(Bound::Complete(Arc::clone(comp1)), hint)?; - ty2.bind(Bound::Complete(Arc::clone(comp2)), hint) - } - _ => Err(bind_error()), - } - } - (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2)) - | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => { - x1.unify(y1, hint)?; - x2.unify(y2, hint)?; - // This type was not complete, but it may be after unification, giving us - // an opportunity to finaliize it. We do this eagerly to make sure that - // "complete" (no free children) is always equivalent to "finalized" (the - // bound field having variant Bound::Complete(..)), even during inference. - // - // It also gives the user access to more information about the type, - // prior to finalization. - if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) { - self.set(Bound::Complete(if let Bound::Sum(..) = bound { - Final::sum(data1, data2) - } else { - Final::product(data1, data2) - })); - } - Ok(()) - } - (x, y) => Err(Error::Bind { - existing_bound: x.shallow_clone(), - new_bound: y.shallow_clone(), - hint, - }), - } - } } } @@ -391,24 +316,6 @@ impl Type { self.clone() } - /// Binds the type to a given bound. If this fails, attach the provided - /// hint to the error. - /// - /// Fails if the type has an existing incompatible bound. - fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> { - let root = self.bound.root(); - root.bind(bound, hint) - } - - /// Unify the type with another one. - /// - /// Fails if the bounds on the two types are incompatible - fn unify(&self, other: &Self, hint: &'static str) -> Result<(), Error> { - self.bound.unify(&other.bound, |x_bound, y_bound| { - x_bound.bind(self.ctx.get(y_bound), hint) - }) - } - /// Accessor for this type's bound pub fn bound(&self) -> Bound { self.ctx.get(&self.bound.root()) From b4a267261d4a6ddbd60965d93a238ba72191a751 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 30 Jun 2024 21:52:43 +0000 Subject: [PATCH 14/16] types: drop BoundMutex and instead use references into the type context slab This completes the transition to using type contexts to keep track of (and allocate/mass-deallocate) type bounds :). There are three major improvements in this changeset: * We no longer leak memory when infinite type bounds are constructed. * It is no longer possible to create distinct programs where the variables are mixed up. (Ok, you can do this still, but you have to explicitly use the same type context for both programs, which is an obvious bug.) * Unification and binding happen atomically, so if you are doing type inference across multiple threads, crosstalk won't happen between them. --- src/types/context.rs | 65 ++++++++++++++++++++++++++------------------ src/types/mod.rs | 39 -------------------------- 2 files changed, 39 insertions(+), 65 deletions(-) diff --git a/src/types/context.rs b/src/types/context.rs index 981ca355..04775a7a 100644 --- a/src/types/context.rs +++ b/src/types/context.rs @@ -19,7 +19,6 @@ use std::sync::{Arc, Mutex, MutexGuard}; use crate::dag::{Dag, DagLike}; -use super::bound_mutex::BoundMutex; use super::{Bound, CompleteBound, Error, Final, Type}; /// Type inference context, or handle to a context. @@ -60,9 +59,13 @@ impl Context { /// Helper function to allocate a bound and return a reference to it. fn alloc_bound(&self, bound: Bound) -> BoundRef { + let mut lock = self.lock(); + lock.slab.push(bound); + let index = lock.slab.len() - 1; + BoundRef { context: Arc::as_ptr(&self.slab), - index: Arc::new(BoundMutex::new(bound)), + index, } } @@ -132,7 +135,8 @@ impl Context { /// Panics if passed a `BoundRef` that was not allocated by this context. pub fn get(&self, bound: &BoundRef) -> Bound { bound.assert_matches_context(self); - bound.index.get().shallow_clone() + let lock = self.lock(); + lock.slab[bound.index].shallow_clone() } /// Reassigns a bound to a different bound. @@ -147,8 +151,8 @@ impl Context { /// /// Also panics if passed a `BoundRef` that was not allocated by this context. pub fn reassign_non_complete(&self, bound: BoundRef, new: Bound) { - bound.assert_matches_context(self); - bound.index.set(new) + let mut lock = self.lock(); + lock.reassign_non_complete(bound, new); } /// Binds the type to a given bound. If this fails, attach the provided @@ -157,7 +161,7 @@ impl Context { /// Fails if the type has an existing incompatible bound. pub fn bind(&self, existing: &Type, new: Bound, hint: &'static str) -> Result<(), Error> { let existing_root = existing.bound.root(); - let lock = self.lock(); + let mut lock = self.lock(); lock.bind(existing_root, new, hint) } @@ -165,7 +169,7 @@ impl Context { /// /// Fails if the bounds on the two types are incompatible pub fn unify(&self, ty1: &Type, ty2: &Type, hint: &'static str) -> Result<(), Error> { - let lock = self.lock(); + let mut lock = self.lock(); lock.unify(ty1, ty2, hint) } @@ -180,9 +184,7 @@ impl Context { #[derive(Debug, Clone)] pub struct BoundRef { context: *const Mutex>, - // Will become an index into the context in a latter commit, but for - // now we set it to an Arc to preserve semantics. - index: Arc, + index: usize, } impl BoundRef { @@ -200,7 +202,7 @@ impl BoundRef { pub fn occurs_check_id(&self) -> OccursCheckId { OccursCheckId { context: self.context, - index: Arc::as_ptr(&self.index), + index: self.index, } } } @@ -211,13 +213,13 @@ impl super::PointerLike for BoundRef { self.context, other.context, "tried to compare two bounds from different inference contexts" ); - Arc::ptr_eq(&self.index, &other.index) + self.index == other.index } fn shallow_clone(&self) -> Self { BoundRef { context: self.context, - index: Arc::clone(&self.index), + index: self.index, } } } @@ -243,7 +245,7 @@ pub struct OccursCheckId { context: *const Mutex>, // Will become an index into the context in a latter commit, but for // now we set it to an Arc to preserve semantics. - index: *const BoundMutex, + index: usize, } /// Structure representing an inference context with its slab allocator mutex locked. @@ -255,17 +257,25 @@ struct LockedContext<'ctx> { } impl<'ctx> LockedContext<'ctx> { + fn reassign_non_complete(&mut self, bound: BoundRef, new: Bound) { + assert!( + !matches!(self.slab[bound.index], Bound::Complete(..)), + "tried to modify finalized type", + ); + self.slab[bound.index] = new; + } + /// Unify the type with another one. /// /// Fails if the bounds on the two types are incompatible - fn unify(&self, existing: &Type, other: &Type, hint: &'static str) -> Result<(), Error> { + fn unify(&mut self, existing: &Type, other: &Type, hint: &'static str) -> Result<(), Error> { existing.bound.unify(&other.bound, |x_bound, y_bound| { - self.bind(x_bound, y_bound.index.get(), hint) + self.bind(x_bound, self.slab[y_bound.index].shallow_clone(), hint) }) } - fn bind(&self, existing: BoundRef, new: Bound, hint: &'static str) -> Result<(), Error> { - let existing_bound = existing.index.get(); + fn bind(&mut self, existing: BoundRef, new: Bound, hint: &'static str) -> Result<(), Error> { + let existing_bound = self.slab[existing.index].shallow_clone(); let bind_error = || Error::Bind { existing_bound: existing_bound.shallow_clone(), new_bound: new.shallow_clone(), @@ -278,7 +288,7 @@ impl<'ctx> LockedContext<'ctx> { // Free types are simply dropped and replaced by the new bound (Bound::Free(_), _) => { // Free means non-finalized, so set() is ok. - existing.index.set(new); + self.reassign_non_complete(existing, new); Ok(()) } // Binding complete->complete shouldn't ever happen, but if so, we just @@ -320,14 +330,17 @@ impl<'ctx> LockedContext<'ctx> { // // It also gives the user access to more information about the type, // prior to finalization. - if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) { - existing - .index - .set(Bound::Complete(if let Bound::Sum(..) = existing_bound { - Final::sum(data1, data2) + let y1_bound = &self.slab[y1.bound.root().index]; + let y2_bound = &self.slab[y2.bound.root().index]; + if let (Bound::Complete(data1), Bound::Complete(data2)) = (y1_bound, y2_bound) { + self.reassign_non_complete( + existing, + Bound::Complete(if let Bound::Sum(..) = existing_bound { + Final::sum(Arc::clone(data1), Arc::clone(data2)) } else { - Final::product(data1, data2) - })); + Final::product(Arc::clone(data1), Arc::clone(data2)) + }), + ); } Ok(()) } diff --git a/src/types/mod.rs b/src/types/mod.rs index fd5c650d..90ae2b54 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -148,45 +148,6 @@ impl fmt::Display for Error { impl std::error::Error for Error {} -mod bound_mutex { - use super::Bound; - use std::fmt; - use std::sync::Mutex; - - /// Source or target type of a Simplicity expression - pub struct BoundMutex { - /// The type's status according to the union-bound algorithm. - inner: Mutex, - } - - impl fmt::Debug for BoundMutex { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.get().fmt(f) - } - } - - impl BoundMutex { - pub fn new(bound: Bound) -> Self { - BoundMutex { - inner: Mutex::new(bound), - } - } - - pub fn get(&self) -> Bound { - self.inner.lock().unwrap().shallow_clone() - } - - pub fn set(&self, new: Bound) { - let mut lock = self.inner.lock().unwrap(); - assert!( - !matches!(*lock, Bound::Complete(..)), - "tried to modify finalized type", - ); - *lock = new; - } - } -} - /// The state of a [`Type`] based on all constraints currently imposed on it. #[derive(Clone, Debug)] pub enum Bound { From 0fd925030d5d3749032eb61ab52ad80ea3eead2b Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Mon, 1 Jul 2024 00:21:33 +0000 Subject: [PATCH 15/16] types: refactor bind/unify error return a bit This avoids passing `hint` through all the layers of recursion, but more importantly, constructs the actual error from a `Context` rather than from a `LockedContext`. In the next commit, we will need extra information to do this, and not want our context to be locked at the time. --- src/types/context.rs | 45 +++++++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/src/types/context.rs b/src/types/context.rs index 04775a7a..10b6c44b 100644 --- a/src/types/context.rs +++ b/src/types/context.rs @@ -162,7 +162,11 @@ impl Context { pub fn bind(&self, existing: &Type, new: Bound, hint: &'static str) -> Result<(), Error> { let existing_root = existing.bound.root(); let mut lock = self.lock(); - lock.bind(existing_root, new, hint) + lock.bind(existing_root, new).map_err(|e| Error::Bind { + existing_bound: e.existing, + new_bound: e.new, + hint, + }) } /// Unify the type with another one. @@ -170,7 +174,11 @@ impl Context { /// Fails if the bounds on the two types are incompatible pub fn unify(&self, ty1: &Type, ty2: &Type, hint: &'static str) -> Result<(), Error> { let mut lock = self.lock(); - lock.unify(ty1, ty2, hint) + lock.unify(ty1, ty2).map_err(|e| Error::Bind { + existing_bound: e.existing, + new_bound: e.new, + hint, + }) } /// Locks the underlying slab mutex. @@ -248,6 +256,11 @@ pub struct OccursCheckId { index: usize, } +struct BindError { + existing: Bound, + new: Bound, +} + /// Structure representing an inference context with its slab allocator mutex locked. /// /// This type is never exposed outside of this module and should only exist @@ -268,18 +281,17 @@ impl<'ctx> LockedContext<'ctx> { /// Unify the type with another one. /// /// Fails if the bounds on the two types are incompatible - fn unify(&mut self, existing: &Type, other: &Type, hint: &'static str) -> Result<(), Error> { + fn unify(&mut self, existing: &Type, other: &Type) -> Result<(), BindError> { existing.bound.unify(&other.bound, |x_bound, y_bound| { - self.bind(x_bound, self.slab[y_bound.index].shallow_clone(), hint) + self.bind(x_bound, self.slab[y_bound.index].shallow_clone()) }) } - fn bind(&mut self, existing: BoundRef, new: Bound, hint: &'static str) -> Result<(), Error> { + fn bind(&mut self, existing: BoundRef, new: Bound) -> Result<(), BindError> { let existing_bound = self.slab[existing.index].shallow_clone(); - let bind_error = || Error::Bind { - existing_bound: existing_bound.shallow_clone(), - new_bound: new.shallow_clone(), - hint, + let bind_error = || BindError { + existing: existing_bound.shallow_clone(), + new: new.shallow_clone(), }; match (&existing_bound, &new) { @@ -313,16 +325,16 @@ impl<'ctx> LockedContext<'ctx> { | (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => { let bound1 = ty1.bound.root(); let bound2 = ty2.bound.root(); - self.bind(bound1, Bound::Complete(Arc::clone(comp1)), hint)?; - self.bind(bound2, Bound::Complete(Arc::clone(comp2)), hint) + self.bind(bound1, Bound::Complete(Arc::clone(comp1)))?; + self.bind(bound2, Bound::Complete(Arc::clone(comp2))) } _ => Err(bind_error()), } } (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2)) | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => { - self.unify(x1, y1, hint)?; - self.unify(x2, y2, hint)?; + self.unify(x1, y1)?; + self.unify(x2, y2)?; // This type was not complete, but it may be after unification, giving us // an opportunity to finaliize it. We do this eagerly to make sure that // "complete" (no free children) is always equivalent to "finalized" (the @@ -344,10 +356,9 @@ impl<'ctx> LockedContext<'ctx> { } Ok(()) } - (x, y) => Err(Error::Bind { - existing_bound: x.shallow_clone(), - new_bound: y.shallow_clone(), - hint, + (x, y) => Err(BindError { + existing: x.shallow_clone(), + new: y.shallow_clone(), }), } } From 46c8720550f0e8c46c06af1b7ff3350a08b8cfe4 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Mon, 1 Jul 2024 00:28:19 +0000 Subject: [PATCH 16/16] f make Bound private, use Type as non-recursive public type After the previous commit, we still had a memory leak in the case of cyclic types, but for a very dumb reason: the Type structure was carrying a Context Arc, and Type was recursive. Instead, introduce a new private TypeInner type which is recursive and does not have direct access to the Context. Make Bound private as well, since it is mutually recursive with Context. I am not thrilled with this commit and may revisit it, but it fixes the leak and lets me get on with fuzzing before I go to bed. --- src/types/arrow.rs | 28 ++++---- src/types/context.rs | 149 +++++++++++++++++++++++++++++++------------ src/types/mod.rs | 128 ++++++++++++------------------------- 3 files changed, 166 insertions(+), 139 deletions(-) diff --git a/src/types/arrow.rs b/src/types/arrow.rs index 70e886b8..8fbee5f6 100644 --- a/src/types/arrow.rs +++ b/src/types/arrow.rs @@ -18,7 +18,7 @@ use crate::node::{ CoreConstructible, DisconnectConstructible, JetConstructible, NoDisconnect, WitnessConstructible, }; -use crate::types::{Bound, Context, Error, Final, Type}; +use crate::types::{Context, Error, Final, Type}; use crate::{jet::Jet, Value}; use super::variable::new_name; @@ -123,17 +123,19 @@ impl Arrow { let target = Type::free(&ctx, String::new()); if let Some(lchild_arrow) = lchild_arrow { - ctx.bind( + ctx.bind_product( &lchild_arrow.source, - Bound::Product(a, c.shallow_clone()), + &a, + &c, "case combinator: left source = A × C", )?; ctx.unify(&target, &lchild_arrow.target, "").unwrap(); } if let Some(rchild_arrow) = rchild_arrow { - ctx.bind( + ctx.bind_product( &rchild_arrow.source, - Bound::Product(b, c), + &b, + &c, "case combinator: left source = B × C", )?; ctx.unify( @@ -162,21 +164,21 @@ impl Arrow { let c = rchild_arrow.source.shallow_clone(); let d = rchild_arrow.target.shallow_clone(); - let prod_256_a = Bound::Product(Type::two_two_n(ctx, 8), a.shallow_clone()); - let prod_b_c = Bound::Product(b.shallow_clone(), c); - let prod_b_d = Type::product(ctx, b, d); - - ctx.bind( + ctx.bind_product( &lchild_arrow.source, - prod_256_a, + &Type::two_two_n(ctx, 8), + &a, "disconnect combinator: left source = 2^256 × A", )?; - ctx.bind( + ctx.bind_product( &lchild_arrow.target, - prod_b_c, + &b, + &c, "disconnect combinator: left target = B × C", )?; + let prod_b_d = Type::product(ctx, b, d); + Ok(Arrow { source: a, target: prod_b_d, diff --git a/src/types/context.rs b/src/types/context.rs index 10b6c44b..efa2b898 100644 --- a/src/types/context.rs +++ b/src/types/context.rs @@ -19,7 +19,7 @@ use std::sync::{Arc, Mutex, MutexGuard}; use crate::dag::{Dag, DagLike}; -use super::{Bound, CompleteBound, Error, Final, Type}; +use super::{Bound, CompleteBound, Error, Final, Type, TypeInner}; /// Type inference context, or handle to a context. /// @@ -60,13 +60,7 @@ impl Context { /// Helper function to allocate a bound and return a reference to it. fn alloc_bound(&self, bound: Bound) -> BoundRef { let mut lock = self.lock(); - lock.slab.push(bound); - let index = lock.slab.len() - 1; - - BoundRef { - context: Arc::as_ptr(&self.slab), - index, - } + lock.alloc_bound(Arc::as_ptr(&self.slab), bound) } /// Allocate a new free type bound, and return a reference to it. @@ -90,10 +84,24 @@ impl Context { /// /// Panics if either of the child types are from a different inference context. pub fn alloc_sum(&self, left: Type, right: Type) -> BoundRef { - left.bound.root().assert_matches_context(self); - right.bound.root().assert_matches_context(self); + assert_eq!( + left.ctx, *self, + "left type did not match inference context of sum" + ); + assert_eq!( + right.ctx, *self, + "right type did not match inference context of sum" + ); - self.alloc_bound(Bound::sum(left, right)) + let mut lock = self.lock(); + if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) { + lock.alloc_bound( + Arc::as_ptr(&self.slab), + Bound::Complete(Final::sum(data1, data2)), + ) + } else { + lock.alloc_bound(Arc::as_ptr(&self.slab), Bound::Sum(left.inner, right.inner)) + } } /// Allocate a new product-type bound, and return a reference to it. @@ -102,10 +110,27 @@ impl Context { /// /// Panics if either of the child types are from a different inference context. pub fn alloc_product(&self, left: Type, right: Type) -> BoundRef { - left.bound.root().assert_matches_context(self); - right.bound.root().assert_matches_context(self); + assert_eq!( + left.ctx, *self, + "left type did not match inference context of product" + ); + assert_eq!( + right.ctx, *self, + "right type did not match inference context of product" + ); - self.alloc_bound(Bound::product(left, right)) + let mut lock = self.lock(); + if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) { + lock.alloc_bound( + Arc::as_ptr(&self.slab), + Bound::Complete(Final::product(data1, data2)), + ) + } else { + lock.alloc_bound( + Arc::as_ptr(&self.slab), + Bound::Product(left.inner, right.inner), + ) + } } /// Creates a new handle to the context. @@ -133,7 +158,7 @@ impl Context { /// # Panics /// /// Panics if passed a `BoundRef` that was not allocated by this context. - pub fn get(&self, bound: &BoundRef) -> Bound { + pub(super) fn get(&self, bound: &BoundRef) -> Bound { bound.assert_matches_context(self); let lock = self.lock(); lock.slab[bound.index].shallow_clone() @@ -150,22 +175,37 @@ impl Context { /// probably a bug. /// /// Also panics if passed a `BoundRef` that was not allocated by this context. - pub fn reassign_non_complete(&self, bound: BoundRef, new: Bound) { + pub(super) fn reassign_non_complete(&self, bound: BoundRef, new: Bound) { let mut lock = self.lock(); lock.reassign_non_complete(bound, new); } - /// Binds the type to a given bound. If this fails, attach the provided - /// hint to the error. + /// Binds the type to a product bound formed by the two inner types. If this + /// fails, attach the provided hint to the error. /// /// Fails if the type has an existing incompatible bound. - pub fn bind(&self, existing: &Type, new: Bound, hint: &'static str) -> Result<(), Error> { - let existing_root = existing.bound.root(); + pub fn bind_product( + &self, + existing: &Type, + prod_l: &Type, + prod_r: &Type, + hint: &'static str, + ) -> Result<(), Error> { + assert_eq!(existing.ctx, *self); + assert_eq!(prod_l.ctx, *self); + assert_eq!(prod_r.ctx, *self); + + let existing_root = existing.inner.bound.root(); + let new_bound = Bound::Product(prod_l.inner.shallow_clone(), prod_r.inner.shallow_clone()); + let mut lock = self.lock(); - lock.bind(existing_root, new).map_err(|e| Error::Bind { - existing_bound: e.existing, - new_bound: e.new, - hint, + lock.bind(existing_root, new_bound).map_err(|e| { + let new_bound = lock.alloc_bound(Arc::as_ptr(&self.slab), e.new); + Error::Bind { + existing_bound: Type::wrap_bound(self, e.existing), + new_bound: Type::wrap_bound(self, new_bound), + hint, + } }) } @@ -173,11 +213,16 @@ impl Context { /// /// Fails if the bounds on the two types are incompatible pub fn unify(&self, ty1: &Type, ty2: &Type, hint: &'static str) -> Result<(), Error> { + assert_eq!(ty1.ctx, *self); + assert_eq!(ty2.ctx, *self); let mut lock = self.lock(); - lock.unify(ty1, ty2).map_err(|e| Error::Bind { - existing_bound: e.existing, - new_bound: e.new, - hint, + lock.unify(&ty1.inner, &ty2.inner).map_err(|e| { + let new_bound = lock.alloc_bound(Arc::as_ptr(&self.slab), e.new); + Error::Bind { + existing_bound: Type::wrap_bound(self, e.existing), + new_bound: Type::wrap_bound(self, new_bound), + hint, + } }) } @@ -257,7 +302,7 @@ pub struct OccursCheckId { } struct BindError { - existing: Bound, + existing: BoundRef, new: Bound, } @@ -270,6 +315,16 @@ struct LockedContext<'ctx> { } impl<'ctx> LockedContext<'ctx> { + fn alloc_bound(&mut self, ctx_ptr: *const Mutex>, bound: Bound) -> BoundRef { + self.slab.push(bound); + let index = self.slab.len() - 1; + + BoundRef { + context: ctx_ptr, + index, + } + } + fn reassign_non_complete(&mut self, bound: BoundRef, new: Bound) { assert!( !matches!(self.slab[bound.index], Bound::Complete(..)), @@ -278,10 +333,29 @@ impl<'ctx> LockedContext<'ctx> { self.slab[bound.index] = new; } + /// It is a common situation that we are pairing two types, and in the + /// case that they are both complete, we want to pair the complete types. + /// + /// This method deals with all the annoying/complicated member variable + /// paths to get the actual complete data out. + fn complete_pair_data( + &self, + inn1: &TypeInner, + inn2: &TypeInner, + ) -> Option<(Arc, Arc)> { + let bound1 = &self.slab[inn1.bound.root().index]; + let bound2 = &self.slab[inn2.bound.root().index]; + if let (Bound::Complete(ref data1), Bound::Complete(ref data2)) = (bound1, bound2) { + Some((Arc::clone(data1), Arc::clone(data2))) + } else { + None + } + } + /// Unify the type with another one. /// /// Fails if the bounds on the two types are incompatible - fn unify(&mut self, existing: &Type, other: &Type) -> Result<(), BindError> { + fn unify(&mut self, existing: &TypeInner, other: &TypeInner) -> Result<(), BindError> { existing.bound.unify(&other.bound, |x_bound, y_bound| { self.bind(x_bound, self.slab[y_bound.index].shallow_clone()) }) @@ -290,7 +364,7 @@ impl<'ctx> LockedContext<'ctx> { fn bind(&mut self, existing: BoundRef, new: Bound) -> Result<(), BindError> { let existing_bound = self.slab[existing.index].shallow_clone(); let bind_error = || BindError { - existing: existing_bound.shallow_clone(), + existing: existing.clone(), new: new.shallow_clone(), }; @@ -342,24 +416,19 @@ impl<'ctx> LockedContext<'ctx> { // // It also gives the user access to more information about the type, // prior to finalization. - let y1_bound = &self.slab[y1.bound.root().index]; - let y2_bound = &self.slab[y2.bound.root().index]; - if let (Bound::Complete(data1), Bound::Complete(data2)) = (y1_bound, y2_bound) { + if let Some((data1, data2)) = self.complete_pair_data(y1, y2) { self.reassign_non_complete( existing, Bound::Complete(if let Bound::Sum(..) = existing_bound { - Final::sum(Arc::clone(data1), Arc::clone(data2)) + Final::sum(data1, data2) } else { - Final::product(Arc::clone(data1), Arc::clone(data2)) + Final::product(data1, data2) }), ); } Ok(()) } - (x, y) => Err(BindError { - existing: x.shallow_clone(), - new: y.shallow_clone(), - }), + (_, _) => Err(bind_error()), } } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 90ae2b54..47220853 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -71,7 +71,7 @@ //! use self::union_bound::{PointerLike, UbElement}; -use crate::dag::{Dag, DagLike, NoSharing}; +use crate::dag::{DagLike, NoSharing}; use crate::Tmr; use std::collections::HashSet; @@ -94,8 +94,8 @@ pub use final_data::{CompleteBound, Final}; pub enum Error { /// An attempt to bind a type conflicted with an existing bound on the type Bind { - existing_bound: Bound, - new_bound: Bound, + existing_bound: Type, + new_bound: Type, hint: &'static str, }, /// Two unequal complete types were attempted to be unified @@ -105,7 +105,7 @@ pub enum Error { hint: &'static str, }, /// A type is recursive (i.e., occurs within itself), violating the "occurs check" - OccursCheck { infinite_bound: Bound }, + OccursCheck { infinite_bound: Type }, /// Attempted to combine two nodes which had different type inference /// contexts. This is probably a programming error. InferenceContextMismatch, @@ -149,35 +149,16 @@ impl fmt::Display for Error { impl std::error::Error for Error {} /// The state of a [`Type`] based on all constraints currently imposed on it. -#[derive(Clone, Debug)] -pub enum Bound { +#[derive(Clone)] +enum Bound { /// Fully-unconstrained type Free(String), /// Fully-constrained (i.e. complete) type, which has no free variables. Complete(Arc), /// A sum of two other types - Sum(Type, Type), + Sum(TypeInner, TypeInner), /// A product of two other types - Product(Type, Type), -} - -impl fmt::Display for Bound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Bound::Free(s) => f.write_str(s), - Bound::Complete(comp) => comp.fmt(f), - Bound::Sum(ty1, ty2) => { - ty1.fmt(f)?; - f.write_str(" + ")?; - ty2.fmt(f) - } - Bound::Product(ty1, ty2) => { - ty1.fmt(f)?; - f.write_str(" × ")?; - ty2.fmt(f) - } - } - } + Product(TypeInner, TypeInner), } impl Bound { @@ -189,22 +170,6 @@ impl Bound { pub fn shallow_clone(&self) -> Bound { self.clone() } - - fn sum(a: Type, b: Type) -> Self { - if let (Some(adata), Some(bdata)) = (a.final_data(), b.final_data()) { - Bound::Complete(Final::sum(adata, bdata)) - } else { - Bound::Sum(a, b) - } - } - - fn product(a: Type, b: Type) -> Self { - if let (Some(adata), Some(bdata)) = (a.final_data(), b.final_data()) { - Bound::Complete(Final::product(adata, bdata)) - } else { - Bound::Product(a, b) - } - } } /// Source or target type of a Simplicity expression. @@ -215,27 +180,34 @@ impl Bound { /// first one. #[derive(Clone)] pub struct Type { + /// Handle to the type context. ctx: Context, + /// The actual contents of the type. + inner: TypeInner, +} + +#[derive(Clone)] +struct TypeInner { /// A set of constraints, which maintained by the union-bound algorithm and /// is progressively tightened as type inference proceeds. bound: UbElement, } +impl TypeInner { + fn shallow_clone(&self) -> Self { + self.clone() + } +} + impl Type { /// Return an unbound type with the given name pub fn free(ctx: &Context, name: String) -> Self { - Type { - ctx: ctx.shallow_clone(), - bound: UbElement::new(ctx.alloc_free(name)), - } + Self::wrap_bound(ctx, ctx.alloc_free(name)) } /// Create the unit type. pub fn unit(ctx: &Context) -> Self { - Type { - ctx: ctx.shallow_clone(), - bound: UbElement::new(ctx.alloc_unit()), - } + Self::wrap_bound(ctx, ctx.alloc_unit()) } /// Create the type `2^(2^n)` for the given `n`. @@ -247,25 +219,26 @@ impl Type { /// Create the sum of the given `left` and `right` types. pub fn sum(ctx: &Context, left: Self, right: Self) -> Self { - Type { - ctx: ctx.shallow_clone(), - bound: UbElement::new(ctx.alloc_sum(left, right)), - } + Self::wrap_bound(ctx, ctx.alloc_sum(left, right)) } /// Create the product of the given `left` and `right` types. pub fn product(ctx: &Context, left: Self, right: Self) -> Self { - Type { - ctx: ctx.shallow_clone(), - bound: UbElement::new(ctx.alloc_product(left, right)), - } + Self::wrap_bound(ctx, ctx.alloc_product(left, right)) } /// Create a complete type. pub fn complete(ctx: &Context, final_data: Arc) -> Self { + Self::wrap_bound(ctx, ctx.alloc_complete(final_data)) + } + + fn wrap_bound(ctx: &Context, bound: BoundRef) -> Self { + bound.assert_matches_context(ctx); Type { ctx: ctx.shallow_clone(), - bound: UbElement::new(ctx.alloc_complete(final_data)), + inner: TypeInner { + bound: UbElement::new(bound), + }, } } @@ -278,8 +251,8 @@ impl Type { } /// Accessor for this type's bound - pub fn bound(&self) -> Bound { - self.ctx.get(&self.bound.root()) + fn bound(&self) -> Bound { + self.ctx.get(&self.inner.bound.root()) } /// Accessor for the TMR of this type, if it is final @@ -316,7 +289,7 @@ impl Type { } // Done with sharing tracker. Actual algorithm follows. - let root = self.bound.root(); + let root = self.inner.bound.root(); let bound = self.ctx.get(&root); if let Bound::Complete(ref data) = bound { return Ok(Arc::clone(data)); @@ -344,7 +317,7 @@ impl Type { } if !in_progress.insert(id) { return Err(Error::OccursCheck { - infinite_bound: self.ctx.get(&bound), + infinite_bound: Type::wrap_bound(&self.ctx, bound), }); } @@ -360,9 +333,8 @@ impl Type { // Now that we know our types have finite size, we can safely use a // post-order iterator to finalize them. let mut finalized = vec![]; - for data in self.shallow_clone().post_order_iter::() { - let bound = data.node.bound.root(); - let bound_get = self.ctx.get(&bound); + for data in (&self.ctx, self.inner.bound.root()).post_order_iter::() { + let bound_get = data.node.0.get(&data.node.1); let final_data = match bound_get { Bound::Free(_) => Final::unit(), Bound::Complete(ref arc) => Arc::clone(arc), @@ -378,7 +350,7 @@ impl Type { if !matches!(bound_get, Bound::Complete(..)) { self.ctx - .reassign_non_complete(bound, Bound::Complete(Arc::clone(&final_data))); + .reassign_non_complete(data.node.1, Bound::Complete(Arc::clone(&final_data))); } finalized.push(final_data); } @@ -403,7 +375,7 @@ const MAX_DISPLAY_DEPTH: usize = 64; impl fmt::Debug for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - for data in (&self.ctx, self.bound.root()) + for data in (&self.ctx, self.inner.bound.root()) .verbose_pre_order_iter::(Some(MAX_DISPLAY_DEPTH)) { if data.depth == MAX_DISPLAY_DEPTH { @@ -436,7 +408,7 @@ impl fmt::Debug for Type { impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - for data in (&self.ctx, self.bound.root()) + for data in (&self.ctx, self.inner.bound.root()) .verbose_pre_order_iter::(Some(MAX_DISPLAY_DEPTH)) { if data.depth == MAX_DISPLAY_DEPTH { @@ -467,22 +439,6 @@ impl fmt::Display for Type { } } -impl DagLike for Type { - type Node = Type; - fn data(&self) -> &Type { - self - } - - fn as_dag_node(&self) -> Dag { - match self.bound() { - Bound::Free(..) | Bound::Complete(..) => Dag::Nullary, - Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => { - Dag::Binary(ty1.shallow_clone(), ty2.shallow_clone()) - } - } - } -} - #[cfg(test)] mod tests { use super::*;