Skip to content

Commit

Permalink
fix: incorrect propagation of constraint flag during inlining
Browse files Browse the repository at this point in the history
This fixes an issue which caused expansion of a folded comprehension
within the context of a constraint to be improperly treated as a
constraint comprehension. This expansion occurs during the inlining
phase.

This bug occurred because a flag determines whether a given
comprehension is a constraint comprehension or a regular list
comprehension, and is set to true when expanding a constraint.
However, while expanding a constraint, we may encounter a call to a
function, such as one of the list folding builtins, in which case
a comprehension passed as an argument to those builtins is not
expressing a constraint, so the flag should be false while expanding
it.

We already manage this flag correctly in other contexts, but this one
was missed during development.

Fixes #340
  • Loading branch information
bitwalker committed Jul 23, 2023
1 parent 4d83841 commit 451119b
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 0 deletions.
253 changes: 253 additions & 0 deletions parser/src/parser/tests/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1113,3 +1113,256 @@ fn test_inlining_constraints_with_folded_comprehensions_in_evaluator() {

assert_eq!(program, expected);
}

/// This test originally reproduced the bug in air-script#340, but as of this commit
/// that bug is fixed. This test remains not to prevent regressions necessarily, but
/// to add a more realistic test case to our test suite, and potentially catch bugs
/// we accidentally introduce in the future which blow up on this code.
#[test]
fn test_repro_issue340() {
// NOTE: This code is exactly what was written in #340, but the only significant
// part is the fact that a panic was raised due to visiting one of the folded
// compehensions in `imm_reconstruction` as if it was a constraint comprehension,
// due to incorrectly propagating the constraint flag in that situation. The
// bug no longer exists, but this test program may be useful for other tests of
// a more real-world nature, so it is preserved here
let root = r#"
def root
trace_columns:
main: [instruction_word, instruction_bits[32], immediate, s]
public_inputs:
stack_inputs: [16]
stack_outputs: [16]
periodic_columns:
k0: [1, 1, 1, 1, 1, 1, 1, 0]
boundary_constraints:
# define boundary constraints against the main trace at the first row of the trace.
enf instruction_word.first = 0
enf instruction_word.last = 2
integrity_constraints:
# The instruction bit decomposition must be bits
enf b^2 = b for b in instruction_bits
# Ensure they add up to the instruction word:
let word_sum = sum([2^i * a for (i, a) in (0..32, instruction_bits)])
enf instruction_word = word_sum
enf match:
case s: imm_reconstruction([instruction_bits, immediate])
case !s: immediate = 0
# The highest bit is a sign bit, so we sign extend then reconstruct from the other bits
ev imm_reconstruction([instruction_bits[32], immediate]):
let sign_bit = instruction_bits[31]
let high_bit_sum = sum([sign_bit*2^i for i in (11..32)])
let immediate_bits = instruction_bits[20..31]
let low_bit_sum = sum([immediate_bit * 2^i for (i, immediate_bit) in (0..11, instruction_bits[20..31])])
enf immediate = low_bit_sum + high_bit_sum
enf sign_bit = 1
"#;

let test = ParseTest::new();
let program = match test.parse_program(root) {
Err(err) => {
test.diagnostics.emit(err);
panic!("expected parsing to succeed, see diagnostics for details");
}
Ok(ast) => ast,
};

let mut pipeline =
ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics));
let program = pipeline.run(program).unwrap();

let mut expected = Program::new(ident!(root));
expected.trace_columns.push(trace_segment!(
0,
"$main",
[
(instruction_word, 1),
(instruction_bits, 32),
(immediate, 1),
(s, 1)
]
));
expected.public_inputs.insert(
ident!(stack_inputs),
PublicInput::new(SourceSpan::UNKNOWN, ident!(stack_inputs), 16),
);
expected.public_inputs.insert(
ident!(stack_outputs),
PublicInput::new(SourceSpan::UNKNOWN, ident!(stack_outputs), 16),
);
expected.boundary_constraints.push(enforce!(eq!(
bounded_access!(instruction_word, Boundary::First, Type::Felt),
int!(0)
)));
expected.boundary_constraints.push(enforce!(eq!(
bounded_access!(instruction_word, Boundary::Last, Type::Felt),
int!(2)
)));
for i in 0..32 {
let access = ScalarExpr::SymbolAccess(SymbolAccess {
span: miden_diagnostics::SourceSpan::UNKNOWN,
name: ResolvableIdentifier::Local(ident!(instruction_bits)),
access_type: AccessType::Index(i),
offset: 0,
ty: Some(Type::Felt),
});
expected
.integrity_constraints
.push(enforce!(eq!(exp!(access.clone(), int!(2)), access.clone())));
}
let word_sum = (1..32)
.into_iter()
.fold(access!("%lc0", Type::Felt), |acc, i| {
let access = ScalarExpr::SymbolAccess(SymbolAccess {
span: miden_diagnostics::SourceSpan::UNKNOWN,
name: ResolvableIdentifier::Local(Identifier::new(
miden_diagnostics::SourceSpan::UNKNOWN,
crate::Symbol::intern(format!("%lc{}", i)),
)),
access_type: AccessType::Default,
offset: 0,
ty: Some(Type::Felt),
});
add!(acc, access)
});
let high_bit_sum = (33..53)
.into_iter()
.fold(access!("%lc32", Type::Felt), |acc, i| {
let access = ScalarExpr::SymbolAccess(SymbolAccess {
span: miden_diagnostics::SourceSpan::UNKNOWN,
name: ResolvableIdentifier::Local(Identifier::new(
miden_diagnostics::SourceSpan::UNKNOWN,
crate::Symbol::intern(format!("%lc{}", i)),
)),
access_type: AccessType::Default,
offset: 0,
ty: Some(Type::Felt),
});
add!(acc, access)
});
let low_bit_sum = (54..64)
.into_iter()
.fold(access!("%lc53", Type::Felt), |acc, i| {
let access = ScalarExpr::SymbolAccess(SymbolAccess {
span: miden_diagnostics::SourceSpan::UNKNOWN,
name: ResolvableIdentifier::Local(Identifier::new(
miden_diagnostics::SourceSpan::UNKNOWN,
crate::Symbol::intern(format!("%lc{}", i)),
)),
access_type: AccessType::Default,
offset: 0,
ty: Some(Type::Felt),
});
add!(acc, access)
});
let low_bit_sum_body = let_!(low_bit_sum = expr!(low_bit_sum) =>
enforce!(eq!(access!(immediate, Type::Felt), add!(access!(low_bit_sum, Type::Felt), access!(high_bit_sum, Type::Felt))), when access!(s, Type::Felt)),
enforce!(eq!(access!(instruction_bits[31], Type::Felt), int!(1)), when access!(s, Type::Felt)));
let high_bit_sum_body = let_!(high_bit_sum = expr!(high_bit_sum)
=> let_!("%lc53" = expr!(mul!(access!(instruction_bits[20], Type::Felt), int!(1)))
=> let_!("%lc54" = expr!(mul!(access!(instruction_bits[21], Type::Felt), int!(2)))
=> let_!("%lc55" = expr!(mul!(access!(instruction_bits[22], Type::Felt), int!(4)))
=> let_!("%lc56" = expr!(mul!(access!(instruction_bits[23], Type::Felt), int!(8)))
=> let_!("%lc57" = expr!(mul!(access!(instruction_bits[24], Type::Felt), int!(16)))
=> let_!("%lc58" = expr!(mul!(access!(instruction_bits[25], Type::Felt), int!(32)))
=> let_!("%lc59" = expr!(mul!(access!(instruction_bits[26], Type::Felt), int!(64)))
=> let_!("%lc60" = expr!(mul!(access!(instruction_bits[27], Type::Felt), int!(128)))
=> let_!("%lc61" = expr!(mul!(access!(instruction_bits[28], Type::Felt), int!(256)))
=> let_!("%lc62" = expr!(mul!(access!(instruction_bits[29], Type::Felt), int!(512)))
=> let_!("%lc63" = expr!(mul!(access!(instruction_bits[30], Type::Felt), int!(1024)))
=> low_bit_sum_body))))))))))));
let word_sum_body = let_!(word_sum = expr!(word_sum)
=> enforce!(eq!(access!(instruction_word, Type::Felt), access!(word_sum, Type::Felt))),
let_!("%lc32" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(2048)))
=> let_!("%lc33" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(4096)))
=> let_!("%lc34" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(8192)))
=> let_!("%lc35" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(16384)))
=> let_!("%lc36" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(32768)))
=> let_!("%lc37" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(65536)))
=> let_!("%lc38" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(131072)))
=> let_!("%lc39" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(262144)))
=> let_!("%lc40" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(524288)))
=> let_!("%lc41" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(1048576)))
=> let_!("%lc42" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(2097152)))
=> let_!("%lc43" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(4194304)))
=> let_!("%lc44" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(8388608)))
=> let_!("%lc45" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(16777216)))
=> let_!("%lc46" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(33554432)))
=> let_!("%lc47" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(67108864)))
=> let_!("%lc48" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(134217728)))
=> let_!("%lc49" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(268435456)))
=> let_!("%lc50" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(536870912)))
=> let_!("%lc51" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(1073741824)))
=> let_!("%lc52" = expr!(mul!(access!(instruction_bits[31], Type::Felt), int!(2147483648)))
=> high_bit_sum_body))))))))))))))))))))),
enforce!(eq!(access!(immediate, Type::Felt), int!(0)), when not!(access!(s, Type::Felt)))
);

expected.integrity_constraints.push(
let_!("%lc0" = expr!(mul!(int!(1), access!(instruction_bits[0], Type::Felt)))
=> let_!("%lc1" = expr!(mul!(int!(2), access!(instruction_bits[1], Type::Felt)))
=> let_!("%lc2" = expr!(mul!(int!(4), access!(instruction_bits[2], Type::Felt)))
=> let_!("%lc3" = expr!(mul!(int!(8), access!(instruction_bits[3], Type::Felt)))
=> let_!("%lc4" = expr!(mul!(int!(16), access!(instruction_bits[4], Type::Felt)))
=> let_!("%lc5" = expr!(mul!(int!(32), access!(instruction_bits[5], Type::Felt)))
=> let_!("%lc6" = expr!(mul!(int!(64), access!(instruction_bits[6], Type::Felt)))
=> let_!("%lc7" = expr!(mul!(int!(128), access!(instruction_bits[7], Type::Felt)))
=> let_!("%lc8" = expr!(mul!(int!(256), access!(instruction_bits[8], Type::Felt)))
=> let_!("%lc9" = expr!(mul!(int!(512), access!(instruction_bits[9], Type::Felt)))
=> let_!("%lc10" = expr!(mul!(int!(1024), access!(instruction_bits[10], Type::Felt)))
=> let_!("%lc11" = expr!(mul!(int!(2048), access!(instruction_bits[11], Type::Felt)))
=> let_!("%lc12" = expr!(mul!(int!(4096), access!(instruction_bits[12], Type::Felt)))
=> let_!("%lc13" = expr!(mul!(int!(8192), access!(instruction_bits[13], Type::Felt)))
=> let_!("%lc14" = expr!(mul!(int!(16384), access!(instruction_bits[14], Type::Felt)))
=> let_!("%lc15" = expr!(mul!(int!(32768), access!(instruction_bits[15], Type::Felt)))
=> let_!("%lc16" = expr!(mul!(int!(65536), access!(instruction_bits[16], Type::Felt)))
=> let_!("%lc17" = expr!(mul!(int!(131072), access!(instruction_bits[17], Type::Felt)))
=> let_!("%lc18" = expr!(mul!(int!(262144), access!(instruction_bits[18], Type::Felt)))
=> let_!("%lc19" = expr!(mul!(int!(524288), access!(instruction_bits[19], Type::Felt)))
=> let_!("%lc20" = expr!(mul!(int!(1048576), access!(instruction_bits[20], Type::Felt)))
=> let_!("%lc21" = expr!(mul!(int!(2097152), access!(instruction_bits[21], Type::Felt)))
=> let_!("%lc22" = expr!(mul!(int!(4194304), access!(instruction_bits[22], Type::Felt)))
=> let_!("%lc23" = expr!(mul!(int!(8388608), access!(instruction_bits[23], Type::Felt)))
=> let_!("%lc24" = expr!(mul!(int!(16777216), access!(instruction_bits[24], Type::Felt)))
=> let_!("%lc25" = expr!(mul!(int!(33554432), access!(instruction_bits[25], Type::Felt)))
=> let_!("%lc26" = expr!(mul!(int!(67108864), access!(instruction_bits[26], Type::Felt)))
=> let_!("%lc27" = expr!(mul!(int!(134217728), access!(instruction_bits[27], Type::Felt)))
=> let_!("%lc28" = expr!(mul!(int!(268435456), access!(instruction_bits[28], Type::Felt)))
=> let_!("%lc29" = expr!(mul!(int!(536870912), access!(instruction_bits[29], Type::Felt)))
=> let_!("%lc30" = expr!(mul!(int!(1073741824), access!(instruction_bits[30], Type::Felt)))
=> let_!("%lc31" = expr!(mul!(int!(2147483648), access!(instruction_bits[31], Type::Felt)))
=> word_sum_body)))))))))))))))))))))))))))))))),
);
// The evaluator definition is never modified by constant propagation or inlining
let body = vec![
let_!(sign_bit = expr!(access!(instruction_bits[31], Type::Felt))
=> let_!(high_bit_sum = expr!(call!(sum(expr!(lc!(((i, range!(11..32))) => mul!(access!(sign_bit, Type::Felt), exp!(int!(2), access!(i, Type::Felt))))))))
=> let_!(immediate_bits = expr!(slice!(instruction_bits, 20..31, Type::Vector(11)))
=> let_!(low_bit_sum = expr!(call!(sum(expr!(lc!(((i, range!(0..11)), (immediate_bit, expr!(slice!(instruction_bits, 20..31, Type::Vector(11))))) => mul!(access!(immediate_bit, Type::Felt), exp!(int!(2), access!(i, Type::Felt))))))))
=> enforce!(eq!(access!(immediate, Type::Felt), add!(access!(low_bit_sum, Type::Felt), access!(high_bit_sum, Type::Felt)))),
enforce!(eq!(access!(sign_bit, Type::Felt), int!(1))))))),
];
expected.evaluators.insert(
function_ident!(root, imm_reconstruction),
EvaluatorFunction::new(
SourceSpan::UNKNOWN,
ident!(imm_reconstruction),
vec![trace_segment!(
0,
"%2",
[(instruction_bits, 32), (immediate, 1)]
)],
body,
),
);

assert_eq!(program, expected);
}
4 changes: 4 additions & 0 deletions parser/src/transforms/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,11 @@ impl<'a> Inlining<'a> {
Ok(vec![Statement::Expr(folded)])
}
Expr::ListComprehension(lc) => {
// Expand the comprehension, but ensure we don't treat it like a comprehension constraint
let in_cc = core::mem::replace(&mut self.in_comprehension_constraint, false);
let mut expanded = self.expand_comprehension(lc)?;
self.in_comprehension_constraint = in_cc;
// Apply the fold to the expanded comprehension in the bottom of the let tree
with_let_result(self, &mut expanded, |inliner, value| {
match value {
// The result value of expanding a comprehension _must_ be a vector
Expand Down

0 comments on commit 451119b

Please sign in to comment.