diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index 18335932e..de86a66a7 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -2020,49 +2020,31 @@ impl<'ast, T: Field> Flattener<'ast, T> { let from = std::cmp::max(from, to); let res = match self.bits_cache.entry(e.field.clone().unwrap()) { - Entry::Occupied(mut entry) => { - let res: Vec<_> = entry.get().clone(); + Entry::Occupied(entry) => { + let mut res: Vec<_> = entry.get().clone(); - if res.len() > to { - // if the result is bigger than `to`, we zero check the sum of higher bits up to `to` - let bit_sum = res[..res.len() - to] - .iter() - .cloned() - .fold(FlatExpression::from(T::zero()), |acc, e| { - FlatExpression::add(acc, e) - }); - - // sum check - statements_flattened.push_back(FlatStatement::condition( - FlatExpression::value(T::from(0)), - bit_sum, - error, - )); - - // truncate to the `to` lowest bits - let bits = res[res.len() - to..].to_vec(); - assert_eq!(bits.len(), to); - - // update the entry - entry.insert( - (0..res.len() - to) - .map(|_| FlatExpression::value(T::zero())) - .chain(bits.clone()) - .collect(), - ); + // only keep the last `to` values and return the sum of the others + let sum = res + .drain(0..res.len().saturating_sub(to)) + .fold(FlatExpression::from(T::zero()), |acc, e| { + FlatExpression::add(acc, e) + }); - return bits; - } + // force the sum to be 0 + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::from(0)), + sum, + error, + )); - // if result is smaller than `to` we pad it with zeroes on the left (big endian) to return `to` bits - if res.len() < to { - return (0..to - res.len()) - .map(|_| FlatExpression::value(T::zero())) - .chain(res) - .collect(); - } + // sanity check that we have at most `to` values + assert!(res.len() <= to); - res + // return the result left-padded to `to` values + std::iter::repeat(FlatExpression::value(T::zero())) + .take(to - res.len()) + .chain(res.clone()) + .collect() } Entry::Vacant(_) => { let bits = (0..from).map(|_| self.use_sym()).collect::>(); @@ -2700,7 +2682,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { c.value, error.into(), ), - // c < e <=> 2^bw - 1 - e < 2^bw - 1 - c + // c <= e <=> 2^bw - 1 - e <= 2^bw - 1 - c (FlatExpression::Value(c), e) => { let max = T::from(2u32).pow(bitwidth) - T::one(); self.enforce_constant_le_check(