Skip to content

Commit

Permalink
Merge #25: Fixes
Browse files Browse the repository at this point in the history
a97ae56 Add DoublePowwUsize (Christian Lewe)
85a8104 Add NonZeroPow2Usize (Christian Lewe)
2d122bc Use NonZeroUsize (Christian Lewe)
adc9408 Add UnsignedDecimal (Christian Lewe)
a519bd0 Add false | true shorthands (Christian Lewe)
b49cc42 Named: Add OIH shorthands (Christian Lewe)
51e3f76 Switch Vec<A> to Arc<[A]> (Christian Lewe)
f5640d2 Compile inner single expression (Christian Lewe)
b39ad29 Grammar: Make rules atomic (Christian Lewe)
1a6a9e2 Test: Print compressed program (Christian Lewe)
d50ea74 Fix boolean match statement (Christian Lewe)
4f5983b Fix inferred bound of list literals (Christian Lewe)

Pull request description:

  Fixes and quality-of-live improvements that came up during the other PRs. I created a separate PR so we can merge these fixes faster.

ACKs for top commit:
  apoelstra:
    ACK a97ae56

Tree-SHA512: c2c78af5467c2d54e71fa15238b21c5919074643775ea0354c5cf203f67fc87765dbe3bf4d20426c47efe23715f1b372928242584011b8910f0f202dab1eb4e4
  • Loading branch information
uncomputable committed Apr 21, 2024
2 parents 2501c34 + a97ae56 commit dff5781
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 85 deletions.
38 changes: 22 additions & 16 deletions src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ use std::{str::FromStr, sync::Arc};
use simplicity::{jet::Elements, node, Cmr, FailEntropy};

use crate::array::{BTreeSlice, Partition};
use crate::num::NonZeroPow2Usize;
use crate::parse::{Pattern, SingleExpressionInner, UIntType};
use crate::{
named::{ConstructExt, NamedConstructNode, ProgExt},
parse::{
Expression, ExpressionInner, FuncCall, FuncType, Program, SingleExpression, Statement, Type,
},
parse::{Expression, ExpressionInner, FuncCall, FuncType, Program, Statement, Type},
scope::GlobalScope,
ProgNode,
};
Expand Down Expand Up @@ -107,14 +106,14 @@ impl Expression {
scope.pop_scope();
res
}
ExpressionInner::SingleExpression(e) => e.eval(scope, reqd_ty),
ExpressionInner::SingleExpression(e) => e.inner.eval(scope, reqd_ty),
}
}
}

impl SingleExpression {
impl SingleExpressionInner {
pub fn eval(&self, scope: &mut GlobalScope, reqd_ty: Option<&Type>) -> ProgNode {
let res = match &self.inner {
let res = match self {
SingleExpressionInner::Unit => ProgNode::unit(),
SingleExpressionInner::Left(l) => {
let l = l.eval(scope, None);
Expand Down Expand Up @@ -165,15 +164,24 @@ impl SingleExpression {
right,
} => {
let mut l_scope = scope.clone();
if let Some(x) = left.pattern.get_identifier() {
l_scope.insert(Pattern::Identifier(x.clone()));
}
l_scope.insert(
left.pattern
.get_identifier()
.cloned()
.map(Pattern::Identifier)
.unwrap_or(Pattern::Ignore),
);
let l_compiled = left.expression.eval(&mut l_scope, reqd_ty);

let mut r_scope = scope.clone();
if let Some(y) = right.pattern.get_identifier() {
r_scope.insert(Pattern::Identifier(y.clone()));
}
r_scope.insert(
right
.pattern
.get_identifier()
.cloned()
.map(Pattern::Identifier)
.unwrap_or(Pattern::Ignore),
);
let r_compiled = right.expression.eval(&mut r_scope, reqd_ty);

// TODO: Enforce target type A + B for m_expr
Expand Down Expand Up @@ -202,12 +210,10 @@ impl SingleExpression {
let bound = if let Some(Type::List(_, bound)) = reqd_ty {
*bound
} else {
elements.len().next_power_of_two()
NonZeroPow2Usize::next(elements.len().saturating_add(1))
};
debug_assert!(bound.is_power_of_two());
debug_assert!(2 <= bound);

let partition = Partition::from_slice(&nodes, bound / 2);
let partition = Partition::from_slice(&nodes, bound.get() / 2);
let process = |block: &[ProgNode]| -> ProgNode {
if block.is_empty() {
ProgNode::injl(ProgNode::unit())
Expand Down
19 changes: 8 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod array;
pub mod compile;
pub mod dummy_env;
pub mod named;
pub mod num;
pub mod parse;
pub mod scope;

Expand Down Expand Up @@ -34,7 +35,6 @@ use crate::{
#[grammar = "minimal.pest"]
pub struct IdentParser;


pub fn _compile(file: &Path) -> Arc<Node<Named<Commit<Elements>>>> {
let file = std::fs::read_to_string(file).unwrap();
let mut pairs = IdentParser::parse(Rule::program, &file).unwrap_or_else(|e| panic!("{}", e));
Expand Down Expand Up @@ -179,22 +179,12 @@ mod tests {
let prog = Program { statements: stmts };
let mut scope = GlobalScope::new();
let simplicity_prog = prog.eval(&mut scope);
let mut vec = Vec::new();
let mut writer = BitWriter::new(&mut vec);
encode::encode_program(&simplicity_prog, &mut writer).unwrap();
println!("{}", Base64Display::new(&vec, &STANDARD));
dbg!(&simplicity_prog);
let commit_node = simplicity_prog
.finalize_types_main()
.expect("Type check error");
// let commit_node = commit_node.to_commit_node();
let simplicity_prog =
Arc::<_>::try_unwrap(commit_node).expect("Only one reference to commit node");
dbg!(&simplicity_prog);
let mut vec = Vec::new();
let mut writer = BitWriter::new(&mut vec);
let _encoded = encode::encode_program(&simplicity_prog, &mut writer).unwrap();
println!("{}", Base64Display::new(&vec, &STANDARD));

struct MyConverter;

Expand Down Expand Up @@ -245,6 +235,13 @@ mod tests {
let redeem_prog = simplicity_prog
.convert::<NoSharing, Redeem<Elements>, _>(&mut MyConverter)
.unwrap();

let mut vec = Vec::new();
let mut writer = BitWriter::new(&mut vec);
let _encoded = encode::encode_program(&redeem_prog, &mut writer).unwrap();
dbg!(&redeem_prog);
println!("{}", Base64Display::new(&vec, &STANDARD));

let mut bit_mac = BitMachine::for_program(&redeem_prog);
let env = dummy_env::dummy();
bit_mac
Expand Down
32 changes: 16 additions & 16 deletions src/minimal.pest
Original file line number Diff line number Diff line change
Expand Up @@ -15,50 +15,50 @@ witness_name = @{ (ASCII_ALPHANUMERIC | "_")+ }
reserved = _{ jet | builtin }

variable_pattern = { identifier }
ignore_pattern = { "_" }
ignore_pattern = @{ "_" }
product_pattern = { "(" ~ pattern ~ "," ~ pattern ~ ")" }
array_pattern = { "[" ~ pattern ~ ("," ~ pattern)* ~ ","? ~ "]" }
pattern = { ignore_pattern | product_pattern | array_pattern | variable_pattern }
assignment = { "let" ~ pattern ~ (":" ~ ty)? ~ "=" ~ expression }

left_pattern = { "Left(" ~ identifier ~ ")" }
right_pattern = { "Right(" ~ identifier ~ ")" }
none_pattern = { "None" }
none_pattern = @{ "None" }
some_pattern = { "Some(" ~ identifier ~ ")" }
false_pattern = { "false" }
true_pattern = { "true" }
false_pattern = @{ "false" }
true_pattern = @{ "true" }
match_pattern = { left_pattern | right_pattern | none_pattern | some_pattern | false_pattern | true_pattern }

unit_type = { "()" }
unit_type = @{ "()" }
sum_type = { "Either<" ~ ty ~ "," ~ ty ~ ">" }
product_type = { "(" ~ ty ~ "," ~ ty ~ ")" }
option_type = { "Option<" ~ ty ~ ">" }
boolean_type = { "bool" }
unsigned_type = { "u128" | "u256" | "u16" | "u32" | "u64" | "u1" | "u2" | "u4" | "u8" }
array_size = { ASCII_DIGIT+ }
boolean_type = @{ "bool" }
unsigned_type = @{ "u128" | "u256" | "u16" | "u32" | "u64" | "u1" | "u2" | "u4" | "u8" }
array_size = @{ ASCII_DIGIT+ }
array_type = { "[" ~ ty ~ ";" ~ array_size ~ "]" }
list_bound = { ASCII_DIGIT+ }
list_bound = @{ ASCII_DIGIT+ }
list_type = { "List<" ~ ty ~ "," ~ list_bound ~ ">" }
ty = { unit_type | sum_type | product_type | option_type | boolean_type | unsigned_type | array_type | list_type }

expression = { block_expression | single_expression }
block_expression = { "{" ~ (statement ~ ";")* ~ expression ~ "}" }
unit_expr = { "()" }
unit_expr = @{ "()" }
left_expr = { "Left(" ~ expression ~ ")" }
right_expr = { "Right(" ~ expression ~ ")" }
product_expr = { "(" ~ expression ~ "," ~ expression ~ ")" }
none_expr = { "None" }
none_expr = @{ "None" }
some_expr = { "Some(" ~ expression ~ ")" }
false_expr = { "false" }
true_expr = { "true" }
false_expr = @{ "false" }
true_expr = @{ "true" }
jet_expr = { jet ~ "(" ~ (expression ~ ("," ~ expression)*)? ~ ")" }
unwrap_left_expr = { "unwrap_left(" ~ expression ~ ")" }
unwrap_right_expr = { "unwrap_right(" ~ expression ~ ")" }
unwrap_expr = { "unwrap(" ~ expression ~ ")" }
func_call = { jet_expr | unwrap_left_expr | unwrap_right_expr | unwrap_expr }
unsigned_integer = { ASCII_DIGIT+ }
bit_string = { "0b" ~ ASCII_BIN_DIGIT+ }
byte_string = { "0x" ~ ASCII_HEX_DIGIT+ }
unsigned_integer = @{ ASCII_DIGIT+ }
bit_string = @{ "0b" ~ ASCII_BIN_DIGIT+ }
byte_string = @{ "0x" ~ ASCII_HEX_DIGIT+ }
witness_expr = { "witness(\"" ~ witness_name ~ "\")" }
variable_expr = { identifier }
match_arm = { match_pattern ~ "=>" ~ (single_expression ~ "," | block_expression ~ ","?) }
Expand Down
57 changes: 56 additions & 1 deletion src/named.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl node::Marker for Named<Construct<Elements>> {
}
}

pub trait ProgExt {
pub trait ProgExt: Sized {
fn unit() -> Self;

fn iden() -> Self;
Expand Down Expand Up @@ -173,6 +173,61 @@ pub trait ProgExt {
fn jet(jet: Elements) -> Self;

fn const_word(v: Arc<Value>) -> Self;

fn o() -> SelectorBuilder<Self> {
SelectorBuilder::default().o()
}

fn i() -> SelectorBuilder<Self> {
SelectorBuilder::default().i()
}

fn _false() -> Self {
Self::injl(Self::unit())
}

fn _true() -> Self {
Self::injr(Self::unit())
}
}

#[derive(Debug, Clone, Hash)]
pub struct SelectorBuilder<P> {
selection: Vec<bool>,
program: PhantomData<P>,
}

impl<P> Default for SelectorBuilder<P> {
fn default() -> Self {
Self {
selection: Vec::default(),
program: PhantomData,
}
}
}

impl<P: ProgExt> SelectorBuilder<P> {
pub fn o(mut self) -> Self {
self.selection.push(false);
self
}

pub fn i(mut self) -> Self {
self.selection.push(true);
self
}

pub fn h(self) -> P {
let mut ret = P::iden();
for bit in self.selection.into_iter().rev() {
match bit {
false => ret = P::take(ret),
true => ret = P::drop_(ret),
}
}

ret
}
}

impl ProgExt for ProgNode {
Expand Down
108 changes: 108 additions & 0 deletions src/num.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/// Implementation for newtypes that wrap a number `u8`, `u16`, ...
/// such that the number has some property.
/// The newtype needs to have a constructor `Self::new(inner) -> Option<Self>`.
macro_rules! checked_num {
(
$wrapper: ident,
$inner: ty,
$description: expr
) => {
impl $wrapper {
/// Access the value as a primitive type.
pub const fn get(&self) -> usize {
self.0
}
}

impl std::fmt::Display for $wrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

impl std::fmt::Debug for $wrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}

impl std::str::FromStr for $wrapper {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let n = s.parse::<$inner>().map_err(|e| e.to_string())?;
Self::new(n).ok_or(format!("{s} is not {}", $description))
}
}
};
}

/// An integer that is known to be a power of two with nonzero exponent.
///
/// The integer is equal to 2^n for some n > 0.
///
/// The integer is strictly greater than 1.
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct NonZeroPow2Usize(usize);

impl NonZeroPow2Usize {
/// Smallest power of two with nonzero exponent.
// FIXME `std::option::Option::<T>::unwrap` is not yet stable as a const fn
// pub const TWO: Self = Self::new(2).unwrap();
pub const TWO: Self = Self(2);

/// Create a power of two with nonzero exponent.
pub const fn new(n: usize) -> Option<Self> {
if n.is_power_of_two() && 1 < n {
Some(Self(n))
} else {
None
}
}

/// Create the smallest power of two with nonzero exponent greater equal `n`.
pub const fn next(n: usize) -> Self {
if n < 2 {
Self::TWO
} else {
// FIXME `std::option::Option::<T>::unwrap` is not yet stable as a const fn
// Self::new(n.next_power_of_two()).unwrap()
Self(n.next_power_of_two())
}
}

/// Return the binary logarithm of the value.
///
/// The integer is equal to 2^n. Return n.
pub const fn log2(self) -> u32 {
self.0.trailing_zeros()
}
}

checked_num!(NonZeroPow2Usize, usize, "a power of two greater than 1");

/// An integer that is known to be a power _of a power_ of two.
///
/// The integer is equal to 2^(2^n) for some n ≥ 0.
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct DoublePow2Usize(usize);

checked_num!(DoublePow2Usize, usize, "a double power of two");

impl DoublePow2Usize {
/// Create a double power of two.
pub const fn new(n: usize) -> Option<Self> {
if n.is_power_of_two() && n.trailing_zeros().is_power_of_two() {
Some(Self(n))
} else {
None
}
}

/// Return the binary logarithm _of the binary logarithm_ of the value.
///
/// The integer is equal to 2^(2^n). Return n.
pub const fn log2_log2(self) -> u32 {
self.0.trailing_zeros().trailing_zeros()
}
}
Loading

0 comments on commit dff5781

Please sign in to comment.