Skip to content

Commit

Permalink
suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
dark64 committed Aug 1, 2023
1 parent 52b6815 commit 8c88527
Showing 1 changed file with 22 additions and 40 deletions.
62 changes: 22 additions & 40 deletions zokrates_codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2020,49 +2020,31 @@ impl<'ast, T: Field> Flattener<'ast, T> {

let from = std::cmp::max(from, to);
let res = match self.bits_cache.entry(e.field.clone().unwrap()) {
Entry::Occupied(mut entry) => {
let res: Vec<_> = entry.get().clone();
Entry::Occupied(entry) => {
let mut res: Vec<_> = entry.get().clone();

if res.len() > to {
// if the result is bigger than `to`, we zero check the sum of higher bits up to `to`
let bit_sum = res[..res.len() - to]
.iter()
.cloned()
.fold(FlatExpression::from(T::zero()), |acc, e| {
FlatExpression::add(acc, e)
});

// sum check
statements_flattened.push_back(FlatStatement::condition(
FlatExpression::value(T::from(0)),
bit_sum,
error,
));

// truncate to the `to` lowest bits
let bits = res[res.len() - to..].to_vec();
assert_eq!(bits.len(), to);

// update the entry
entry.insert(
(0..res.len() - to)
.map(|_| FlatExpression::value(T::zero()))
.chain(bits.clone())
.collect(),
);
// only keep the last `to` values and return the sum of the others
let sum = res
.drain(0..res.len().saturating_sub(to))
.fold(FlatExpression::from(T::zero()), |acc, e| {
FlatExpression::add(acc, e)
});

return bits;
}
// force the sum to be 0
statements_flattened.push_back(FlatStatement::condition(
FlatExpression::value(T::from(0)),
sum,
error,
));

// if result is smaller than `to` we pad it with zeroes on the left (big endian) to return `to` bits
if res.len() < to {
return (0..to - res.len())
.map(|_| FlatExpression::value(T::zero()))
.chain(res)
.collect();
}
// sanity check that we have at most `to` values
assert!(res.len() <= to);

res
// return the result left-padded to `to` values
std::iter::repeat(FlatExpression::value(T::zero()))
.take(to - res.len())
.chain(res.clone())
.collect()
}
Entry::Vacant(_) => {
let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>();
Expand Down Expand Up @@ -2700,7 +2682,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
c.value,
error.into(),
),
// c < e <=> 2^bw - 1 - e < 2^bw - 1 - c
// c <= e <=> 2^bw - 1 - e <= 2^bw - 1 - c
(FlatExpression::Value(c), e) => {
let max = T::from(2u32).pow(bitwidth) - T::one();
self.enforce_constant_le_check(
Expand Down

0 comments on commit 8c88527

Please sign in to comment.