Skip to content

Commit

Permalink
update bit cache with correct bits
Browse files Browse the repository at this point in the history
  • Loading branch information
dark64 committed May 22, 2023
1 parent 73e6b09 commit 52b6815
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
1 change: 1 addition & 0 deletions changelogs/unreleased/1309-dark64
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix uint range check in assertions when comparing to a constant
23 changes: 16 additions & 7 deletions zokrates_codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1985,7 +1985,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
///
/// # Notes
/// * `from` and `to` must be smaller or equal to `T::get_required_bits()`, the bitwidth of the prime field
/// * the result is not checked to be in range. This is fine for `to < T::get_required_bits()`, but otherwise it is the caller's responsibility to add that check
/// * The result is not checked to be in range unless the bits of the expression were already decomposed with a higher bitwidth than `to`
/// * This is fine for `to < T::get_required_bits()`, but otherwise it is the caller's responsibility to add that check
fn get_bits_unchecked(
&mut self,
e: &FlatUExpression<T>,
Expand Down Expand Up @@ -2019,30 +2020,38 @@ impl<'ast, T: Field> Flattener<'ast, T> {

let from = std::cmp::max(from, to);
let res = match self.bits_cache.entry(e.field.clone().unwrap()) {
Entry::Occupied(entry) => {
Entry::Occupied(mut entry) => {
let res: Vec<_> = entry.get().clone();

if res.len() > to {
// if the result is bigger than `to`, we sum higher bits up to `to`
// if the result is bigger than `to`, we zero check the sum of higher bits up to `to`
let bit_sum = res[..res.len() - to]
.iter()
.cloned()
.fold(FlatExpression::from(T::zero()), |acc, e| {
FlatExpression::add(acc, e)
});

// sum of higher bits must be zero
// sum check
statements_flattened.push_back(FlatStatement::condition(
FlatExpression::value(T::from(0)),
bit_sum,
error,
));

// truncate to the `to` lowest bits
let res = res[res.len() - to..].to_vec();
assert_eq!(res.len(), to);
let bits = res[res.len() - to..].to_vec();
assert_eq!(bits.len(), to);

// update the entry
entry.insert(
(0..res.len() - to)
.map(|_| FlatExpression::value(T::zero()))
.chain(bits.clone())
.collect(),
);

return res;
return bits;
}

// if result is smaller than `to` we pad it with zeroes on the left (big endian) to return `to` bits
Expand Down

0 comments on commit 52b6815

Please sign in to comment.