diff --git a/parser/src/parser/tests/inlining.rs b/parser/src/parser/tests/inlining.rs index 5011d9a1..c2501a1e 100644 --- a/parser/src/parser/tests/inlining.rs +++ b/parser/src/parser/tests/inlining.rs @@ -1113,3 +1113,256 @@ fn test_inlining_constraints_with_folded_comprehensions_in_evaluator() { assert_eq!(program, expected); } + +/// This test originally reproduced the bug in air-script#340, but as of this commit +/// that bug is fixed. This test remains not to prevent regressions necessarily, but +/// to add a more realistic test case to our test suite, and potentially catch bugs +/// we accidentally introduce in the future which blow up on this code. +#[test] +fn test_repro_issue340() { + // NOTE: This code is exactly what was written in #340, but the only significant + // part is the fact that a panic was raised due to visiting one of the folded + // compehensions in `imm_reconstruction` as if it was a constraint comprehension, + // due to incorrectly propagating the constraint flag in that situation. The + // bug no longer exists, but this test program may be useful for other tests of + // a more real-world nature, so it is preserved here + let root = r#" + def root + + trace_columns: + main: [instruction_word, instruction_bits[32], immediate, s] + + public_inputs: + stack_inputs: [16] + stack_outputs: [16] + + periodic_columns: + k0: [1, 1, 1, 1, 1, 1, 1, 0] + + boundary_constraints: + # define boundary constraints against the main trace at the first row of the trace. + enf instruction_word.first = 0 + enf instruction_word.last = 2 + + integrity_constraints: + # The instruction bit decomposition must be bits + enf b^2 = b for b in instruction_bits + + # Ensure they add up to the instruction word: + let word_sum = sum([2^i * a for (i, a) in (0..32, instruction_bits)]) + enf instruction_word = word_sum + + enf match: + case s: imm_reconstruction([instruction_bits, immediate]) + case !s: immediate = 0 + + # The highest bit is a sign bit, so we sign extend then reconstruct from the other bits + ev imm_reconstruction([instruction_bits[32], immediate]): + let sign_bit = instruction_bits[31] + let high_bit_sum = sum([sign_bit*2^i for i in (11..32)]) + let immediate_bits = instruction_bits[20..31] + let low_bit_sum = sum([immediate_bit * 2^i for (i, immediate_bit) in (0..11, instruction_bits[20..31])]) + enf immediate = low_bit_sum + high_bit_sum + enf sign_bit = 1 + "#; + + let test = ParseTest::new(); + let program = match test.parse_program(root) { + Err(err) => { + test.diagnostics.emit(err); + panic!("expected parsing to succeed, see diagnostics for details"); + } + Ok(ast) => ast, + }; + + let mut pipeline = + ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); + let program = pipeline.run(program).unwrap(); + + let mut expected = Program::new(ident!(root)); + expected.trace_columns.push(trace_segment!( + 0, + "$main", + [ + (instruction_word, 1), + (instruction_bits, 32), + (immediate, 1), + (s, 1) + ] + )); + expected.public_inputs.insert( + ident!(stack_inputs), + PublicInput::new(SourceSpan::UNKNOWN, ident!(stack_inputs), 16), + ); + expected.public_inputs.insert( + ident!(stack_outputs), + PublicInput::new(SourceSpan::UNKNOWN, ident!(stack_outputs), 16), + ); + expected.boundary_constraints.push(enforce!(eq!( + bounded_access!(instruction_word, Boundary::First, Type::Felt), + int!(0) + ))); + expected.boundary_constraints.push(enforce!(eq!( + bounded_access!(instruction_word, Boundary::Last, Type::Felt), + int!(2) + ))); + for i in 0..32 { + let access = ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(ident!(instruction_bits)), + access_type: AccessType::Index(i), + offset: 0, + ty: Some(Type::Felt), + }); + expected + .integrity_constraints + .push(enforce!(eq!(exp!(access.clone(), int!(2)), access.clone()))); + } + let word_sum = (1..32) + .into_iter() + .fold(access!("%lc0", Type::Felt), |acc, i| { + let access = ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(Identifier::new( + miden_diagnostics::SourceSpan::UNKNOWN, + crate::Symbol::intern(format!("%lc{}", i)), + )), + access_type: AccessType::Default, + offset: 0, + ty: Some(Type::Felt), + }); + add!(acc, access) + }); + let high_bit_sum = (33..53) + .into_iter() + .fold(access!("%lc32", Type::Felt), |acc, i| { + let access = ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(Identifier::new( + miden_diagnostics::SourceSpan::UNKNOWN, + crate::Symbol::intern(format!("%lc{}", i)), + )), + access_type: AccessType::Default, + offset: 0, + ty: Some(Type::Felt), + }); + add!(acc, access) + }); + let low_bit_sum = (54..64) + .into_iter() + .fold(access!("%lc53", Type::Felt), |acc, i| { + let access = ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(Identifier::new( + miden_diagnostics::SourceSpan::UNKNOWN, + crate::Symbol::intern(format!("%lc{}", i)), + )), + access_type: AccessType::Default, + offset: 0, + ty: Some(Type::Felt), + }); + add!(acc, access) + }); + let low_bit_sum_body = let_!(low_bit_sum = expr!(low_bit_sum) => + enforce!(eq!(access!(immediate, Type::Felt), add!(access!(low_bit_sum, Type::Felt), access!(high_bit_sum, Type::Felt))), when access!(s, Type::Felt)), + enforce!(eq!(access!(instruction_bits[31], Type::Felt), int!(1)), when access!(s, Type::Felt))); + let high_bit_sum_body = let_!(high_bit_sum = expr!(high_bit_sum) + => let_!("%lc53" = expr!(mul!(access!(instruction_bits[20], Type::Felt), int!(1))) + => let_!("%lc54" = expr!(mul!(access!(instruction_bits[21], Type::Felt), int!(2))) + => let_!("%lc55" = expr!(mul!(access!(instruction_bits[22], Type::Felt), int!(4))) + => let_!("%lc56" = expr!(mul!(access!(instruction_bits[23], Type::Felt), int!(8))) + => let_!("%lc57" = expr!(mul!(access!(instruction_bits[24], Type::Felt), int!(16))) + => let_!("%lc58" = expr!(mul!(access!(instruction_bits[25], Type::Felt), int!(32))) + => let_!("%lc59" = expr!(mul!(access!(instruction_bits[26], Type::Felt), int!(64))) + => let_!("%lc60" = expr!(mul!(access!(instruction_bits[27], Type::Felt), int!(128))) + => let_!("%lc61" = expr!(mul!(access!(instruction_bits[28], Type::Felt), int!(256))) + => let_!("%lc62" = expr!(mul!(access!(instruction_bits[29], Type::Felt), int!(512))) + => let_!("%lc63" = expr!(mul!(access!(instruction_bits[30], Type::Felt), int!(1024))) + => low_bit_sum_body)))))))))))); + let word_sum_body = let_!(word_sum = expr!(word_sum) + => enforce!(eq!(access!(instruction_word, Type::Felt), access!(word_sum, Type::Felt))), + let_!("%lc32" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(2048))) + => let_!("%lc33" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(4096))) + => let_!("%lc34" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(8192))) + => let_!("%lc35" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(16384))) + => let_!("%lc36" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(32768))) + => let_!("%lc37" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(65536))) + => let_!("%lc38" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(131072))) + => let_!("%lc39" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(262144))) + => let_!("%lc40" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(524288))) + => let_!("%lc41" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(1048576))) + => let_!("%lc42" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(2097152))) + => let_!("%lc43" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(4194304))) + => let_!("%lc44" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(8388608))) + => let_!("%lc45" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(16777216))) + => let_!("%lc46" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(33554432))) + => let_!("%lc47" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(67108864))) + => let_!("%lc48" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(134217728))) + => let_!("%lc49" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(268435456))) + => let_!("%lc50" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(536870912))) + => let_!("%lc51" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(1073741824))) + => let_!("%lc52" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(2147483648))) + => high_bit_sum_body))))))))))))))))))))), + enforce!(eq!(access!(immediate, Type::Felt), int!(0)), when not!(access!(s, Type::Felt))) + ); + + expected.integrity_constraints.push( + let_!("%lc0" = expr!(mul!(int!(1), access!(instruction_bits[0], Type::Felt))) + => let_!("%lc1" = expr!(mul!(int!(2), access!(instruction_bits[1], Type::Felt))) + => let_!("%lc2" = expr!(mul!(int!(4), access!(instruction_bits[2], Type::Felt))) + => let_!("%lc3" = expr!(mul!(int!(8), access!(instruction_bits[3], Type::Felt))) + => let_!("%lc4" = expr!(mul!(int!(16), access!(instruction_bits[4], Type::Felt))) + => let_!("%lc5" = expr!(mul!(int!(32), access!(instruction_bits[5], Type::Felt))) + => let_!("%lc6" = expr!(mul!(int!(64), access!(instruction_bits[6], Type::Felt))) + => let_!("%lc7" = expr!(mul!(int!(128), access!(instruction_bits[7], Type::Felt))) + => let_!("%lc8" = expr!(mul!(int!(256), access!(instruction_bits[8], Type::Felt))) + => let_!("%lc9" = expr!(mul!(int!(512), access!(instruction_bits[9], Type::Felt))) + => let_!("%lc10" = expr!(mul!(int!(1024), access!(instruction_bits[10], Type::Felt))) + => let_!("%lc11" = expr!(mul!(int!(2048), access!(instruction_bits[11], Type::Felt))) + => let_!("%lc12" = expr!(mul!(int!(4096), access!(instruction_bits[12], Type::Felt))) + => let_!("%lc13" = expr!(mul!(int!(8192), access!(instruction_bits[13], Type::Felt))) + => let_!("%lc14" = expr!(mul!(int!(16384), access!(instruction_bits[14], Type::Felt))) + => let_!("%lc15" = expr!(mul!(int!(32768), access!(instruction_bits[15], Type::Felt))) + => let_!("%lc16" = expr!(mul!(int!(65536), access!(instruction_bits[16], Type::Felt))) + => let_!("%lc17" = expr!(mul!(int!(131072), access!(instruction_bits[17], Type::Felt))) + => let_!("%lc18" = expr!(mul!(int!(262144), access!(instruction_bits[18], Type::Felt))) + => let_!("%lc19" = expr!(mul!(int!(524288), access!(instruction_bits[19], Type::Felt))) + => let_!("%lc20" = expr!(mul!(int!(1048576), access!(instruction_bits[20], Type::Felt))) + => let_!("%lc21" = expr!(mul!(int!(2097152), access!(instruction_bits[21], Type::Felt))) + => let_!("%lc22" = expr!(mul!(int!(4194304), access!(instruction_bits[22], Type::Felt))) + => let_!("%lc23" = expr!(mul!(int!(8388608), access!(instruction_bits[23], Type::Felt))) + => let_!("%lc24" = expr!(mul!(int!(16777216), access!(instruction_bits[24], Type::Felt))) + => let_!("%lc25" = expr!(mul!(int!(33554432), access!(instruction_bits[25], Type::Felt))) + => let_!("%lc26" = expr!(mul!(int!(67108864), access!(instruction_bits[26], Type::Felt))) + => let_!("%lc27" = expr!(mul!(int!(134217728), access!(instruction_bits[27], Type::Felt))) + => let_!("%lc28" = expr!(mul!(int!(268435456), access!(instruction_bits[28], Type::Felt))) + => let_!("%lc29" = expr!(mul!(int!(536870912), access!(instruction_bits[29], Type::Felt))) + => let_!("%lc30" = expr!(mul!(int!(1073741824), access!(instruction_bits[30], Type::Felt))) + => let_!("%lc31" = expr!(mul!(int!(2147483648), access!(instruction_bits[31], Type::Felt))) + => word_sum_body)))))))))))))))))))))))))))))))), + ); + // The evaluator definition is never modified by constant propagation or inlining + let body = vec![ + let_!(sign_bit = expr!(access!(instruction_bits[31], Type::Felt)) + => let_!(high_bit_sum = expr!(call!(sum(expr!(lc!(((i, range!(11..32))) => mul!(access!(sign_bit, Type::Felt), exp!(int!(2), access!(i, Type::Felt)))))))) + => let_!(immediate_bits = expr!(slice!(instruction_bits, 20..31, Type::Vector(11))) + => let_!(low_bit_sum = expr!(call!(sum(expr!(lc!(((i, range!(0..11)), (immediate_bit, expr!(slice!(instruction_bits, 20..31, Type::Vector(11))))) => mul!(access!(immediate_bit, Type::Felt), exp!(int!(2), access!(i, Type::Felt)))))))) + => enforce!(eq!(access!(immediate, Type::Felt), add!(access!(low_bit_sum, Type::Felt), access!(high_bit_sum, Type::Felt)))), + enforce!(eq!(access!(sign_bit, Type::Felt), int!(1))))))), + ]; + expected.evaluators.insert( + function_ident!(root, imm_reconstruction), + EvaluatorFunction::new( + SourceSpan::UNKNOWN, + ident!(imm_reconstruction), + vec![trace_segment!( + 0, + "%2", + [(instruction_bits, 32), (immediate, 1)] + )], + body, + ), + ); + + assert_eq!(program, expected); +} diff --git a/parser/src/transforms/inlining.rs b/parser/src/transforms/inlining.rs index 25e0f4ae..b8a34383 100644 --- a/parser/src/transforms/inlining.rs +++ b/parser/src/transforms/inlining.rs @@ -448,7 +448,11 @@ impl<'a> Inlining<'a> { Ok(vec![Statement::Expr(folded)]) } Expr::ListComprehension(lc) => { + // Expand the comprehension, but ensure we don't treat it like a comprehension constraint + let in_cc = core::mem::replace(&mut self.in_comprehension_constraint, false); let mut expanded = self.expand_comprehension(lc)?; + self.in_comprehension_constraint = in_cc; + // Apply the fold to the expanded comprehension in the bottom of the let tree with_let_result(self, &mut expanded, |inliner, value| { match value { // The result value of expanding a comprehension _must_ be a vector