Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce cost of branching #1242

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelogs/unreleased/1242-schaeff
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reduce the runtime cost of conditionals
3 changes: 3 additions & 0 deletions zokrates_ast/src/common/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub enum RuntimeError {
LtSymetric,
Or,
Xor,
Ite,
IncompleteDynamicRange,
Inverse,
Euclidean,
Expand Down Expand Up @@ -54,6 +55,7 @@ impl RuntimeError {
| SelectRangeCheck
| ArgumentBitness
| IncompleteDynamicRange
| Ite
)
}
}
Expand All @@ -80,6 +82,7 @@ impl fmt::Display for RuntimeError {
LtSymetric => "Symetrical check failed in Lt check",
Or => "Or check failed",
Xor => "Xor check failed",
Ite => "Conditional check failed",
IncompleteDynamicRange => {
"Failed to compare field elements because dynamic comparison is incomplete"
}
Expand Down
2 changes: 2 additions & 0 deletions zokrates_ast/src/common/solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt;

#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub enum Solver {
Ite,
ConditionEq,
Bits(usize),
Div,
Expand All @@ -26,6 +27,7 @@ impl fmt::Display for Solver {
impl Solver {
pub fn get_signature(&self) -> (usize, usize) {
match self {
Solver::Ite => (3, 1),
Solver::ConditionEq => (1, 2),
Solver::Bits(bit_width) => (1, *bit_width),
Solver::Div => (2, 1),
Expand Down
2 changes: 1 addition & 1 deletion zokrates_ast/src/ir/smtlib2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl<T: Field> SMTLib2 for Prog<T> {
"; Number of circuit variables: {}",
collector.variables.len()
)?;
writeln!(f, "; Number of equalities: {}", self.statements.len())?;
writeln!(f, "; Number of equalities: {}", self.constraint_count())?;

writeln!(f, "(declare-const |~prime| Int)")?;
for v in collector.variables.iter() {
Expand Down
2 changes: 1 addition & 1 deletion zokrates_cli/tests/code/conditional_false.smt2
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
; Auto generated by ZoKrates
; Number of circuit variables: 5
; Number of equalities: 4
; Number of equalities: 3
(declare-const |~prime| Int)
(declare-const |~out_0| Int)
(declare-const |~one| Int)
Expand Down
2 changes: 1 addition & 1 deletion zokrates_cli/tests/code/conditional_true.smt2
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
; Auto generated by ZoKrates
; Number of circuit variables: 5
; Number of equalities: 4
; Number of equalities: 3
(declare-const |~prime| Int)
(declare-const |~out_0| Int)
(declare-const |~one| Int)
Expand Down
9 changes: 5 additions & 4 deletions zokrates_cli/tests/code/taxation.smt2
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
; Auto generated by ZoKrates
; Number of circuit variables: 260
; Number of equalities: 261
; Number of equalities: 259
(declare-const |~prime| Int)
(declare-const |~out_0| Int)
(declare-const |~one| Int)
Expand Down Expand Up @@ -261,7 +261,7 @@
(declare-const |_254| Int)
(declare-const |_257| Int)
(declare-const |_258| Int)
(declare-const |_264| Int)
(declare-const |_263| Int)
(assert (and
(= |~prime| 21888242871839275222246405745257275088548364400416034343698204186575808495617)
(= |~one| 1)
Expand Down Expand Up @@ -524,6 +524,7 @@
(= (mod (* (+ (* |~one| 14651237294507013008273219182214280847718990358813499091232105186081237893121) (* |_0| 1) (* |_1| 21888242871839275222246405745257275088548364400416034343698204186575808495616)) (* |_258| 1)) |~prime|) (mod (* |_257| 1) |~prime|))
(= (mod (* (+ (* |~one| 1) (* |_257| 21888242871839275222246405745257275088548364400416034343698204186575808495616)) (+ (* |~one| 14651237294507013008273219182214280847718990358813499091232105186081237893121) (* |_0| 1) (* |_1| 21888242871839275222246405745257275088548364400416034343698204186575808495616))) |~prime|) (mod 0 |~prime|))
(= (mod (* (* |~one| 1) (* |~one| 1)) |~prime|) (mod (* |_257| 1) |~prime|))
(= (mod (* (* |_2| 1) (+ (* |_0| 21888242871839275222246405745257275088548364400416034343698204186575808495616) (* |_1| 1))) |~prime|) (mod (* |_264| 1) |~prime|))
(= (mod (* (* |~one| 1) (* |_264| 1)) |~prime|) (mod (* |~out_0| 1) |~prime|))

(= (mod (* (+ (* |~one| 1) (* |_2| 21888242871839275222246405745257275088548364400416034343698204186575808495616)) (+ (* |_0| 1) (* |_1| 21888242871839275222246405745257275088548364400416034343698204186575808495616))) |~prime|) (mod (+ (* |_0| 1) (* |_1| 21888242871839275222246405745257275088548364400416034343698204186575808495616) (* |_263| 1)) |~prime|))
(= (mod (* (* |~one| 1) (* |_263| 1)) |~prime|) (mod (* |~out_0| 1) |~prime|))
))
39 changes: 13 additions & 26 deletions zokrates_core/src/flatten/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,38 +634,25 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let alternative_id = self.use_sym();
statements_flattened.push_back(FlatStatement::Definition(alternative_id, alternative));

let term0_id = self.use_sym();
statements_flattened.push_back(FlatStatement::Definition(
term0_id,
FlatExpression::Mult(
box condition_id.into(),
box FlatExpression::from(consequence_id),
),
));
let res_id = self.use_sym();

let term1_id = self.use_sym();
statements_flattened.push_back(FlatStatement::Definition(
term1_id,
FlatExpression::Mult(
box FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box condition_id.into(),
),
box FlatExpression::from(alternative_id),
),
));
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
vec![res_id],
Solver::Ite,
vec![condition_id, consequence_id, alternative_id],
)));

let res = self.use_sym();
statements_flattened.push_back(FlatStatement::Definition(
res,
FlatExpression::Add(
box FlatExpression::from(term0_id),
box FlatExpression::from(term1_id),
statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Sub(box res_id.into(), box alternative_id.into()),
FlatExpression::Mult(
box condition_id.into(),
box FlatExpression::Sub(box consequence_id.into(), box alternative_id.into()),
),
RuntimeError::Ite,
));

FlatUExpression {
field: Some(FlatExpression::Identifier(res)),
field: Some(res_id.into()),
bits: None,
}
}
Expand Down
48 changes: 10 additions & 38 deletions zokrates_core/src/optimizer/redefinition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,52 +120,24 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
Statement::Directive(d) => {
let d = self.fold_directive(d);

// check if the inputs are constants, ie reduce to the form `coeff * ~one`
let inputs: Vec<_> = d
.inputs
.into_iter()
// we need to reduce to the canonical form to interpret `a + 1 - a` as `1`
.map(|i| i.reduce())
.map(|q| {
match q
.try_linear()
.map(|l| l.try_constant().map_err(|l| l.into()))
{
Ok(r) => r,
Err(e) => Err(e),
}
})
.collect::<Vec<Result<T, QuadComb<T>>>>();

match inputs.iter().all(|i| i.is_ok()) {
true => {
// unwrap inputs to their constant value
let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
// run the solver
let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap();
assert_eq!(outputs.len(), d.outputs.len());

// insert the results in the substitution
// we run symbolic execution so that we can propagate cases like `c ? 1 : 0 === c`
match Interpreter::execute_solver_symbolic(d.solver, d.inputs) {
Ok(outputs) => {
for (output, value) in d.outputs.into_iter().zip(outputs.into_iter()) {
self.substitution
.insert(output, LinComb::from(value).into_canonical());
self.substitution.insert(output, value.into_canonical());
}
vec![]
}
false => {
//reconstruct the input expressions
let inputs: Vec<_> = inputs
.into_iter()
.map(|i| {
i.map(|v| LinComb::summand(v, Variable::one()).into())
.unwrap_or_else(|q| q)
})
.collect();
Err((solver, inputs)) => {
// to prevent the optimiser from replacing variables introduced by directives, add them to the ignored set
for o in d.outputs.iter().cloned() {
self.ignore.insert(o);
}
vec![Statement::Directive(Directive { inputs, ..d })]
vec![Statement::Directive(Directive {
inputs,
solver,
..d
})]
}
}
}
Expand Down
1 change: 1 addition & 0 deletions zokrates_core_test/tests/tests/array_conditional.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"entry_point": "./tests/tests/array_conditional.zok",
"curves": ["Bn128", "Bls12_381", "Bls12_377", "Bw6_761"],
"max_constraint_count": 6,
"tests": [
{
"input": {
Expand Down
2 changes: 1 addition & 1 deletion zokrates_core_test/tests/tests/uint/conditional.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"entry_point": "./tests/tests/uint/conditional.zok",
"max_constraint_count": 31,
"max_constraint_count": 30,
"tests": [
{
"input": {
Expand Down
65 changes: 65 additions & 0 deletions zokrates_interpreter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use zokrates_ast::ir::{
use zokrates_field::Field;

pub type ExecutionResult<T> = Result<Witness<T>, Error>;
type SolverCall<T> = (Solver, Vec<QuadComb<T>>);

#[derive(Default)]
pub struct Interpreter {
Expand Down Expand Up @@ -157,11 +158,75 @@ impl Interpreter {
}
}

pub fn execute_solver_symbolic<T: Field>(
solver: Solver,
inputs: Vec<QuadComb<T>>,
) -> Result<Vec<LinComb<T>>, SolverCall<T>> {
let (expected_input_count, expected_output_count) = solver.get_signature();
assert_eq!(inputs.len(), expected_input_count);

// check if the inputs are constants, ie reduce to the form `coeff * ~one`
let constant_inputs: Vec<_> = inputs
.clone()
.into_iter()
// we need to reduce to the canonical form to interpret `a + 1 - a` as `1`
.map(|i| i.reduce())
.map(|q| {
match q
.try_linear()
.map(|l| l.try_constant().map_err(|l| l.into()))
{
Ok(r) => r,
Err(e) => Err(e),
}
})
.collect::<Vec<Result<T, QuadComb<T>>>>();

match constant_inputs.iter().all(|i| i.is_ok()) {
// run concrete execution
true => {
// unwrap inputs to their constant value
let constant_inputs: Vec<_> =
constant_inputs.into_iter().map(|i| i.unwrap()).collect();
// run the solver
let outputs = Interpreter::execute_solver(&solver, &constant_inputs).unwrap();
assert_eq!(outputs.len(), expected_output_count);
Ok(outputs.into_iter().map(|o| LinComb::from(o)).collect())
}
// run symbolic execution
false => match solver {
// we currently only support the following case:
// ```
// Ite(condition, consequence, alternative) -> consequence + (1 - condition) * (alternative - consequence)
// ```
Solver::Ite => {
let condition = inputs[0].clone();
let mut constant_inputs = constant_inputs;
let alternative = constant_inputs.pop().unwrap();
let consequence = constant_inputs.pop().unwrap();

match (condition.try_linear(), consequence, alternative) {
(Ok(condition), Ok(consequence), Ok(alternative)) => Ok(vec![
(LinComb::from(alternative.clone())
+ condition * &(consequence - alternative)),
]),
_ => Err((solver, inputs)),
}
}
solver => Err((solver, inputs)),
},
}
}

pub fn execute_solver<T: Field>(solver: &Solver, inputs: &[T]) -> Result<Vec<T>, String> {
let (expected_input_count, expected_output_count) = solver.get_signature();
assert_eq!(inputs.len(), expected_input_count);

let res = match solver {
Solver::Ite => match inputs[0].is_zero() {
true => vec![inputs[2].clone()],
false => vec![inputs[1].clone()],
},
Solver::ConditionEq => match inputs[0].is_zero() {
true => vec![T::zero(), T::one()],
false => vec![
Expand Down