Skip to content

Commit

Permalink
types: refactor precomputed data to directly store finalize types
Browse files Browse the repository at this point in the history
This eliminates an `unwrap`, cleans up a bit of code in named_node.rs,
and moves every instance of `Type::from` into types/mod.rs, where it
will be easy to modify or remove in a later commit.

Also removes From<Arc<Final>> for Type, a somewhat-weird From impl which
was introduced in #218. It is now superceded by Type::complete.

In general I would like to remove From impls from Type because they have
an inflexible API and because they have somewhat nonobvious behavior.
  • Loading branch information
apoelstra committed Jun 30, 2024
1 parent dcba2f4 commit 12a1885
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 29 deletions.
8 changes: 2 additions & 6 deletions src/human_encoding/named_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,17 +413,13 @@ impl<J: Jet> NamedConstructNode<J> {
if self.for_main {
// For `main`, only apply type ascriptions *after* inference has completely
// determined the type.
let source_bound =
types::Bound::Complete(Arc::clone(&commit_data.arrow().source));
let source_ty = types::Type::from(source_bound);
let source_ty = types::Type::complete(Arc::clone(&commit_data.arrow().source));
for ty in data.node.cached_data().user_source_types.as_ref() {
if let Err(e) = source_ty.unify(ty, "binding source type annotation") {
self.errors.add(data.node.position(), e);
}
}
let target_bound =
types::Bound::Complete(Arc::clone(&commit_data.arrow().target));
let target_ty = types::Type::from(target_bound);
let target_ty = types::Type::complete(Arc::clone(&commit_data.arrow().target));
for ty in data.node.cached_data().user_target_types.as_ref() {
if let Err(e) = target_ty.unify(ty, "binding target type annotation") {
self.errors.add(data.node.position(), e);
Expand Down
4 changes: 1 addition & 3 deletions src/human_encoding/parse/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,7 @@ fn grammar<J: Jet + 'static>() -> Grammar<Ast<J>> {
Error::BadWordLength { bit_length },
));
}
let ty = types::Type::two_two_n(bit_length.trailing_zeros() as usize)
.final_data()
.unwrap();
let ty = types::Final::two_two_n(bit_length.trailing_zeros() as usize);
// unwrap ok here since literally every sequence of bits is a valid
// value for the given type
let value = iter.read_value(&ty).unwrap();
Expand Down
9 changes: 1 addition & 8 deletions src/types/final_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
//!
use crate::dag::{Dag, DagLike, NoSharing};
use crate::types::{Bound, Type};
use crate::Tmr;

use std::sync::Arc;
Expand Down Expand Up @@ -163,7 +162,7 @@ impl Final {
///
/// The type is precomputed and fast to access.
pub fn two_two_n(n: usize) -> Arc<Self> {
super::precomputed::nth_power_of_2(n).final_data().unwrap()
super::precomputed::nth_power_of_2(n)
}

/// Create the sum of the given `left` and `right` types.
Expand Down Expand Up @@ -227,12 +226,6 @@ impl Final {
}
}

impl From<Arc<Final>> for Type {
fn from(value: Arc<Final>) -> Self {
Type::from(Bound::Complete(value))
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
7 changes: 6 additions & 1 deletion src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ impl Type {
///
/// The type is precomputed and fast to access.
pub fn two_two_n(n: usize) -> Self {
precomputed::nth_power_of_2(n)
Self::complete(precomputed::nth_power_of_2(n))
}

/// Create the sum of the given `left` and `right` types.
Expand All @@ -415,6 +415,11 @@ impl Type {
Type::from(Bound::product(left, right))
}

/// Create a complete type.
pub fn complete(final_data: Arc<Final>) -> Self {
Type::from(Bound::Complete(final_data))
}

/// Clones the `Type`.
///
/// This is the same as just calling `.clone()` but has a different name to
Expand Down
23 changes: 12 additions & 11 deletions src/types/precomputed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,34 @@
use crate::Tmr;

use super::Type;
use super::Final;

use std::cell::RefCell;
use std::convert::TryInto;
use std::sync::Arc;

// Directly use the size of the precomputed TMR table to make sure they're in sync.
const N_POWERS: usize = Tmr::POWERS_OF_TWO.len();

thread_local! {
static POWERS_OF_TWO: RefCell<Option<[Type; N_POWERS]>> = RefCell::new(None);
static POWERS_OF_TWO: RefCell<Option<[Arc<Final>; N_POWERS]>> = RefCell::new(None);
}

fn initialize(write: &mut Option<[Type; N_POWERS]>) {
let one = Type::unit();
fn initialize(write: &mut Option<[Arc<Final>; N_POWERS]>) {
let one = Final::unit();
let mut powers = Vec::with_capacity(N_POWERS);

// Two^(2^0) = Two = (One + One)
let mut power = Type::sum(one.shallow_clone(), one);
powers.push(power.shallow_clone());
let mut power = Final::sum(Arc::clone(&one), one);
powers.push(Arc::clone(&power));

// Two^(2^(i + 1)) = (Two^(2^i) * Two^(2^i))
for _ in 1..N_POWERS {
power = Type::product(power.shallow_clone(), power);
powers.push(power.shallow_clone());
power = Final::product(Arc::clone(&power), power);
powers.push(Arc::clone(&power));
}

let powers: [Type; N_POWERS] = powers.try_into().unwrap();
let powers: [Arc<Final>; N_POWERS] = powers.try_into().unwrap();
*write = Some(powers);
}

Expand All @@ -49,12 +50,12 @@ fn initialize(write: &mut Option<[Type; N_POWERS]>) {
/// # Panics
///
/// Panics if you request a number `n` greater than or equal to [`Tmr::POWERS_OF_TWO`].
pub fn nth_power_of_2(n: usize) -> Type {
pub fn nth_power_of_2(n: usize) -> Arc<Final> {
POWERS_OF_TWO.with(|arr| {
if arr.borrow().is_none() {
initialize(&mut arr.borrow_mut());
}
debug_assert!(arr.borrow().is_some());
arr.borrow().as_ref().unwrap()[n].shallow_clone()
Arc::clone(&arr.borrow().as_ref().unwrap()[n])
})
}

0 comments on commit 12a1885

Please sign in to comment.