Skip to content

Commit

Permalink
implement field idiv and rem ops
Browse files Browse the repository at this point in the history
  • Loading branch information
dark64 committed Sep 26, 2023
1 parent c7e4e29 commit afdb931
Show file tree
Hide file tree
Showing 33 changed files with 602 additions and 39 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ members = [
]

[profile.dev]
opt-level = 1
# opt-level = 1
2 changes: 1 addition & 1 deletion zokrates_analysis/src/expression_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ExpressionValidator {
| FieldElementExpression::Xor(_)
| FieldElementExpression::LeftShift(_)
| FieldElementExpression::RightShift(_) => Err(Error(format!(
"Found non-constant bitwise operation in field element expression `{}`",
"Field element expression `{}` must be a constant expression",
e
))),
FieldElementExpression::Pow(e) => {
Expand Down
6 changes: 6 additions & 0 deletions zokrates_analysis/src/flatten_complex_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,12 @@ fn fold_field_expression<'ast, T: Field>(
typed::FieldElementExpression::Div(e) => {
zir::FieldElementExpression::Div(f.fold_binary_expression(statements_buffer, e))
}
typed::FieldElementExpression::IDiv(e) => {
zir::FieldElementExpression::IDiv(f.fold_binary_expression(statements_buffer, e))
}
typed::FieldElementExpression::Rem(e) => {
zir::FieldElementExpression::Rem(f.fold_binary_expression(statements_buffer, e))
}
typed::FieldElementExpression::Pow(e) => {
zir::FieldElementExpression::Pow(f.fold_binary_expression(statements_buffer, e))
}
Expand Down
19 changes: 19 additions & 0 deletions zokrates_analysis/src/panic_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,25 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
);
FieldElementExpression::div(n, d)
}
FieldElementExpression::IDiv(e) => {
let n = self.fold_field_expression(*e.left);
let d = self.fold_field_expression(*e.right);
self.panic_buffer.push(
ZirStatement::assertion(
BooleanExpression::not(
BooleanExpression::field_eq(
d.clone().span(span),
FieldElementExpression::value(T::zero()).span(span),
)
.span(span),
)
.span(span),
RuntimeError::DivisionByZero,
)
.span(span),
);
FieldElementExpression::idiv(n, d)
}
e => fold_field_expression_cases(self, e),
}
}
Expand Down
97 changes: 91 additions & 6 deletions zokrates_analysis/src/propagation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@ impl fmt::Display for Error {
Error::Type(s) => write!(f, "{}", s),
Error::AssertionFailed(err) => write!(f, "Assertion failed ({})", err),
Error::InvalidValue(s) => write!(f, "{}", s),
Error::OutOfBounds(index, size) => write!(
f,
"Out of bounds index ({} >= {}) found during static analysis",
index, size
),
Error::OutOfBounds(index, size) => {
write!(f, "Out of bounds index ({} >= {})", index, size)
}
Error::VariableLength(message) => write!(f, "{}", message),
Error::DivisionByZero => {
write!(f, "Division by zero detected during static analysis",)
write!(f, "Division by zero detected",)
}
}
}
Expand Down Expand Up @@ -856,6 +854,22 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
Ok(UExpression::and(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner())
}
},
UExpressionInner::Or(e) => match (
self.fold_uint_expression(*e.left)?.into_inner(),
self.fold_uint_expression(*e.right)?.into_inner(),
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
Ok(UExpression::value(v1.value | v2.value))
}
(UExpressionInner::Value(v), e) | (e, UExpressionInner::Value(v))
if v.value == 0 =>
{
Ok(e)
}
(e1, e2) => {
Ok(UExpression::or(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner())
}
},
UExpressionInner::Not(e) => {
let e = self.fold_uint_expression(*e.inner)?.into_inner();
match e {
Expand Down Expand Up @@ -939,6 +953,35 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
(e1, e2) => Ok(e1 / e2),
}
}
FieldElementExpression::IDiv(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

Ok(match (left, right) {
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
FieldElementExpression::value(
T::try_from(n1.value.to_biguint().div(n2.value.to_biguint())).unwrap(),
)
}
(e1, e2) => FieldElementExpression::idiv(e1, e2),
})
}
FieldElementExpression::Rem(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

Ok(match (left, right) {
(_, FieldElementExpression::Value(n)) if n.value == T::from(1) => {
FieldElementExpression::value(T::zero())
}
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
FieldElementExpression::value(
T::try_from(n1.value.to_biguint().rem(n2.value.to_biguint())).unwrap(),
)
}
(e1, e2) => e1 % e2,
})
}
FieldElementExpression::Neg(e) => match self.fold_field_expression(*e.inner)? {
FieldElementExpression::Value(n) => {
Ok(FieldElementExpression::value(T::zero() - n.value))
Expand Down Expand Up @@ -1606,6 +1649,48 @@ mod tests {
);
}

#[test]
fn idiv() {
let e = FieldElementExpression::idiv(
FieldElementExpression::value(Bn128Field::from(7)),
FieldElementExpression::value(Bn128Field::from(2)),
);

assert_eq!(
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::value(Bn128Field::from(3)))
);
}

#[test]
fn rem() {
let mut propagator = Propagator::default();

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(5)),
FieldElementExpression::value(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(1)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(2)),
FieldElementExpression::value(Bn128Field::from(5)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(2)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(1)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(0)))
);
}

#[test]
fn pow() {
let e = FieldElementExpression::pow(
Expand Down
112 changes: 106 additions & 6 deletions zokrates_analysis/src/zir_propagation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ pub enum Error {
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::OutOfBounds(index, size) => write!(
f,
"Out of bounds index ({} >= {}) found in zir during static analysis",
index, size
),
Error::OutOfBounds(index, size) => {
write!(f, "Out of bounds index ({} >= {})", index, size)
}
Error::DivisionByZero => {
write!(f, "Division by zero detected in zir during static analysis",)
write!(f, "Division by zero detected",)
}
Error::AssertionFailed(err) => write!(f, "Assertion failed ({})", err),
}
Expand Down Expand Up @@ -343,6 +341,42 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
(e1, e2) => Ok(FieldElementExpression::div(e1, e2).span(e.span)),
}
}
FieldElementExpression::IDiv(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

match (left, right) {
(_, FieldElementExpression::Value(n)) if n.value == T::from(0) => {
Err(Error::DivisionByZero)
}
(e, FieldElementExpression::Value(n)) if n.value == T::from(1) => Ok(e),
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
Ok(FieldElementExpression::value(
T::try_from(n1.value.to_biguint().div(n2.value.to_biguint())).unwrap(),
))
}
(e1, e2) => Ok(FieldElementExpression::idiv(e1, e2).span(e.span)),
}
}
FieldElementExpression::Rem(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

match (left, right) {
(_, FieldElementExpression::Value(n)) if n.value == T::from(0) => {
Err(Error::DivisionByZero)
}
(_, FieldElementExpression::Value(n)) if n.value == T::from(1) => {
Ok(FieldElementExpression::value(T::zero()))
}
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
Ok(FieldElementExpression::value(
T::try_from(n1.value.to_biguint().rem(n2.value.to_biguint())).unwrap(),
))
}
(e1, e2) => Ok(FieldElementExpression::rem(e1, e2).span(e.span)),
}
}
FieldElementExpression::Pow(e) => {
let exponent = self.fold_uint_expression(*e.right)?;
match (self.fold_field_expression(*e.left)?, exponent.into_inner()) {
Expand Down Expand Up @@ -1099,6 +1133,72 @@ mod tests {
);
}

#[test]
fn idiv() {
let mut propagator = ZirPropagator::default();

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::idiv(
FieldElementExpression::value(Bn128Field::from(7)),
FieldElementExpression::value(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(3)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::idiv(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(1)),
)),
Ok(FieldElementExpression::identifier("a".into()))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::idiv(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(0)),
)),
Err(Error::DivisionByZero)
);
}

#[test]
fn rem() {
let mut propagator = ZirPropagator::default();

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(5)),
FieldElementExpression::value(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(1)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(2)),
FieldElementExpression::value(Bn128Field::from(5)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(2)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(1)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(0)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::div(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(0)),
)),
Err(Error::DivisionByZero)
);
}

#[test]
fn pow() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
Expand Down
7 changes: 7 additions & 0 deletions zokrates_ast/src/common/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ impl OperatorStr for OpDiv {
const STR: &'static str = "/";
}

#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct OpIDiv;

impl OperatorStr for OpIDiv {
const STR: &'static str = "\\";
}

#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct OpRem;

Expand Down
6 changes: 5 additions & 1 deletion zokrates_ast/src/ir/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::ir::Parameter;
use crate::ir::ProgIterator;
use crate::ir::Statement;
use crate::ir::Variable;
use crate::Solver;
use std::collections::HashSet;
use zokrates_field::Field;

Expand Down Expand Up @@ -46,7 +47,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UnconstrainedVariableDetector {
&mut self,
d: DirectiveStatement<'ast, T>,
) -> Vec<Statement<'ast, T>> {
self.variables.extend(d.outputs.iter());
match d.solver {
Solver::Zir(_) => {} // we do not check variables introduced by assembly
_ => self.variables.extend(d.outputs.iter()), // this is not necessary, but we keep it as a sanity check
};
vec![Statement::Directive(d)]
}
}
4 changes: 2 additions & 2 deletions zokrates_ast/src/ir/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,10 @@ impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'a
if matches!(s, Statement::Constraint(..)) {
count += 1;
}
let s: Vec<Statement<T>> = solver_indexer
let s: Vec<Statement<T>> = unconstrained_variable_detector
.fold_statement(s)
.into_iter()
.flat_map(|s| unconstrained_variable_detector.fold_statement(s))
.flat_map(|s| solver_indexer.fold_statement(s))
.collect();
for s in s {
serde_cbor::to_writer(&mut w, &s)?;
Expand Down
8 changes: 8 additions & 0 deletions zokrates_ast/src/typed/folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,14 @@ pub fn fold_field_expression_cases<'ast, T: Field, F: Folder<'ast, T>>(
BinaryOrExpression::Binary(e) => Div(e),
BinaryOrExpression::Expression(u) => u,
},
IDiv(e) => match f.fold_binary_expression(&Type::FieldElement, e) {
BinaryOrExpression::Binary(e) => IDiv(e),
BinaryOrExpression::Expression(u) => u,
},
Rem(e) => match f.fold_binary_expression(&Type::FieldElement, e) {
BinaryOrExpression::Binary(e) => Rem(e),
BinaryOrExpression::Expression(u) => u,
},
Pow(e) => match f.fold_binary_expression(&Type::FieldElement, e) {
BinaryOrExpression::Binary(e) => Pow(e),
BinaryOrExpression::Expression(u) => u,
Expand Down
Loading

0 comments on commit afdb931

Please sign in to comment.