Skip to content

Commit

Permalink
feat: add support for functions
Browse files Browse the repository at this point in the history
  • Loading branch information
tohrnii committed Aug 1, 2023
1 parent 5f7411b commit 7c6ef62
Show file tree
Hide file tree
Showing 21 changed files with 921 additions and 16 deletions.
10 changes: 10 additions & 0 deletions air-script/tests/codegen/masm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ fn evaluators() {
expected.assert_eq(&generated_masm);
}

#[test]
fn functions() {
let generated_masm = Test::new("tests/functions/functions.air".to_string())
.transpile(Target::Masm)
.unwrap();

let expected = expect_file!["../functions/functions.masm"];
expected.assert_eq(&generated_masm);
}

#[test]
fn variables() {
let generated_masm = Test::new("tests/variables/variables.air".to_string())
Expand Down
10 changes: 10 additions & 0 deletions air-script/tests/codegen/winterfell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ fn evaluators() {
expected.assert_eq(&generated_air);
}

#[test]
fn functions() {
let generated_air = Test::new("tests/functions/functions.air".to_string())
.transpile(Target::Winterfell)
.unwrap();

let expected = expect_file!["../functions/functions.rs"];
expected.assert_eq(&generated_air);
}

#[test]
fn variables() {
let generated_air = Test::new("tests/variables/variables.air".to_string())
Expand Down
23 changes: 23 additions & 0 deletions air-script/tests/functions/functions.air
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
def FunctionsAir

fn get_multiplicity_flags(s0: felt, s1: felt) -> felt[4]:
return [!s0 & !s1, s0 & !s1, !s0 & s1, s0 & s1]

trace_columns:
main: [t, s0, s1, v]
aux: [b_range]

public_inputs:
stack_inputs: [16]

random_values:
alpha: [16]

boundary_constraints:
enf v.first = 0

integrity_constraints:
let val = $alpha[0] + v
let f = get_multiplicity_flags(s0, s1)
let z = val^4 * f[3] + val^2 * f[2] + val * f[1] + f[0]
enf b_range' = b_range * (z * t - t + 1)
161 changes: 161 additions & 0 deletions air-script/tests/functions/functions.masm
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Procedure to efficiently compute the required exponentiations of the out-of-domain point `z` and cache them for later use.
#
# This computes the power of `z` needed to evaluate the periodic polynomials and the constraint divisors
#
# Input: [...]
# Output: [...]
proc.cache_z_exp
padw mem_loadw.4294903304 drop drop # load z
# => [z_1, z_0, ...]
# Exponentiate z trace_len times
mem_load.4294903307 neg
# => [count, z_1, z_0, ...] where count = -log2(trace_len)
dup.0 neq.0
while.true
movdn.2 dup.1 dup.1 ext2mul
# => [(e_1, e_0)^n, i, ...]
movup.2 add.1 dup.0 neq.0
# => [b, i+1, (e_1, e_0)^n, ...]
end # END while
push.0 mem_storew.500000100 # z^trace_len
# => [0, 0, (z_1, z_0)^trace_len, ...]
dropw # Clean stack
end # END PROC cache_z_exp

# Procedure to compute the exemption points.
#
# Input: [...]
# Output: [g^{-2}, g^{-1}, ...]
proc.get_exemptions_points
mem_load.4294799999
# => [g, ...]
push.1 swap div
# => [g^{-1}, ...]
dup.0 dup.0 mul
# => [g^{-2}, g^{-1}, ...]
end # END PROC get_exemptions_points

# Procedure to compute the integrity constraint divisor.
#
# The divisor is defined as `(z^trace_len - 1) / ((z - g^{trace_len-2}) * (z - g^{trace_len-1}))`
# Procedure `cache_z_exp` must have been called prior to this.
#
# Input: [...]
# Output: [divisor_1, divisor_0, ...]
proc.compute_integrity_constraint_divisor
padw mem_loadw.500000100 drop drop # load z^trace_len
# Comments below use zt = `z^trace_len`
# => [zt_1, zt_0, ...]
push.1 push.0 ext2sub
# => [zt_1-1, zt_0-1, ...]
padw mem_loadw.4294903304 drop drop # load z
# => [z_1, z_0, zt_1-1, zt_0-1, ...]
exec.get_exemptions_points
# => [g^{trace_len-2}, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...]
dup.0 mem_store.500000101 # Save a copy of `g^{trace_len-2} to be used by the boundary divisor
dup.3 dup.3 movup.3 push.0 ext2sub
# => [e_1, e_0, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...]
movup.4 movup.4 movup.4 push.0 ext2sub
# => [e_3, e_2, e_1, e_0, zt_1-1, zt_0-1, ...]
ext2mul
# => [denominator_1, denominator_0, zt_1-1, zt_0-1, ...]
ext2div
# => [divisor_1, divisor_0, ...]
end # END PROC compute_integrity_constraint_divisor

# Procedure to evaluate numerators of all integrity constraints.
#
# All the 0 main and 1 auxiliary constraints are evaluated.
# The result of each evaluation is kept on the stack, with the top of the stack
# containing the evaluations for the auxiliary trace (if any) followed by the main trace.
#
# Input: [...]
# Output: [(r_1, r_0)*, ...]
# where: (r_1, r_0) is the quadratic extension element resulting from the integrity constraint evaluation.
# This procedure pushes 1 quadratic extension field elements to the stack
proc.compute_integrity_constraints
# integrity constraint 0 for aux
padw mem_loadw.4294900072 drop drop padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add
# push the accumulator to the stack
push.1 movdn.2 push.0 movdn.2
# => [b1, b0, r1, r0, ...]
# square 2 times
dup.1 dup.1 ext2mul dup.1 dup.1 ext2mul
# multiply
dup.1 dup.1 movdn.5 movdn.5
# => [b1, b0, r1, r0, b1, b0, ...] (4 cycles)
ext2mul movdn.3 movdn.3
# => [b1, b0, r1', r0', ...] (5 cycles)
# clean stack
drop drop
# => [r1, r0, ...] (2 cycles)
padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add
# push the accumulator to the stack
push.1 movdn.2 push.0 movdn.2
# => [b1, b0, r1, r0, ...]
# square 1 times
dup.1 dup.1 ext2mul
# multiply
dup.1 dup.1 movdn.5 movdn.5
# => [b1, b0, r1, r0, b1, b0, ...] (4 cycles)
ext2mul movdn.3 movdn.3
# => [b1, b0, r1', r0', ...] (5 cycles)
# clean stack
drop drop
# => [r1, r0, ...] (2 cycles)
push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul ext2add padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2mul ext2add push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 ext2add ext2mul ext2sub
# Multiply by the composition coefficient
padw mem_loadw.4294900200 movdn.3 movdn.3 drop drop ext2mul
end # END PROC compute_integrity_constraints

# Procedure to evaluate the boundary constraint numerator for the first row of the main trace
#
# Input: [...]
# Output: [(r_1, r_0)*, ...]
# Where: (r_1, r_0) is one quadratic extension field element for each constraint
proc.compute_boundary_constraints_main_first
# boundary constraint 0 for main
padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop push.0 push.0 ext2sub
# Multiply by the composition coefficient
padw mem_loadw.4294900200 drop drop ext2mul
end # END PROC compute_boundary_constraints_main_first

# Procedure to evaluate all integrity constraints.
#
# Input: [...]
# Output: [(r_1, r_0), ...]
# Where: (r_1, r_0) is the final result with the divisor applied
proc.evaluate_integrity_constraints
exec.compute_integrity_constraints
# Numerator of the transition constraint polynomial
ext2add
# Divisor of the transition constraint polynomial
exec.compute_integrity_constraint_divisor
ext2div # divide the numerator by the divisor
end # END PROC evaluate_integrity_constraints

# Procedure to evaluate all boundary constraints.
#
# Input: [...]
# Output: [(r_1, r_0), ...]
# Where: (r_1, r_0) is the final result with the divisor applied
proc.evaluate_boundary_constraints
exec.compute_boundary_constraints_main_first
# => [(first1, first0), ...]
# Compute the denominator for domain FirstRow
padw mem_loadw.4294903304 drop drop # load z
push.1 push.0 ext2sub
# Compute numerator/denominator for first row
ext2div
end # END PROC evaluate_boundary_constraints

# Procedure to evaluate the integrity and boundary constraints.
#
# Input: [...]
# Output: [(r_1, r_0), ...]
export.evaluate_constraints
exec.cache_z_exp
exec.evaluate_integrity_constraints
exec.evaluate_boundary_constraints
ext2add
end # END PROC evaluate_constraints
90 changes: 90 additions & 0 deletions air-script/tests/functions/functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo};
use winter_math::fields::f64::BaseElement as Felt;
use winter_math::{ExtensionOf, FieldElement};
use winter_utils::collections::Vec;
use winter_utils::{ByteWriter, Serializable};

pub struct PublicInputs {
stack_inputs: [Felt; 16],
}

impl PublicInputs {
pub fn new(stack_inputs: [Felt; 16]) -> Self {
Self { stack_inputs }
}
}

impl Serializable for PublicInputs {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.stack_inputs.as_slice());
}
}

pub struct FunctionsAir {
context: AirContext<Felt>,
stack_inputs: [Felt; 16],
}

impl FunctionsAir {
pub fn last_step(&self) -> usize {
self.trace_length() - self.context().num_transition_exemptions()
}
}

impl Air for FunctionsAir {
type BaseField = Felt;
type PublicInputs = PublicInputs;

fn context(&self) -> &AirContext<Felt> {
&self.context
}

fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self {
let main_degrees = vec![];
let aux_degrees = vec![TransitionConstraintDegree::new(8)];
let num_main_assertions = 1;
let num_aux_assertions = 0;

let context = AirContext::new_multi_segment(
trace_info,
main_degrees,
aux_degrees,
num_main_assertions,
num_aux_assertions,
options,
)
.set_num_transition_exemptions(2);
Self { context, stack_inputs: public_inputs.stack_inputs }
}

fn get_periodic_column_values(&self) -> Vec<Vec<Felt>> {
vec![]
}

fn get_assertions(&self) -> Vec<Assertion<Felt>> {
let mut result = Vec::new();
result.push(Assertion::single(3, 0, Felt::ZERO));
result
}

fn get_aux_assertions<E: FieldElement<BaseField = Felt>>(&self, aux_rand_elements: &AuxTraceRandElements<E>) -> Vec<Assertion<E>> {
let mut result = Vec::new();
result
}

fn evaluate_transition<E: FieldElement<BaseField = Felt>>(&self, frame: &EvaluationFrame<E>, periodic_values: &[E], result: &mut [E]) {
let main_current = frame.current();
let main_next = frame.next();
}

fn evaluate_aux_transition<F, E>(&self, main_frame: &EvaluationFrame<F>, aux_frame: &EvaluationFrame<E>, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements<E>, result: &mut [E])
where F: FieldElement<BaseField = Felt>,
E: FieldElement<BaseField = Felt> + ExtensionOf<F>,
{
let main_current = main_frame.current();
let main_next = main_frame.next();
let aux_current = aux_frame.current();
let aux_next = aux_frame.next();
result[0] = aux_next[0] - aux_current[0] * (((aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])).exp(E::PositiveInteger::from(4_u64)) * E::from(main_current[1]) * E::from(main_current[2]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])).exp(E::PositiveInteger::from(2_u64)) * (E::ONE - E::from(main_current[1])) * E::from(main_current[2]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])) * E::from(main_current[1]) * (E::ONE - E::from(main_current[2])) + (E::ONE - E::from(main_current[1])) * (E::ONE - E::from(main_current[2]))) * E::from(main_current[0]) - E::from(main_current[0]) + E::ONE);
}
}
55 changes: 55 additions & 0 deletions parser/src/ast/declarations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub enum Declaration {
///
/// Evaluator functions can be defined in any module of the program
EvaluatorFunction(EvaluatorFunction),
/// A pure function definition
///
/// Pure functions can be defined in any module of the program
Function(Function),
/// A `periodic_columns` section declaration
///
/// This may appear any number of times in the program, and may be declared in any module.
Expand Down Expand Up @@ -523,3 +527,54 @@ impl PartialEq for EvaluatorFunction {
self.name == other.name && self.params == other.params && self.body == other.body
}
}

/// Functions take a group of expressions as parameters and returns a group of expressions. These
/// values can be a Felt, a Vector or a Matrix. Functions do not take trace bindings as parameters.
#[derive(Debug, Clone, Spanned)]
pub struct Function {
#[span]
pub span: SourceSpan,
pub name: Identifier,
pub params: Vec<(Identifier, Type)>,
pub return_type: Type,
pub body: Vec<Statement>,
}
impl Function {
/// Creates a new function.
pub const fn new(
span: SourceSpan,
name: Identifier,
params: Vec<(Identifier, Type)>,
return_type: Type,
body: Vec<Statement>,
) -> Self {
Self {
span,
name,
params,
return_type,
body,
}
}

pub fn param_bindings(&self) -> Vec<Identifier> {
self.params
.iter()
.map(|(name, _)| *name)
.collect::<Vec<_>>()
}

pub fn param_types(&self) -> Vec<Type> {
self.params.iter().map(|(_, ty)| *ty).collect::<Vec<_>>()
}
}

impl Eq for Function {}
impl PartialEq for Function {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.params == other.params
&& self.return_type == other.return_type
&& self.body == other.body
}
}
Loading

0 comments on commit 7c6ef62

Please sign in to comment.