diff --git a/Cargo.lock b/Cargo.lock index aca7f199bb5..936aa636e86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2840,7 +2840,7 @@ dependencies = [ "num-bigint", "num-traits", "proptest", - "proptest-derive", + "proptest-derive 0.4.0", "serde", "serde_json", "strum", @@ -2961,6 +2961,8 @@ dependencies = [ "num-bigint", "num-traits", "petgraph", + "proptest", + "proptest-derive 0.5.0", "rangemap", "regex", "rustc-hash", @@ -3511,6 +3513,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "proptest-derive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ff7ff745a347b87471d859a377a9a404361e7efc2a971d73424a6d183c0fc77" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.64", +] + [[package]] name = "quick-error" version = "1.2.3" diff --git a/compiler/noirc_frontend/Cargo.toml b/compiler/noirc_frontend/Cargo.toml index d729dabcb04..581d7f1b61d 100644 --- a/compiler/noirc_frontend/Cargo.toml +++ b/compiler/noirc_frontend/Cargo.toml @@ -36,6 +36,8 @@ strum_macros = "0.24" [dev-dependencies] base64.workspace = true +proptest.workspace = true +proptest-derive = "0.5.0" [features] experimental_parser = [] diff --git a/compiler/noirc_frontend/proptest-regressions/tests/arithmetic_generics.txt b/compiler/noirc_frontend/proptest-regressions/tests/arithmetic_generics.txt new file mode 100644 index 00000000000..80f5c7f1ead --- /dev/null +++ b/compiler/noirc_frontend/proptest-regressions/tests/arithmetic_generics.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc fc27f4091dfa5f938973048209b5fcf22aefa1cfaffaaa3e349f30e9b1f93f49 # shrinks to infix_and_bindings = (((0: numeric bool) % (Numeric(Shared(RefCell { value: Unbound('2, Numeric(bool)) }): bool) + Numeric(Shared(RefCell { value: Unbound('0, Numeric(bool)) }): bool))), [('0, (0: numeric bool)), ('1, (0: numeric bool)), ('2, (0: numeric bool))]) diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 17fae63bc35..2c8a9b6508d 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -98,7 +98,7 @@ impl UnresolvedGeneric { UnresolvedGeneric::Variable(_) => Ok(Kind::Normal), UnresolvedGeneric::Numeric { typ, .. } => { let typ = self.resolve_numeric_kind_type(typ)?; - Ok(Kind::Numeric(Box::new(typ))) + Ok(Kind::numeric(typ)) } UnresolvedGeneric::Resolved(..) => { panic!("Don't know the kind of a resolved generic here") diff --git a/compiler/noirc_frontend/src/ast/mod.rs b/compiler/noirc_frontend/src/ast/mod.rs index e85563691ba..3c6664dd569 100644 --- a/compiler/noirc_frontend/src/ast/mod.rs +++ b/compiler/noirc_frontend/src/ast/mod.rs @@ -19,6 +19,9 @@ pub use visitor::Visitor; pub use expression::*; pub use function::*; +#[cfg(test)] +use proptest_derive::Arbitrary; + use acvm::FieldElement; pub use docs::*; use noirc_errors::Span; @@ -37,6 +40,7 @@ use crate::{ use acvm::acir::AcirField; use iter_extended::vecmap; +#[cfg_attr(test, derive(Arbitrary))] #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Ord, PartialOrd)] pub enum IntegerBitSize { One, diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 9d6d04c07b0..ee79b62671f 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -641,7 +641,7 @@ impl<'context> Elaborator<'context> { let typ = if unresolved_typ.is_type_expression() { self.resolve_type_inner( unresolved_typ.clone(), - &Kind::Numeric(Box::new(Type::default_int_type())), + &Kind::numeric(Type::default_int_type()), ) } else { self.resolve_type(unresolved_typ.clone()) @@ -654,7 +654,7 @@ impl<'context> Elaborator<'context> { }); self.push_err(unsupported_typ_err); } - Kind::Numeric(Box::new(typ)) + Kind::numeric(typ) } else { Kind::Normal } diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 722573bcd38..b296c4f1805 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -414,7 +414,7 @@ impl<'context> Elaborator<'context> { let kind = self .interner .get_global_let_statement(id) - .map(|let_statement| Kind::Numeric(Box::new(let_statement.r#type))) + .map(|let_statement| Kind::numeric(let_statement.r#type)) .unwrap_or(Kind::u32()); // TODO(https://github.com/noir-lang/noir/issues/6238): diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index 2d15d4927c9..a373441b4e0 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -541,7 +541,7 @@ impl<'a> ModCollector<'a> { name: Rc::new(name.to_string()), type_var: TypeVariable::unbound( type_variable_id, - Kind::Numeric(Box::new(typ)), + Kind::numeric(typ), ), span: name.span(), }); diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 88fc87cc500..a0fea3aa774 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -5,6 +5,9 @@ use std::{ rc::Rc, }; +#[cfg(test)] +use proptest_derive::Arbitrary; + use acvm::{AcirField, FieldElement}; use crate::{ @@ -169,6 +172,11 @@ pub enum Kind { } impl Kind { + // Kind::Numeric constructor helper + pub fn numeric(typ: Type) -> Kind { + Kind::Numeric(Box::new(typ)) + } + pub(crate) fn is_error(&self) -> bool { match self.follow_bindings() { Self::Numeric(typ) => *typ == Type::Error, @@ -196,7 +204,7 @@ impl Kind { } pub(crate) fn u32() -> Self { - Self::Numeric(Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo))) + Self::numeric(Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo)) } pub(crate) fn follow_bindings(&self) -> Self { @@ -205,7 +213,7 @@ impl Kind { Self::Normal => Self::Normal, Self::Integer => Self::Integer, Self::IntegerOrField => Self::IntegerOrField, - Self::Numeric(typ) => Self::Numeric(Box::new(typ.follow_bindings())), + Self::Numeric(typ) => Self::numeric(typ.follow_bindings()), } } @@ -671,6 +679,7 @@ impl Shared { /// A restricted subset of binary operators useable on /// type level integers for use in the array length positions of types. +#[cfg_attr(test, derive(Arbitrary))] #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub enum BinaryTypeOperator { Addition, @@ -1422,7 +1431,7 @@ impl Type { if self_kind.unifies(&other_kind) { self_kind } else { - Kind::Numeric(Box::new(Type::Error)) + Kind::numeric(Type::Error) } } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index df687782afb..ce2c58e71c1 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -944,9 +944,7 @@ impl<'interner> Monomorphizer<'interner> { }; let location = self.interner.id_location(expr_id); - if !Kind::Numeric(numeric_typ.clone()) - .unifies(&Kind::Numeric(Box::new(typ.clone()))) - { + if !Kind::Numeric(numeric_typ.clone()).unifies(&Kind::numeric(typ.clone())) { let message = "ICE: Generic's kind does not match expected type"; return Err(MonomorphizationError::InternalError { location, message }); } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index f66a82a622f..b709436cccc 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -1,6 +1,7 @@ #![cfg(test)] mod aliases; +mod arithmetic_generics; mod bound_checks; mod imports; mod metaprogramming; @@ -16,7 +17,6 @@ mod visibility; // A test harness will allow for more expressive and readable tests use std::collections::BTreeMap; -use acvm::{AcirField, FieldElement}; use fm::FileId; use iter_extended::vecmap; @@ -37,7 +37,6 @@ use crate::hir::def_collector::dc_crate::DefCollector; use crate::hir::def_map::{CrateDefMap, LocalModuleId}; use crate::hir_def::expr::HirExpression; use crate::hir_def::stmt::HirStatement; -use crate::hir_def::types::{BinaryTypeOperator, Type}; use crate::monomorphization::ast::Program; use crate::monomorphization::errors::MonomorphizationError; use crate::monomorphization::monomorphize; @@ -3210,122 +3209,6 @@ fn as_trait_path_syntax_no_impl() { assert!(matches!(&errors[0].0, TypeError(TypeCheckError::NoMatchingImplFound { .. }))); } -#[test] -fn arithmetic_generics_canonicalization_deduplication_regression() { - let source = r#" - struct ArrData { - a: [Field; N], - b: [Field; N + N - 1], - } - - fn main() { - let _f: ArrData<5> = ArrData { - a: [0; 5], - b: [0; 9], - }; - } - "#; - let errors = get_program_errors(source); - assert_eq!(errors.len(), 0); -} - -#[test] -fn arithmetic_generics_checked_cast_zeros() { - let source = r#" - struct W {} - - fn foo(_x: W) -> W<(0 * N) / (N % N)> { - W {} - } - - fn bar(_x: W) -> u1 { - N - } - - fn main() -> pub u1 { - let w_0: W<0> = W {}; - let w: W<_> = foo(w_0); - bar(w) - } - "#; - - let errors = get_program_errors(source); - assert_eq!(errors.len(), 0); - - let monomorphization_error = get_monomorphization_error(source); - assert!(monomorphization_error.is_some()); - - // Expect a CheckedCast (0 % 0) failure - let monomorphization_error = monomorphization_error.unwrap(); - if let MonomorphizationError::UnknownArrayLength { ref length, ref err, location: _ } = - monomorphization_error - { - match length { - Type::CheckedCast { from, to } => { - assert!(matches!(*from.clone(), Type::InfixExpr { .. })); - assert!(matches!(*to.clone(), Type::InfixExpr { .. })); - } - _ => panic!("unexpected length: {:?}", length), - } - assert!(matches!( - err, - TypeCheckError::FailingBinaryOp { op: BinaryTypeOperator::Modulo, lhs: 0, rhs: 0, .. } - )); - } else { - panic!("unexpected error: {:?}", monomorphization_error); - } -} - -#[test] -fn arithmetic_generics_checked_cast_indirect_zeros() { - let source = r#" - struct W {} - - fn foo(_x: W) -> W<(N - N) % (N - N)> { - W {} - } - - fn bar(_x: W) -> Field { - N - } - - fn main() { - let w_0: W<0> = W {}; - let w = foo(w_0); - let _ = bar(w); - } - "#; - - let errors = get_program_errors(source); - assert_eq!(errors.len(), 0); - - let monomorphization_error = get_monomorphization_error(source); - assert!(monomorphization_error.is_some()); - - // Expect a CheckedCast (0 % 0) failure - let monomorphization_error = monomorphization_error.unwrap(); - if let MonomorphizationError::UnknownArrayLength { ref length, ref err, location: _ } = - monomorphization_error - { - match length { - Type::CheckedCast { from, to } => { - assert!(matches!(*from.clone(), Type::InfixExpr { .. })); - assert!(matches!(*to.clone(), Type::InfixExpr { .. })); - } - _ => panic!("unexpected length: {:?}", length), - } - match err { - TypeCheckError::ModuloOnFields { lhs, rhs, .. } => { - assert_eq!(lhs.clone(), FieldElement::zero()); - assert_eq!(rhs.clone(), FieldElement::zero()); - } - _ => panic!("expected ModuloOnFields, but found: {:?}", err), - } - } else { - panic!("unexpected error: {:?}", monomorphization_error); - } -} - #[test] fn infer_globals_to_u32_from_type_use() { let src = r#" diff --git a/compiler/noirc_frontend/src/tests/arithmetic_generics.rs b/compiler/noirc_frontend/src/tests/arithmetic_generics.rs new file mode 100644 index 00000000000..b7c4834a84a --- /dev/null +++ b/compiler/noirc_frontend/src/tests/arithmetic_generics.rs @@ -0,0 +1,366 @@ +#![cfg(test)] + +use proptest::arbitrary::any; +use proptest::collection; +use proptest::prelude::*; +use proptest::result::maybe_ok; +use proptest::strategy; + +use acvm::{AcirField, FieldElement}; + +use super::get_program_errors; +use crate::ast::{IntegerBitSize, Signedness}; +use crate::hir::type_check::TypeCheckError; +use crate::hir_def::types::{BinaryTypeOperator, Kind, Type, TypeVariable, TypeVariableId}; +use crate::monomorphization::errors::MonomorphizationError; +use crate::tests::get_monomorphization_error; + +#[test] +fn arithmetic_generics_canonicalization_deduplication_regression() { + let source = r#" + struct ArrData { + a: [Field; N], + b: [Field; N + N - 1], + } + + fn main() { + let _f: ArrData<5> = ArrData { + a: [0; 5], + b: [0; 9], + }; + } + "#; + let errors = get_program_errors(source); + assert_eq!(errors.len(), 0); +} + +#[test] +fn arithmetic_generics_checked_cast_zeros() { + let source = r#" + struct W {} + + fn foo(_x: W) -> W<(0 * N) / (N % N)> { + W {} + } + + fn bar(_x: W) -> u1 { + N + } + + fn main() -> pub u1 { + let w_0: W<0> = W {}; + let w: W<_> = foo(w_0); + bar(w) + } + "#; + + let errors = get_program_errors(source); + assert_eq!(errors.len(), 0); + + let monomorphization_error = get_monomorphization_error(source); + assert!(monomorphization_error.is_some()); + + // Expect a CheckedCast (0 % 0) failure + let monomorphization_error = monomorphization_error.unwrap(); + if let MonomorphizationError::UnknownArrayLength { ref length, ref err, location: _ } = + monomorphization_error + { + match length { + Type::CheckedCast { from, to } => { + assert!(matches!(*from.clone(), Type::InfixExpr { .. })); + assert!(matches!(*to.clone(), Type::InfixExpr { .. })); + } + _ => panic!("unexpected length: {:?}", length), + } + assert!(matches!( + err, + TypeCheckError::FailingBinaryOp { op: BinaryTypeOperator::Modulo, lhs: 0, rhs: 0, .. } + )); + } else { + panic!("unexpected error: {:?}", monomorphization_error); + } +} + +#[test] +fn arithmetic_generics_checked_cast_indirect_zeros() { + let source = r#" + struct W {} + + fn foo(_x: W) -> W<(N - N) % (N - N)> { + W {} + } + + fn bar(_x: W) -> Field { + N + } + + fn main() { + let w_0: W<0> = W {}; + let w = foo(w_0); + let _ = bar(w); + } + "#; + + let errors = get_program_errors(source); + assert_eq!(errors.len(), 0); + + let monomorphization_error = get_monomorphization_error(source); + assert!(monomorphization_error.is_some()); + + // Expect a CheckedCast (0 % 0) failure + let monomorphization_error = monomorphization_error.unwrap(); + if let MonomorphizationError::UnknownArrayLength { ref length, ref err, location: _ } = + monomorphization_error + { + match length { + Type::CheckedCast { from, to } => { + assert!(matches!(*from.clone(), Type::InfixExpr { .. })); + assert!(matches!(*to.clone(), Type::InfixExpr { .. })); + } + _ => panic!("unexpected length: {:?}", length), + } + match err { + TypeCheckError::ModuloOnFields { lhs, rhs, .. } => { + assert_eq!(lhs.clone(), FieldElement::zero()); + assert_eq!(rhs.clone(), FieldElement::zero()); + } + _ => panic!("expected ModuloOnFields, but found: {:?}", err), + } + } else { + panic!("unexpected error: {:?}", monomorphization_error); + } +} + +prop_compose! { + // maximum_size must be non-zero + fn arbitrary_u128_field_element(maximum_size: u128) + (u128_value in any::()) + -> FieldElement + { + assert!(maximum_size != 0); + FieldElement::from(u128_value % maximum_size) + } +} + +// NOTE: this is roughly the same method from acvm/tests/solver +prop_compose! { + // Use both `u128` and hex proptest strategies + fn arbitrary_field_element() + (u128_or_hex in maybe_ok(any::(), "[0-9a-f]{64}")) + -> FieldElement + { + match u128_or_hex { + Ok(number) => FieldElement::from(number), + Err(hex) => FieldElement::from_hex(&hex).expect("should accept any 32 byte hex string"), + } + } +} + +// Generate (arbitrary_unsigned_type, generator for that type) +fn arbitrary_unsigned_type_with_generator() -> BoxedStrategy<(Type, BoxedStrategy)> { + prop_oneof![ + strategy::Just((Type::FieldElement, arbitrary_field_element().boxed())), + any::().prop_map(|bit_size| { + let typ = Type::Integer(Signedness::Unsigned, bit_size); + let maximum_size = typ.integral_maximum_size().unwrap().to_u128(); + (typ, arbitrary_u128_field_element(maximum_size).boxed()) + }), + strategy::Just((Type::Bool, arbitrary_u128_field_element(1).boxed())), + ] + .boxed() +} + +prop_compose! { + fn arbitrary_variable(typ: Type, num_variables: usize) + (variable_index in any::()) + -> Type { + assert!(num_variables != 0); + let id = TypeVariableId(variable_index % num_variables); + let kind = Kind::numeric(typ.clone()); + let var = TypeVariable::unbound(id, kind); + Type::TypeVariable(var) + } +} + +fn first_n_variables(typ: Type, num_variables: usize) -> impl Iterator { + (0..num_variables).map(move |id| { + let id = TypeVariableId(id); + let kind = Kind::numeric(typ.clone()); + TypeVariable::unbound(id, kind) + }) +} + +fn arbitrary_infix_expr( + typ: Type, + arbitrary_value: BoxedStrategy, + num_variables: usize, +) -> impl Strategy { + let leaf = prop_oneof![ + arbitrary_variable(typ.clone(), num_variables), + arbitrary_value.prop_map(move |value| Type::Constant(value, Kind::numeric(typ.clone()))), + ]; + + leaf.prop_recursive( + 8, // 8 levels deep maximum + 256, // Shoot for maximum size of 256 nodes + 10, // We put up to 10 items per collection + |inner| { + (inner.clone(), any::(), inner) + .prop_map(|(lhs, op, rhs)| Type::InfixExpr(Box::new(lhs), op, Box::new(rhs))) + }, + ) +} + +prop_compose! { + // (infix_expr, type, generator) + fn arbitrary_infix_expr_type_gen(num_variables: usize) + (type_and_gen in arbitrary_unsigned_type_with_generator()) + (infix_expr in arbitrary_infix_expr(type_and_gen.clone().0, type_and_gen.clone().1, num_variables), type_and_gen in Just(type_and_gen)) + -> (Type, Type, BoxedStrategy) { + let (typ, value_generator) = type_and_gen; + (infix_expr, typ, value_generator) + } +} + +prop_compose! { + // (Type::InfixExpr, numeric kind, bindings) + fn arbitrary_infix_expr_with_bindings_sized(num_variables: usize) + (infix_type_gen in arbitrary_infix_expr_type_gen(num_variables)) + (values in collection::vec(infix_type_gen.clone().2, num_variables), infix_type_gen in Just(infix_type_gen)) + -> (Type, Type, Vec<(TypeVariable, Type)>) { + let (infix_expr, typ, _value_generator) = infix_type_gen; + let bindings: Vec<_> = first_n_variables(typ.clone(), num_variables) + .zip(values.iter().map(|value| { + Type::Constant(*value, Kind::numeric(typ.clone())) + })) + .collect(); + (infix_expr, typ, bindings) + } +} + +prop_compose! { + // the lint misfires on 'num_variables' + #[allow(unused_variables)] + fn arbitrary_infix_expr_with_bindings(max_num_variables: usize) + (num_variables in any::().prop_map(move |num_variables| (num_variables % max_num_variables).clamp(1, max_num_variables))) + (infix_type_bindings in arbitrary_infix_expr_with_bindings_sized(num_variables), num_variables in Just(num_variables)) + -> (Type, Type, Vec<(TypeVariable, Type)>) { + infix_type_bindings + } +} + +#[test] +fn instantiate_after_canonicalize_smoke_test() { + let field_element_kind = Kind::numeric(Type::FieldElement); + let x_var = TypeVariable::unbound(TypeVariableId(0), field_element_kind.clone()); + let x_type = Type::TypeVariable(x_var.clone()); + let one = Type::Constant(FieldElement::one(), field_element_kind.clone()); + + let lhs = Type::InfixExpr( + Box::new(x_type.clone()), + BinaryTypeOperator::Addition, + Box::new(one.clone()), + ); + let rhs = + Type::InfixExpr(Box::new(one), BinaryTypeOperator::Addition, Box::new(x_type.clone())); + + // canonicalize + let lhs = lhs.canonicalize(); + let rhs = rhs.canonicalize(); + + // bind vars + let two = Type::Constant(FieldElement::one() + FieldElement::one(), field_element_kind.clone()); + x_var.bind(two); + + // canonicalize (expect constant) + let lhs = lhs.canonicalize(); + let rhs = rhs.canonicalize(); + + // ensure we've canonicalized to constants + assert!(matches!(lhs, Type::Constant(..))); + assert!(matches!(rhs, Type::Constant(..))); + + // ensure result kinds are the same as the original kind + assert_eq!(lhs.kind(), field_element_kind); + assert_eq!(rhs.kind(), field_element_kind); + + // ensure results are the same + assert_eq!(lhs, rhs); +} + +proptest! { + #[test] + // Expect cases that don't resolve to constants, e.g. see + // `arithmetic_generics_checked_cast_indirect_zeros` + #[should_panic(expected = "matches!(infix, Type :: Constant(..))")] + fn instantiate_before_or_after_canonicalize(infix_type_bindings in arbitrary_infix_expr_with_bindings(10)) { + let (infix, typ, bindings) = infix_type_bindings; + + // canonicalize + let infix_canonicalized = infix.canonicalize(); + + // bind vars + for (var, binding) in bindings { + var.bind(binding); + } + + // attempt to canonicalize to a constant + let infix = infix.canonicalize(); + let infix_canonicalized = infix_canonicalized.canonicalize(); + + // ensure we've canonicalized to constants + prop_assert!(matches!(infix, Type::Constant(..))); + prop_assert!(matches!(infix_canonicalized, Type::Constant(..))); + + // ensure result kinds are the same as the original kind + let kind = Kind::numeric(typ); + prop_assert_eq!(infix.kind(), kind.clone()); + prop_assert_eq!(infix_canonicalized.kind(), kind); + + // ensure results are the same + prop_assert_eq!(infix, infix_canonicalized); + } + + #[test] + fn instantiate_before_or_after_canonicalize_checked_cast(infix_type_bindings in arbitrary_infix_expr_with_bindings(10)) { + let (infix, typ, bindings) = infix_type_bindings; + + // wrap in CheckedCast + let infix = Type::CheckedCast { + from: Box::new(infix.clone()), + to: Box::new(infix) + }; + + // canonicalize + let infix_canonicalized = infix.canonicalize(); + + // bind vars + for (var, binding) in bindings { + var.bind(binding); + } + + // attempt to canonicalize to a constant + let infix = infix.canonicalize(); + let infix_canonicalized = infix_canonicalized.canonicalize(); + + // ensure result kinds are the same as the original kind + let kind = Kind::numeric(typ); + prop_assert_eq!(infix.kind(), kind.clone()); + prop_assert_eq!(infix_canonicalized.kind(), kind.clone()); + + // ensure the results are still wrapped in CheckedCast's + match (&infix, &infix_canonicalized) { + (Type::CheckedCast { from, to }, Type::CheckedCast { from: from_canonicalized, to: to_canonicalized }) => { + // ensure from's are the same + prop_assert_eq!(from, from_canonicalized); + + // ensure to's have the same kinds + prop_assert_eq!(to.kind(), kind.clone()); + prop_assert_eq!(to_canonicalized.kind(), kind); + } + _ => { + prop_assert!(false, "expected CheckedCast"); + } + } + } +}