Skip to content

Commit

Permalink
feat: add support for functions
Browse files Browse the repository at this point in the history
  • Loading branch information
tohrnii committed Sep 19, 2023
1 parent 2ea09c6 commit 683d4c1
Show file tree
Hide file tree
Showing 25 changed files with 2,268 additions and 38 deletions.
26 changes: 26 additions & 0 deletions air-script/tests/codegen/masm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,32 @@ fn evaluators() {
expected.assert_eq(&generated_masm);
}

#[test]
#[ignore]
fn functions() {
// let generated_masm = Test::new("tests/functions/functions_simple.air".to_string())
// .transpile(Target::Masm)
// .unwrap();

// let expected = expect_file!["../functions/functions_simple.masm"];
// expected.assert_eq(&generated_masm);

// make sure that the constraints generated using inlined functions are the same as the ones
// generated using regular functions
let generated_masm = Test::new("tests/functions/inlined_functions_simple.air".to_string())
.transpile(Target::Masm)
.unwrap();
let expected = expect_file!["../functions/functions_simple.masm"];
expected.assert_eq(&generated_masm);

let generated_masm = Test::new("tests/functions/functions_complex.air".to_string())
.transpile(Target::Masm)
.unwrap();

let expected = expect_file!["../functions/functions_complex.masm"];
expected.assert_eq(&generated_masm);
}

#[test]
fn variables() {
let generated_masm = Test::new("tests/variables/variables.air".to_string())
Expand Down
26 changes: 26 additions & 0 deletions air-script/tests/codegen/winterfell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,32 @@ fn evaluators() {
expected.assert_eq(&generated_air);
}

#[test]
fn functions() {
let generated_air = Test::new("tests/functions/functions_simple.air".to_string())
.transpile(Target::Winterfell)
.unwrap();

let expected = expect_file!["../functions/functions_simple.rs"];
expected.assert_eq(&generated_air);

// make sure that the constraints generated using inlined functions are the same as the ones
// generated using regular functions
let generated_air = Test::new("tests/functions/inlined_functions_simple.air".to_string())
.transpile(Target::Winterfell)
.unwrap();

let expected = expect_file!["../functions/functions_simple.rs"];
expected.assert_eq(&generated_air);

// let generated_air = Test::new("tests/functions/functions_complex.air".to_string())
// .transpile(Target::Winterfell)
// .unwrap();

// let expected = expect_file!["../functions/functions_complex.rs"];
// expected.assert_eq(&generated_air);
}

#[test]
fn variables() {
let generated_air = Test::new("tests/variables/variables.air".to_string())
Expand Down
37 changes: 37 additions & 0 deletions air-script/tests/functions/functions_complex.air
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
def FunctionsAir

fn get_multiplicity_flags(s0: felt, s1: felt) -> felt[4]:
return [!s0 & !s1, s0 & !s1, !s0 & s1, s0 & s1]

fn fold_vec(a: felt[12]) -> felt:
return sum([x for x in a])

fn fold_scalar_and_vec(a: felt, b: felt[12]) -> felt:
let m = fold_vec(b)
let n = m + 1
let o = n * 2
return o

trace_columns:
main: [t, s0, s1, v, b[12]]
aux: [b_range]

public_inputs:
stack_inputs: [16]

random_values:
alpha: [16]

boundary_constraints:
enf v.first = 0

integrity_constraints:
# let val = $alpha[0] + v
let f = get_multiplicity_flags(s0, s1)
let z = v^4 * f[3] + v^2 * f[2] + v * f[1] + f[0]
# let folded_value = fold_scalar_and_vec(v, b)
# enf b_range' = b_range * (z * t - t + 1)
enf b_range' = b_range * 2
# let y = fold_scalar_and_vec(v, b)
# let c = fold_scalar_and_vec(t, b)
# enf v' = y
166 changes: 166 additions & 0 deletions air-script/tests/functions/functions_complex.masm
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Procedure to efficiently compute the required exponentiations of the out-of-domain point `z` and cache them for later use.
#
# This computes the power of `z` needed to evaluate the periodic polynomials and the constraint divisors
#
# Input: [...]
# Output: [...]
proc.cache_z_exp
padw mem_loadw.4294903304 drop drop # load z
# => [z_1, z_0, ...]
# Exponentiate z trace_len times
mem_load.4294903307 neg
# => [count, z_1, z_0, ...] where count = -log2(trace_len)
dup.0 neq.0
while.true
movdn.2 dup.1 dup.1 ext2mul
# => [(e_1, e_0)^n, i, ...]
movup.2 add.1 dup.0 neq.0
# => [b, i+1, (e_1, e_0)^n, ...]
end # END while
push.0 mem_storew.500000100 # z^trace_len
# => [0, 0, (z_1, z_0)^trace_len, ...]
dropw # Clean stack
end # END PROC cache_z_exp

# Procedure to compute the exemption points.
#
# Input: [...]
# Output: [g^{-2}, g^{-1}, ...]
proc.get_exemptions_points
mem_load.4294799999
# => [g, ...]
push.1 swap div
# => [g^{-1}, ...]
dup.0 dup.0 mul
# => [g^{-2}, g^{-1}, ...]
end # END PROC get_exemptions_points

# Procedure to compute the integrity constraint divisor.
#
# The divisor is defined as `(z^trace_len - 1) / ((z - g^{trace_len-2}) * (z - g^{trace_len-1}))`
# Procedure `cache_z_exp` must have been called prior to this.
#
# Input: [...]
# Output: [divisor_1, divisor_0, ...]
proc.compute_integrity_constraint_divisor
padw mem_loadw.500000100 drop drop # load z^trace_len
# Comments below use zt = `z^trace_len`
# => [zt_1, zt_0, ...]
push.1 push.0 ext2sub
# => [zt_1-1, zt_0-1, ...]
padw mem_loadw.4294903304 drop drop # load z
# => [z_1, z_0, zt_1-1, zt_0-1, ...]
exec.get_exemptions_points
# => [g^{trace_len-2}, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...]
dup.0 mem_store.500000101 # Save a copy of `g^{trace_len-2} to be used by the boundary divisor
dup.3 dup.3 movup.3 push.0 ext2sub
# => [e_1, e_0, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...]
movup.4 movup.4 movup.4 push.0 ext2sub
# => [e_3, e_2, e_1, e_0, zt_1-1, zt_0-1, ...]
ext2mul
# => [denominator_1, denominator_0, zt_1-1, zt_0-1, ...]
ext2div
# => [divisor_1, divisor_0, ...]
end # END PROC compute_integrity_constraint_divisor

# Procedure to evaluate numerators of all integrity constraints.
#
# All the 1 main and 1 auxiliary constraints are evaluated.
# The result of each evaluation is kept on the stack, with the top of the stack
# containing the evaluations for the auxiliary trace (if any) followed by the main trace.
#
# Input: [...]
# Output: [(r_1, r_0)*, ...]
# where: (r_1, r_0) is the quadratic extension element resulting from the integrity constraint evaluation.
# This procedure pushes 2 quadratic extension field elements to the stack
proc.compute_integrity_constraints
# integrity constraint 0 for main
padw mem_loadw.4294900003 drop drop padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900007 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900008 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900009 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900010 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900011 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900012 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900013 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900014 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900015 movdn.3 movdn.3 drop drop ext2add push.1 push.0 ext2add push.2 push.0 ext2mul ext2sub
# Multiply by the composition coefficient
padw mem_loadw.4294900200 movdn.3 movdn.3 drop drop ext2mul
# integrity constraint 0 for aux
padw mem_loadw.4294900072 drop drop padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add
# push the accumulator to the stack
push.1 movdn.2 push.0 movdn.2
# => [b1, b0, r1, r0, ...]
# square 2 times
dup.1 dup.1 ext2mul dup.1 dup.1 ext2mul
# multiply
dup.1 dup.1 movdn.5 movdn.5
# => [b1, b0, r1, r0, b1, b0, ...] (4 cycles)
ext2mul movdn.3 movdn.3
# => [b1, b0, r1', r0', ...] (5 cycles)
# clean stack
drop drop
# => [r1, r0, ...] (2 cycles)
padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add
# push the accumulator to the stack
push.1 movdn.2 push.0 movdn.2
# => [b1, b0, r1, r0, ...]
# square 1 times
dup.1 dup.1 ext2mul
# multiply
dup.1 dup.1 movdn.5 movdn.5
# => [b1, b0, r1, r0, b1, b0, ...] (4 cycles)
ext2mul movdn.3 movdn.3
# => [b1, b0, r1', r0', ...] (5 cycles)
# clean stack
drop drop
# => [r1, r0, ...] (2 cycles)
push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul ext2add padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2mul ext2add push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 ext2add ext2mul ext2sub
# Multiply by the composition coefficient
padw mem_loadw.4294900200 drop drop ext2mul
end # END PROC compute_integrity_constraints

# Procedure to evaluate the boundary constraint numerator for the first row of the main trace
#
# Input: [...]
# Output: [(r_1, r_0)*, ...]
# Where: (r_1, r_0) is one quadratic extension field element for each constraint
proc.compute_boundary_constraints_main_first
# boundary constraint 0 for main
padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop push.0 push.0 ext2sub
# Multiply by the composition coefficient
padw mem_loadw.4294900201 movdn.3 movdn.3 drop drop ext2mul
end # END PROC compute_boundary_constraints_main_first

# Procedure to evaluate all integrity constraints.
#
# Input: [...]
# Output: [(r_1, r_0), ...]
# Where: (r_1, r_0) is the final result with the divisor applied
proc.evaluate_integrity_constraints
exec.compute_integrity_constraints
# Numerator of the transition constraint polynomial
ext2add ext2add
# Divisor of the transition constraint polynomial
exec.compute_integrity_constraint_divisor
ext2div # divide the numerator by the divisor
end # END PROC evaluate_integrity_constraints

# Procedure to evaluate all boundary constraints.
#
# Input: [...]
# Output: [(r_1, r_0), ...]
# Where: (r_1, r_0) is the final result with the divisor applied
proc.evaluate_boundary_constraints
exec.compute_boundary_constraints_main_first
# => [(first1, first0), ...]
# Compute the denominator for domain FirstRow
padw mem_loadw.4294903304 drop drop # load z
push.1 push.0 ext2sub
# Compute numerator/denominator for first row
ext2div
end # END PROC evaluate_boundary_constraints

# Procedure to evaluate the integrity and boundary constraints.
#
# Input: [...]
# Output: [(r_1, r_0), ...]
export.evaluate_constraints
exec.cache_z_exp
exec.evaluate_integrity_constraints
exec.evaluate_boundary_constraints
ext2add
end # END PROC evaluate_constraints

91 changes: 91 additions & 0 deletions air-script/tests/functions/functions_complex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo};
use winter_math::fields::f64::BaseElement as Felt;
use winter_math::{ExtensionOf, FieldElement};
use winter_utils::collections::Vec;
use winter_utils::{ByteWriter, Serializable};

pub struct PublicInputs {
stack_inputs: [Felt; 16],
}

impl PublicInputs {
pub fn new(stack_inputs: [Felt; 16]) -> Self {
Self { stack_inputs }
}
}

impl Serializable for PublicInputs {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.stack_inputs.as_slice());
}
}

pub struct FunctionsAir {
context: AirContext<Felt>,
stack_inputs: [Felt; 16],
}

impl FunctionsAir {
pub fn last_step(&self) -> usize {
self.trace_length() - self.context().num_transition_exemptions()
}
}

impl Air for FunctionsAir {
type BaseField = Felt;
type PublicInputs = PublicInputs;

fn context(&self) -> &AirContext<Felt> {
&self.context
}

fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self {
let main_degrees = vec![TransitionConstraintDegree::new(1)];
let aux_degrees = vec![TransitionConstraintDegree::new(8)];
let num_main_assertions = 1;
let num_aux_assertions = 0;

let context = AirContext::new_multi_segment(
trace_info,
main_degrees,
aux_degrees,
num_main_assertions,
num_aux_assertions,
options,
)
.set_num_transition_exemptions(2);
Self { context, stack_inputs: public_inputs.stack_inputs }
}

fn get_periodic_column_values(&self) -> Vec<Vec<Felt>> {
vec![]
}

fn get_assertions(&self) -> Vec<Assertion<Felt>> {
let mut result = Vec::new();
result.push(Assertion::single(3, 0, Felt::ZERO));
result
}

fn get_aux_assertions<E: FieldElement<BaseField = Felt>>(&self, aux_rand_elements: &AuxTraceRandElements<E>) -> Vec<Assertion<E>> {
let mut result = Vec::new();
result
}

fn evaluate_transition<E: FieldElement<BaseField = Felt>>(&self, frame: &EvaluationFrame<E>, periodic_values: &[E], result: &mut [E]) {
let main_current = frame.current();
let main_next = frame.next();
result[0] = main_next[3] - (main_current[4] + main_current[5] + main_current[6] + main_current[7] + main_current[8] + main_current[9] + main_current[10] + main_current[11] + main_current[12] + main_current[13] + main_current[14] + main_current[15] + E::ONE) * E::from(2_u64);
}

fn evaluate_aux_transition<F, E>(&self, main_frame: &EvaluationFrame<F>, aux_frame: &EvaluationFrame<E>, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements<E>, result: &mut [E])
where F: FieldElement<BaseField = Felt>,
E: FieldElement<BaseField = Felt> + ExtensionOf<F>,
{
let main_current = main_frame.current();
let main_next = main_frame.next();
let aux_current = aux_frame.current();
let aux_next = aux_frame.next();
result[0] = aux_next[0] - aux_current[0] * (((aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])).exp(E::PositiveInteger::from(4_u64)) * E::from(main_current[1]) * E::from(main_current[2]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])).exp(E::PositiveInteger::from(2_u64)) * (E::ONE - E::from(main_current[1])) * E::from(main_current[2]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])) * E::from(main_current[1]) * (E::ONE - E::from(main_current[2])) + (E::ONE - E::from(main_current[1])) * (E::ONE - E::from(main_current[2]))) * E::from(main_current[0]) - E::from(main_current[0]) + E::ONE);
}
}
Loading

0 comments on commit 683d4c1

Please sign in to comment.