diff --git a/air-script/tests/codegen/masm.rs b/air-script/tests/codegen/masm.rs index d952e494..897e6a5c 100644 --- a/air-script/tests/codegen/masm.rs +++ b/air-script/tests/codegen/masm.rs @@ -74,6 +74,16 @@ fn constants() { expected.assert_eq(&generated_masm); } +#[test] +fn constant_in_range() { + let generated_masm = Test::new("tests/constant_in_range/constant_in_range.air".to_string()) + .transpile(Target::Masm) + .unwrap(); + + let expected = expect_file!["../constant_in_range/constant_in_range.masm"]; + expected.assert_eq(&generated_masm); +} + #[test] fn evaluators() { let generated_masm = Test::new("tests/evaluators/evaluators.air".to_string()) diff --git a/air-script/tests/codegen/winterfell.rs b/air-script/tests/codegen/winterfell.rs index 3d2874e0..7227fd43 100644 --- a/air-script/tests/codegen/winterfell.rs +++ b/air-script/tests/codegen/winterfell.rs @@ -74,6 +74,16 @@ fn constants() { expected.assert_eq(&generated_air); } +#[test] +fn constant_in_range() { + let generated_air = Test::new("tests/constant_in_range/constant_in_range.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../constant_in_range/constant_in_range.rs"]; + expected.assert_eq(&generated_air); +} + #[test] fn evaluators() { let generated_air = Test::new("tests/evaluators/evaluators.air".to_string()) diff --git a/air-script/tests/constant_in_range/constant_in_range.air b/air-script/tests/constant_in_range/constant_in_range.air new file mode 100644 index 00000000..25974c49 --- /dev/null +++ b/air-script/tests/constant_in_range/constant_in_range.air @@ -0,0 +1,21 @@ +def ConstantInRangeAir + +use constant_in_range_module::MIN; +const MAX = 3; + +trace_columns { + main: [a, b[3], c[4], d[4]], +} + +public_inputs { + stack_inputs: [16], +} + +boundary_constraints { + enf c[2].first = 0; +} + +integrity_constraints { + let m = [w + x - y - z for (w, x, y, z) in (MIN..MAX, b, c[MIN..MAX], d[MIN..MAX])]; + enf a = m[0] + m[1] + m[2]; +} \ No newline at end of file diff --git a/air-script/tests/constant_in_range/constant_in_range.masm b/air-script/tests/constant_in_range/constant_in_range.masm new file mode 100644 index 00000000..13ab1510 --- /dev/null +++ b/air-script/tests/constant_in_range/constant_in_range.masm @@ -0,0 +1,134 @@ +# Procedure to efficiently compute the required exponentiations of the out-of-domain point `z` and cache them for later use. +# +# This computes the power of `z` needed to evaluate the periodic polynomials and the constraint divisors +# +# Input: [...] +# Output: [...] +proc.cache_z_exp + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, ...] + # Exponentiate z trace_len times + mem_load.4294903307 neg + # => [count, z_1, z_0, ...] where count = -log2(trace_len) + dup.0 neq.0 + while.true + movdn.2 dup.1 dup.1 ext2mul + # => [(e_1, e_0)^n, i, ...] + movup.2 add.1 dup.0 neq.0 + # => [b, i+1, (e_1, e_0)^n, ...] + end # END while + push.0 mem_storew.500000100 # z^trace_len + # => [0, 0, (z_1, z_0)^trace_len, ...] + dropw # Clean stack +end # END PROC cache_z_exp + +# Procedure to compute the exemption points. +# +# Input: [...] +# Output: [g^{-2}, g^{-1}, ...] +proc.get_exemptions_points + mem_load.4294799999 + # => [g, ...] + push.1 swap div + # => [g^{-1}, ...] + dup.0 dup.0 mul + # => [g^{-2}, g^{-1}, ...] +end # END PROC get_exemptions_points + +# Procedure to compute the integrity constraint divisor. +# +# The divisor is defined as `(z^trace_len - 1) / ((z - g^{trace_len-2}) * (z - g^{trace_len-1}))` +# Procedure `cache_z_exp` must have been called prior to this. +# +# Input: [...] +# Output: [divisor_1, divisor_0, ...] +proc.compute_integrity_constraint_divisor + padw mem_loadw.500000100 drop drop # load z^trace_len + # Comments below use zt = `z^trace_len` + # => [zt_1, zt_0, ...] + push.1 push.0 ext2sub + # => [zt_1-1, zt_0-1, ...] + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, zt_1-1, zt_0-1, ...] + exec.get_exemptions_points + # => [g^{trace_len-2}, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + dup.0 mem_store.500000101 # Save a copy of `g^{trace_len-2} to be used by the boundary divisor + dup.3 dup.3 movup.3 push.0 ext2sub + # => [e_1, e_0, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + movup.4 movup.4 movup.4 push.0 ext2sub + # => [e_3, e_2, e_1, e_0, zt_1-1, zt_0-1, ...] + ext2mul + # => [denominator_1, denominator_0, zt_1-1, zt_0-1, ...] + ext2div + # => [divisor_1, divisor_0, ...] +end # END PROC compute_integrity_constraint_divisor + +# Procedure to evaluate numerators of all integrity constraints. +# +# All the 1 main and 0 auxiliary constraints are evaluated. +# The result of each evaluation is kept on the stack, with the top of the stack +# containing the evaluations for the auxiliary trace (if any) followed by the main trace. +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# where: (r_1, r_0) is the quadratic extension element resulting from the integrity constraint evaluation. +# This procedure pushes 1 quadratic extension field elements to the stack +proc.compute_integrity_constraints + # integrity constraint 0 for main + padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop push.0 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop ext2sub padw mem_loadw.4294900008 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900005 movdn.3 movdn.3 drop drop ext2sub padw mem_loadw.4294900009 movdn.3 movdn.3 drop drop ext2sub ext2add push.2 push.0 padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop ext2sub padw mem_loadw.4294900010 movdn.3 movdn.3 drop drop ext2sub ext2add ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 movdn.3 movdn.3 drop drop ext2mul +end # END PROC compute_integrity_constraints + +# Procedure to evaluate the boundary constraint numerator for the first row of the main trace +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# Where: (r_1, r_0) is one quadratic extension field element for each constraint +proc.compute_boundary_constraints_main_first + # boundary constraint 0 for main + padw mem_loadw.4294900006 movdn.3 movdn.3 drop drop push.0 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 drop drop ext2mul +end # END PROC compute_boundary_constraints_main_first + +# Procedure to evaluate all integrity constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_integrity_constraints + exec.compute_integrity_constraints + # Numerator of the transition constraint polynomial + ext2add + # Divisor of the transition constraint polynomial + exec.compute_integrity_constraint_divisor + ext2div # divide the numerator by the divisor +end # END PROC evaluate_integrity_constraints + +# Procedure to evaluate all boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_boundary_constraints + exec.compute_boundary_constraints_main_first + # => [(first1, first0), ...] + # Compute the denominator for domain FirstRow + padw mem_loadw.4294903304 drop drop # load z + push.1 push.0 ext2sub + # Compute numerator/denominator for first row + ext2div +end # END PROC evaluate_boundary_constraints + +# Procedure to evaluate the integrity and boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +export.evaluate_constraints + exec.cache_z_exp + exec.evaluate_integrity_constraints + exec.evaluate_boundary_constraints + ext2add +end # END PROC evaluate_constraints + diff --git a/air-script/tests/constant_in_range/constant_in_range.rs b/air-script/tests/constant_in_range/constant_in_range.rs new file mode 100644 index 00000000..b9b97791 --- /dev/null +++ b/air-script/tests/constant_in_range/constant_in_range.rs @@ -0,0 +1,90 @@ +use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; +use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{ExtensionOf, FieldElement}; +use winter_utils::collections::Vec; +use winter_utils::{ByteWriter, Serializable}; + +pub struct PublicInputs { + stack_inputs: [Felt; 16], +} + +impl PublicInputs { + pub fn new(stack_inputs: [Felt; 16]) -> Self { + Self { stack_inputs } + } +} + +impl Serializable for PublicInputs { + fn write_into(&self, target: &mut W) { + target.write(self.stack_inputs.as_slice()); + } +} + +pub struct ConstantInRangeAir { + context: AirContext, + stack_inputs: [Felt; 16], +} + +impl ConstantInRangeAir { + pub fn last_step(&self) -> usize { + self.trace_length() - self.context().num_transition_exemptions() + } +} + +impl Air for ConstantInRangeAir { + type BaseField = Felt; + type PublicInputs = PublicInputs; + + fn context(&self) -> &AirContext { + &self.context + } + + fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { + let main_degrees = vec![TransitionConstraintDegree::new(1)]; + let aux_degrees = vec![]; + let num_main_assertions = 1; + let num_aux_assertions = 0; + + let context = AirContext::new_multi_segment( + trace_info, + main_degrees, + aux_degrees, + num_main_assertions, + num_aux_assertions, + options, + ) + .set_num_transition_exemptions(2); + Self { context, stack_inputs: public_inputs.stack_inputs } + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + fn get_assertions(&self) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(6, 0, Felt::ZERO)); + result + } + + fn get_aux_assertions>(&self, aux_rand_elements: &AuxTraceRandElements) -> Vec> { + let mut result = Vec::new(); + result + } + + fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { + let main_current = frame.current(); + let main_next = frame.next(); + result[0] = main_current[0] - (E::ZERO + main_current[1] - main_current[4] - main_current[8] + E::ONE + main_current[2] - main_current[5] - main_current[9] + E::from(2_u64) + main_current[3] - main_current[6] - main_current[10]); + } + + fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements, result: &mut [E]) + where F: FieldElement, + E: FieldElement + ExtensionOf, + { + let main_current = main_frame.current(); + let main_next = main_frame.next(); + let aux_current = aux_frame.current(); + let aux_next = aux_frame.next(); + } +} \ No newline at end of file diff --git a/air-script/tests/constant_in_range/constant_in_range_module.air b/air-script/tests/constant_in_range/constant_in_range_module.air new file mode 100644 index 00000000..ec0823b9 --- /dev/null +++ b/air-script/tests/constant_in_range/constant_in_range_module.air @@ -0,0 +1,3 @@ +mod constant_in_range_module + +const MIN = 0; \ No newline at end of file diff --git a/docs/src/description/constraints.md b/docs/src/description/constraints.md index 76b78c07..51fa2285 100644 --- a/docs/src/description/constraints.md +++ b/docs/src/description/constraints.md @@ -149,8 +149,8 @@ The following is a simple example of a valid `integrity_constraints` source sect def IntegrityConstraintsExample trace_columns { - main: [a, b] - aux: [p] + main: [a, b], + aux: [p], } public_inputs { @@ -163,11 +163,11 @@ boundary_constraints { integrity_constraints { # these are main constraints. they both express the same constraint. - enf a' = a + 1 - enf b' - b - 1 = 0 + enf a' = a + 1; + enf b' - b - 1 = 0; # this is an auxiliary constraint, since it uses an auxiliary trace column. - enf p = p' * a + enf p = p' * a; } ``` @@ -187,8 +187,8 @@ The following is an example of a valid `integrity_constraints` source section th def IntegrityConstraintsExample trace_columns { - main: [a, b] - aux: [p0, p1] + main: [a, b], + aux: [p0, p1], } public_inputs { @@ -196,11 +196,11 @@ public_inputs { } periodic_columns { - k: [1, 1, 1, 0] + k: [1, 1, 1, 0], } random_values { - rand: [16] + rand: [16], } boundary_constraints { @@ -209,14 +209,14 @@ boundary_constraints { integrity_constraints { # this is a main constraint that uses a periodic column. - enf a' = k * a + enf a' = k * a; # this is an auxiliary constraint that uses a periodic column. - enf p0' = k * p0 + enf p0' = k * p0; # these are auxiliary constraints that use random values from the verifier. - enf b = a + $rand[0] - enf p1 = k * (a + $rand[0]) * (b + $rand[1]) + enf b = a + $rand[0]; + enf p1 = k * (a + $rand[0]) * (b + $rand[1]); } ``` diff --git a/docs/src/description/convenience.md b/docs/src/description/convenience.md index 1d6ddbd7..d19fdad6 100644 --- a/docs/src/description/convenience.md +++ b/docs/src/description/convenience.md @@ -7,22 +7,28 @@ To make writing constraints easier, AirScript provides a number of syntactic con List comprehension provides a simple way to create new vectors. It is similar to the list comprehension syntax in Python. The following examples show how to use list comprehension in AirScript. ``` -let x = [a * 2 for a in b] +let x = [a * 2 for a in b]; ``` This will create a new vector with the same length as `b` and the value of each element will be twice that of the corresponding element in `b`. ``` -let x = [a + b for (a, b) in (c, d)] +let x = [a + b for (a, b) in (c, d)]; ``` This will create a new vector with the same length as `c` and `d` and the value of each element will be the sum of the corresponding elements in `c` and `d`. This will throw an error if `c` and `d` vectors are of unequal lengths. ``` -let x = [2^i * a for (i, a) in (0..5, b)] +let x = [2^i * a for (i, a) in (0..5, b)]; ``` Ranges can also be used as iterables, which makes it easy to refer to an element and its index at the same time. This will create a new vector with length 5 and each element will be the corresponding element in `b` multiplied by 2 raised to the power of the element's index. This will throw an error if `b` is not of length 5. ``` -let x = [m + n + o for (m, n, o) in (a, 0..5, c[0..5])] +const MAX = 5; +let x = [2^i * a for (i, a) in (0..MAX, b)]; +``` +Ranges are defined with each bound being either an integer literal, or a [named constant](./declarations.md#constants-const) of type scalar. + +``` +let x = [m + n + o for (m, n, o) in (a, 0..5, c[0..5])]; ``` Slices can also be used as iterables. This will create a new vector with length 5 and each element will be the sum of the corresponding elements in `a`, the range 0 to 5, and the first 5 elements of `c`. This will throw an error if `a` is not of length 5 or if `c` is of length less than 5. @@ -32,13 +38,13 @@ List folding provides syntactic convenience for folding vectors into expressions ``` trace_columns { - main: [a[5], b, c] + main: [a[5], b, c], } integrity_constraints { - let x = sum(a) - let y = sum([a[0], a[1], a[2], a[3], a[4]]) - let z = sum([a * 2 for a in a]) + let x = sum(a); + let y = sum([a[0], a[1], a[2], a[3], a[4]]); + let z = sum([a * 2 for a in a]); } ``` @@ -46,13 +52,13 @@ In the above, `x` and `y` both represent the sum of all trace column values in t ``` trace_columns { - main: [a[5], b, c] + main: [a[5], b, c], } integrity_constraints { - let x = prod(a) - let y = prod([a[0], a[1], a[2], a[3], a[4]]) - let z = prod([a + 2 for a in a]) + let x = prod(a); + let y = prod([a[0], a[1], a[2], a[3], a[4]]); + let z = prod([a + 2 for a in a]); } ``` @@ -63,36 +69,36 @@ In the above, `x` and `y` both represent the product of all trace column values Constraint comprehension provides a way to enforce the same constraint on multiple values. Conceptually, it is very similar to the list comprehension described above. For example: ``` trace_columns { - main: [a[5], b, c] + main: [a[5], b, c], } integrity_constraints { - enf v^2 = v for v in a + enf v^2 = v for v in a; } ``` The above will enforce $a_i^2 = a_i$ constraint for all columns in the trace column group `a`. Semantically, this is equivalent to: ``` trace_columns { - main: [a[5], b, c] + main: [a[5], b, c], } integrity_constraints { - enf a[0]^2 = a[0] - enf a[1]^2 = a[1] - enf a[2]^2 = a[2] - enf a[3]^2 = a[3] - enf a[4]^2 = a[4] + enf a[0]^2 = a[0]; + enf a[1]^2 = a[1]; + enf a[2]^2 = a[2]; + enf a[3]^2 = a[3]; + enf a[4]^2 = a[4]; } ``` Similar to list comprehension, constraints in constraint comprehension can involve values from multiple lists. For example: ``` trace_columns { - main: [a[5], b[5]] + main: [a[5], b[5]], } integrity_constraints { - enf x' = i * y for (x, y, i) in (a, b, 0..5) + enf x' = i * y for (x, y, i) in (a, b, 0..5); } ``` The above will enforce that $a_i' = i \cdot b_i$ for $i \in [0, 5)$. If the length of either `a` or `b` is not 5, this will throw an error. @@ -102,12 +108,12 @@ The above will enforce that $a_i' = i \cdot b_i$ for $i \in [0, 5)$. If the leng Frequently, we may want to enforce constraints based on some selectors. For example, let's say our trace has 4 columns: `a`, `b`, `c`, and `s`, and we want to enforce that $c' = a + b$ when $s = 1$ and $c' = a \cdot c$ when $s = 0$. We can write these constraints directly like so: ``` trace_columns { - main: [a, b, c, s] + main: [a, b, c, s], } integrity_constraints { - enf s^2 = s - enf c' = s * (a + b) + (1 - s) * (a * b) + enf s^2 = s; + enf c' = s * (a + b) + (1 - s) * (a * b); } ``` Notice that we also need to enforce $s^2 = s$ to ensure that column $s$ can contain only binary values. @@ -115,15 +121,15 @@ Notice that we also need to enforce $s^2 = s$ to ensure that column $s$ can cont While the above approach works, it gets more and more difficult to manage as selectors and constraints get more complicated. To simplify describing constraints for this use case, AirScript introduces `enf match` statement. The above constraints can be described using `enf match` statement as follows: ``` trace_columns { - main: [a, b, c, s] + main: [a, b, c, s], } integrity_constraints { - enf s^2 = s + enf s^2 = s; enf match { case s: c' = a + b case !s: c' = a * c - } + }; } ``` In the above, the syntax of each "option" is `case : `, where selector expression consists of values composed using binary operands and logical operators `!`, `&`, and `|`. AirScript reduces logical operations to their equivalent algebraic operations as follows: @@ -135,18 +141,18 @@ In the above, the syntax of each "option" is `case : AirBuilder<'a> { }, ast::Expr::Range(ref values) => { let values = values - .item - .clone() + .to_slice_range() .map(|v| self.insert_constant(v as u64)) .collect(); Ok(MemoizedBinding::Vector(values)) @@ -404,7 +403,7 @@ impl<'a> AirBuilder<'a> { AccessType::Default => MemoizedBinding::Vector(nodes.clone()), AccessType::Index(idx) => MemoizedBinding::Scalar(nodes[*idx]), AccessType::Slice(range) => { - MemoizedBinding::Vector(nodes[range.start..range.end].to_vec()) + MemoizedBinding::Vector(nodes[range.to_slice_range()].to_vec()) } AccessType::Matrix(_, _) => unreachable!(), }; @@ -415,7 +414,7 @@ impl<'a> AirBuilder<'a> { AccessType::Default => MemoizedBinding::Matrix(nodes.clone()), AccessType::Index(idx) => MemoizedBinding::Vector(nodes[*idx].clone()), AccessType::Slice(range) => { - MemoizedBinding::Matrix(nodes[range.start..range.end].to_vec()) + MemoizedBinding::Matrix(nodes[range.to_slice_range()].to_vec()) } AccessType::Matrix(row, col) => { MemoizedBinding::Scalar(nodes[*row][*col]) diff --git a/parser/src/ast/declarations.rs b/parser/src/ast/declarations.rs index cbb9aaac..3fffe911 100644 --- a/parser/src/ast/declarations.rs +++ b/parser/src/ast/declarations.rs @@ -436,15 +436,22 @@ impl RandBinding { /// Derive a new [RandBinding] derived from the current one given an [AccessType] pub fn access(&self, access_type: AccessType) -> Result { + use super::{RangeBound, RangeExpr}; match access_type { AccessType::Default => Ok(*self), AccessType::Slice(_) if self.is_scalar() => Err(InvalidAccessError::SliceOfScalar), - AccessType::Slice(range) if range.end > self.size => { - Err(InvalidAccessError::IndexOutOfBounds) - } - AccessType::Slice(range) => { - let offset = self.offset + range.start; - let size = range.end - range.start; + AccessType::Slice(RangeExpr { + start: RangeBound::Const(start), + end: RangeBound::Const(end), + .. + }) if start > end => Err(InvalidAccessError::IndexOutOfBounds), + AccessType::Slice(RangeExpr { + start: RangeBound::Const(start), + end: RangeBound::Const(end), + .. + }) => { + let offset = self.offset + start.item; + let size = end.item - start.item; Ok(Self { offset, size, @@ -452,6 +459,9 @@ impl RandBinding { ..*self }) } + AccessType::Slice(_) => { + unreachable!("expected non-constant range bounds to have been erased by this point") + } AccessType::Index(_) if self.is_scalar() => Err(InvalidAccessError::IndexIntoScalar), AccessType::Index(idx) if idx >= self.size => Err(InvalidAccessError::IndexOutOfBounds), AccessType::Index(idx) => { diff --git a/parser/src/ast/errors.rs b/parser/src/ast/errors.rs index 0d8aab66..e8d8904c 100644 --- a/parser/src/ast/errors.rs +++ b/parser/src/ast/errors.rs @@ -7,6 +7,8 @@ pub enum InvalidExprError { InvalidExponent(SourceSpan), #[error("expected exponent to be a constant")] NonConstantExponent(SourceSpan), + #[error("expected constant range expression")] + NonConstantRangeExpr(SourceSpan), #[error("accessing column boundaries is not allowed here")] BoundedSymbolAccess(SourceSpan), #[error("expected scalar expression")] @@ -35,6 +37,14 @@ impl ToDiagnostic for InvalidExprError { "Only constant powers are supported with the exponentiation operator currently" .to_string(), ]), + Self::NonConstantRangeExpr(span) => Diagnostic::error() + .with_message("invalid expression") + .with_labels(vec![ + Label::primary(span.source_id(), span).with_message(message) + ]) + .with_notes(vec![ + "Range expression must be a constant to do this operation".to_string(), + ]), Self::InvalidExponent(span) | Self::BoundedSymbolAccess(span) | Self::InvalidScalarExpr(span) diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index 9f16c013..be68042f 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -275,7 +275,7 @@ pub enum Expr { /// A constant expression Const(Span), /// An expression which evaluates to a vector of integers in the given range - Range(Span), + Range(RangeExpr), /// A vector of expressions /// /// A vector may be used to represent matrices in some situations, but such matrices @@ -307,14 +307,18 @@ impl Expr { /// /// NOTE: This only returns true for the `Const` and `Range` variants pub fn is_constant(&self) -> bool { - matches!(self, Self::Const(_) | Self::Range(_)) + match self { + Self::Const(_) => true, + Self::Range(range) => range.is_constant(), + _ => false, + } } /// Returns the resolved type of this expression, if known pub fn ty(&self) -> Option { match self { Self::Const(constant) => Some(constant.ty()), - Self::Range(range) => Some(Type::Vector(range.item.end - range.item.start)), + Self::Range(range) => range.ty(), Self::Vector(vector) => match vector.first().and_then(|e| e.ty()) { Some(Type::Felt) => Some(Type::Vector(vector.len())), Some(Type::Vector(n)) => Some(Type::Matrix(vector.len(), n)), @@ -338,7 +342,7 @@ impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::Const(ref expr) => f.debug_tuple("Const").field(&expr.item).finish(), - Self::Range(ref expr) => f.debug_tuple("Range").field(&expr.item).finish(), + Self::Range(ref expr) => f.debug_tuple("Range").field(&expr).finish(), Self::Vector(ref expr) => f.debug_tuple("Vector").field(&expr.item).finish(), Self::Matrix(ref expr) => f.debug_tuple("Matrix").field(&expr.item).finish(), Self::SymbolAccess(ref expr) => f.debug_tuple("SymbolAccess").field(expr).finish(), @@ -355,7 +359,7 @@ impl fmt::Display for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::Const(ref expr) => write!(f, "{}", &expr), - Self::Range(ref range) => write!(f, "{}..{}", range.start, range.end), + Self::Range(ref range) => write!(f, "{}", range), Self::Vector(ref expr) => write!(f, "{}", DisplayList(expr.as_slice())), Self::Matrix(ref expr) => { f.write_str("[")?; @@ -595,6 +599,126 @@ impl fmt::Display for ScalarExpr { } } +/// Represents a symbol access to a named constant. +#[derive(Clone, Spanned, Debug)] +pub struct ConstSymbolAccess { + #[span] + pub span: SourceSpan, + pub name: ResolvableIdentifier, + pub ty: Option, +} +impl ConstSymbolAccess { + pub fn new(span: SourceSpan, name: Identifier) -> Self { + Self { + span, + name: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(name)), + ty: None, + } + } +} +impl Eq for ConstSymbolAccess {} +impl PartialEq for ConstSymbolAccess { + fn eq(&self, other: &Self) -> bool { + self.name.eq(&other.name) && self.ty.eq(&other.ty) + } +} +impl fmt::Display for ConstSymbolAccess { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", &self.name) + } +} + +#[derive(Debug, Clone, Spanned)] +pub struct RangeExpr { + #[span] + pub span: SourceSpan, + pub start: RangeBound, + pub end: RangeBound, +} + +impl TryFrom<&RangeExpr> for Range { + type Error = InvalidExprError; + + #[inline] + fn try_from(expr: &RangeExpr) -> Result { + match (&expr.start, &expr.end) { + (RangeBound::Const(lhs), RangeBound::Const(rhs)) => Ok(lhs.item..rhs.item), + _ => Err(InvalidExprError::NonConstantRangeExpr(expr.span)), + } + } +} + +impl RangeExpr { + pub fn is_constant(&self) -> bool { + self.start.is_constant() && self.end.is_constant() + } + + /// Converts this range expression to a `Range` type, assuming it is constant. + /// Panics if the range is not constant. + pub fn to_slice_range(&self) -> Range { + self.try_into() + .expect("attempted to convert non-constant range expression to constant") + } + + pub fn ty(&self) -> Option { + match (&self.start, &self.end) { + (RangeBound::Const(start), RangeBound::Const(end)) => { + Some(Type::Vector(end.item.abs_diff(start.item))) + } + _ => None, + } + } +} +impl From for RangeExpr { + fn from(range: Range) -> Self { + Self { + span: SourceSpan::default(), + start: RangeBound::Const(Span::new(SourceSpan::UNKNOWN, range.start)), + end: RangeBound::Const(Span::new(SourceSpan::UNKNOWN, range.end)), + } + } +} +impl Eq for RangeExpr {} +impl PartialEq for RangeExpr { + fn eq(&self, other: &Self) -> bool { + self.start.eq(&other.start) && self.end.eq(&other.end) + } +} +impl fmt::Display for RangeExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}..{}", &self.start, &self.end) + } +} + +#[derive(Clone, Spanned, PartialEq, Eq, Debug)] +pub enum RangeBound { + SymbolAccess(ConstSymbolAccess), + Const(Span), +} +impl RangeBound { + pub fn is_constant(&self) -> bool { + matches!(self, Self::Const(_)) + } +} +impl From for RangeBound { + fn from(name: Identifier) -> Self { + Self::SymbolAccess(ConstSymbolAccess::new(name.span(), name)) + } +} +impl From for RangeBound { + fn from(constant: usize) -> Self { + Self::Const(Span::new(SourceSpan::UNKNOWN, constant)) + } +} +impl fmt::Display for RangeBound { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::SymbolAccess(sym) => write!(f, "{sym}"), + Self::Const(constant) => write!(f, "{constant}"), + } + } +} + /// Represents an expression requiring evaluation of a binary operator #[derive(Clone, Spanned)] pub struct BinaryExpr { @@ -689,7 +813,7 @@ pub enum AccessType { /// Access refers to the entire bound value Default, /// Access binds a sub-slice of a vector - Slice(Range), + Slice(RangeExpr), /// Access binds the value at a specific index of an aggregate value (i.e. vector or matrix) /// /// The result type may be either a scalar or a vector, depending on the type of the aggregate @@ -783,7 +907,9 @@ impl SymbolAccess { pub fn access(&self, access_type: AccessType) -> Result { match &self.access_type { AccessType::Default => self.access_default(access_type), - AccessType::Slice(base_range) => self.access_slice(base_range.clone(), access_type), + AccessType::Slice(base_range) => { + self.access_slice(base_range.to_slice_range(), access_type) + } AccessType::Index(base_idx) => self.access_index(*base_idx, access_type), AccessType::Matrix(_, _) => match access_type { AccessType::Default => Ok(self.clone()), @@ -812,10 +938,11 @@ impl SymbolAccess { }), }, AccessType::Slice(range) => { - let rlen = range.end - range.start; + let slice_range = range.to_slice_range(); + let rlen = slice_range.end - slice_range.start; match ty { Type::Felt => Err(InvalidAccessError::IndexIntoScalar), - Type::Vector(len) if range.end > len => { + Type::Vector(len) if slice_range.end > len => { Err(InvalidAccessError::IndexOutOfBounds) } Type::Vector(_) => Ok(Self { @@ -823,7 +950,7 @@ impl SymbolAccess { ty: Some(Type::Vector(rlen)), ..self.clone() }), - Type::Matrix(rows, _) if range.end > rows => { + Type::Matrix(rows, _) if slice_range.end > rows => { Err(InvalidAccessError::IndexOutOfBounds) } Type::Matrix(_, cols) => Ok(Self { @@ -871,14 +998,19 @@ impl SymbolAccess { }), }, AccessType::Slice(range) => { + let slice_range = range.to_slice_range(); let blen = base_range.end - base_range.start; - let rlen = range.end - range.start; - let start = base_range.start + range.start; - let end = range.start + range.end; - let shifted = start..end; + let rlen = slice_range.len(); + let start = base_range.start + slice_range.start; + let end = slice_range.start + slice_range.end; + let shifted = RangeExpr { + span: range.span, + start: RangeBound::Const(Span::new(range.start.span(), start)), + end: RangeBound::Const(Span::new(range.end.span(), end)), + }; match ty { Type::Felt => unreachable!(), - Type::Vector(_) if range.end > blen => { + Type::Vector(_) if slice_range.end > blen => { Err(InvalidAccessError::IndexOutOfBounds) } Type::Vector(_) => Ok(Self { @@ -886,7 +1018,7 @@ impl SymbolAccess { ty: Some(Type::Vector(rlen)), ..self.clone() }), - Type::Matrix(rows, _) if range.end > rows => { + Type::Matrix(rows, _) if slice_range.end > rows => { Err(InvalidAccessError::IndexOutOfBounds) } Type::Matrix(_, cols) => Ok(Self { diff --git a/parser/src/ast/trace.rs b/parser/src/ast/trace.rs index 2e92673e..6413909d 100644 --- a/parser/src/ast/trace.rs +++ b/parser/src/ast/trace.rs @@ -255,18 +255,20 @@ impl TraceBinding { match access_type { AccessType::Default => Ok(*self), AccessType::Slice(_) if self.is_scalar() => Err(InvalidAccessError::SliceOfScalar), - AccessType::Slice(range) if range.end > self.size => { - Err(InvalidAccessError::IndexOutOfBounds) - } AccessType::Slice(range) => { - let offset = self.offset + range.start; - let size = range.end - range.start; - Ok(Self { - offset, - size, - ty: Type::Vector(size), - ..*self - }) + let slice_range = range.to_slice_range(); + if slice_range.end > self.size { + Err(InvalidAccessError::IndexOutOfBounds) + } else { + let offset = self.offset + slice_range.start; + let size = slice_range.len(); + Ok(Self { + offset, + size, + ty: Type::Vector(size), + ..*self + }) + } } AccessType::Index(_) if self.is_scalar() => Err(InvalidAccessError::IndexIntoScalar), AccessType::Index(idx) if idx >= self.size => Err(InvalidAccessError::IndexOutOfBounds), diff --git a/parser/src/ast/types.rs b/parser/src/ast/types.rs index 96620db8..995ab6d5 100644 --- a/parser/src/ast/types.rs +++ b/parser/src/ast/types.rs @@ -44,20 +44,28 @@ impl Type { ty if access_type == AccessType::Default => Ok(ty), Self::Felt => Err(InvalidAccessError::IndexIntoScalar), Self::Vector(len) => match access_type { - AccessType::Slice(range) if range.end > len => { - Err(InvalidAccessError::IndexOutOfBounds) + AccessType::Slice(range) => { + let slice_range = range.to_slice_range(); + if slice_range.end > len { + Err(InvalidAccessError::IndexOutOfBounds) + } else { + Ok(Self::Vector(slice_range.len())) + } } - AccessType::Slice(range) => Ok(Self::Vector(range.end - range.start)), AccessType::Index(idx) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), AccessType::Index(_) => Ok(Self::Felt), AccessType::Matrix(_, _) => Err(InvalidAccessError::IndexIntoScalar), _ => unreachable!(), }, Self::Matrix(rows, cols) => match access_type { - AccessType::Slice(range) if range.end > rows => { - Err(InvalidAccessError::IndexOutOfBounds) + AccessType::Slice(range) => { + let slice_range = range.to_slice_range(); + if slice_range.end > rows { + Err(InvalidAccessError::IndexOutOfBounds) + } else { + Ok(Self::Matrix(slice_range.len(), cols)) + } } - AccessType::Slice(range) => Ok(Self::Matrix(range.end - range.start, cols)), AccessType::Index(idx) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), AccessType::Index(_) => Ok(Self::Vector(cols)), AccessType::Matrix(row, col) if row >= rows || col >= cols => { diff --git a/parser/src/ast/visit.rs b/parser/src/ast/visit.rs index 2bb33f0e..15432017 100644 --- a/parser/src/ast/visit.rs +++ b/parser/src/ast/visit.rs @@ -208,6 +208,18 @@ pub trait VisitMut { fn visit_mut_call(&mut self, expr: &mut ast::Call) -> ControlFlow { visit_mut_call(self, expr) } + fn visit_mut_range_bound(&mut self, expr: &mut ast::RangeBound) -> ControlFlow { + visit_mut_range_bound(self, expr) + } + fn visit_mut_access_type(&mut self, expr: &mut ast::AccessType) -> ControlFlow { + visit_mut_access_type(self, expr) + } + fn visit_mut_const_symbol_access( + &mut self, + expr: &mut ast::ConstSymbolAccess, + ) -> ControlFlow { + visit_mut_const_symbol_access(self, expr) + } fn visit_mut_bounded_symbol_access( &mut self, expr: &mut ast::BoundedSymbolAccess, @@ -338,6 +350,18 @@ where fn visit_mut_call(&mut self, expr: &mut ast::Call) -> ControlFlow { (**self).visit_mut_call(expr) } + fn visit_mut_range_bound(&mut self, expr: &mut ast::RangeBound) -> ControlFlow { + (**self).visit_mut_range_bound(expr) + } + fn visit_mut_access_type(&mut self, expr: &mut ast::AccessType) -> ControlFlow { + (**self).visit_mut_access_type(expr) + } + fn visit_mut_const_symbol_access( + &mut self, + expr: &mut ast::ConstSymbolAccess, + ) -> ControlFlow { + (**self).visit_mut_const_symbol_access(expr) + } fn visit_mut_bounded_symbol_access( &mut self, expr: &mut ast::BoundedSymbolAccess, @@ -582,7 +606,12 @@ where V: ?Sized + VisitMut, { match expr { - ast::Expr::Const(_) | ast::Expr::Range(_) => ControlFlow::Continue(()), + ast::Expr::Const(_) => ControlFlow::Continue(()), + ast::Expr::Range(ref mut range) => { + visitor.visit_mut_range_bound(&mut range.start)?; + visitor.visit_mut_range_bound(&mut range.end)?; + ControlFlow::Continue(()) + } ast::Expr::Vector(ref mut exprs) => { for expr in exprs.iter_mut() { visitor.visit_mut_expr(expr)?; @@ -659,6 +688,43 @@ where ControlFlow::Continue(()) } +pub fn visit_mut_range_bound(visitor: &mut V, expr: &mut ast::RangeBound) -> ControlFlow +where + V: ?Sized + VisitMut, +{ + match expr { + ast::RangeBound::Const(_) => ControlFlow::Continue(()), + ast::RangeBound::SymbolAccess(ref mut access) => { + visitor.visit_mut_const_symbol_access(access) + } + } +} + +pub fn visit_mut_access_type(visitor: &mut V, expr: &mut ast::AccessType) -> ControlFlow +where + V: ?Sized + VisitMut, +{ + match expr { + ast::AccessType::Default | ast::AccessType::Index(_) | ast::AccessType::Matrix(_, _) => { + ControlFlow::Continue(()) + } + ast::AccessType::Slice(ref mut range) => { + visitor.visit_mut_range_bound(&mut range.start)?; + visitor.visit_mut_range_bound(&mut range.end) + } + } +} + +pub fn visit_mut_const_symbol_access( + visitor: &mut V, + expr: &mut ast::ConstSymbolAccess, +) -> ControlFlow +where + V: ?Sized + VisitMut, +{ + visitor.visit_mut_resolvable_identifier(&mut expr.name) +} + pub fn visit_mut_bounded_symbol_access( visitor: &mut V, expr: &mut ast::BoundedSymbolAccess, diff --git a/parser/src/parser/grammar.lalrpop b/parser/src/parser/grammar.lalrpop index c7669bc0..6a235119 100644 --- a/parser/src/parser/grammar.lalrpop +++ b/parser/src/parser/grammar.lalrpop @@ -351,7 +351,7 @@ MatchArm: Statement = { let generated_name = format!("%{}", *next_var); *next_var += 1; let generated_binding = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern(generated_name)); - let context = vec![(generated_binding, Expr::Range(Span::new(SourceSpan::UNKNOWN, 0..1)))]; + let context = vec![(generated_binding, Expr::Range(RangeExpr::from(0..1)))]; Statement::EnforceAll(ListComprehension::new(span!(l, r), constraint, context, Some(selector))) } } @@ -396,7 +396,7 @@ ConstraintExpr: Statement = { let generated_name = format!("%{}", *next_var); *next_var += 1; let generated_binding = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern(generated_name)); - let context = vec![(generated_binding, Expr::Range(Span::new(SourceSpan::UNKNOWN, 0..1)))]; + let context = vec![(generated_binding, Expr::Range(RangeExpr::from(0..1)))]; Statement::EnforceAll(ListComprehension::new(span!(l, r), expr, context, selector)) } else { Statement::Enforce(expr) @@ -555,7 +555,7 @@ Iterables: Vec = { Iterable: Expr = { => Expr::SymbolAccess(SymbolAccess::new(ident.span(), ident, AccessType::Default, 0)), - => Expr::Range(Span::new(span!(l, r), range)), + => Expr::Range(range), "[" "]" => Expr::SymbolAccess(SymbolAccess::new(span!(l, r), ident, AccessType::Slice(range), 0)), => if let ScalarExpr::Call(call) = function_call { Expr::Call(call) @@ -564,8 +564,13 @@ Iterable: Expr = { } } -Range: Range = { - ".." => Range { start: start as usize, end: end as usize } +Range: RangeExpr = { + ".." => RangeExpr { span: span!(l, r), start, end } +} + +RangeBound: RangeBound = { + => RangeBound::Const(Span::new(span!(l, r), value as usize)), + => RangeBound::SymbolAccess(ConstSymbolAccess::new(span!(l, r), name)), } // ATOMS diff --git a/parser/src/parser/tests/constant_propagation.rs b/parser/src/parser/tests/constant_propagation.rs index 7e705b0c..016c19fa 100644 --- a/parser/src/parser/tests/constant_propagation.rs +++ b/parser/src/parser/tests/constant_propagation.rs @@ -24,6 +24,7 @@ fn test_constant_propagation() { const A = [2, 4, 6, 8]; const B = [[1, 1], [2, 2]]; + const TWO = 2; integrity_constraints { enf test_constraint(b); diff --git a/parser/src/parser/tests/list_comprehension.rs b/parser/src/parser/tests/list_comprehension.rs index c11474c1..f4bca5b8 100644 --- a/parser/src/parser/tests/list_comprehension.rs +++ b/parser/src/parser/tests/list_comprehension.rs @@ -59,6 +59,8 @@ fn bc_identifier_and_range_lc() { let source = " def test + const THREE = 3; + trace_columns { main: [a, b, c[4]], } @@ -72,11 +74,14 @@ fn bc_identifier_and_range_lc() { } boundary_constraints { - let x = [2^i * c for (i, c) in (0..3, c)]; + let x = [2^i * c for (i, c) in (0..THREE, c)]; enf a.first = x[0] + x[1] + x[2] + x[3]; }"; let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + expected + .constants + .insert(ident!(THREE), constant!(THREE = 3)); expected .trace_columns .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4)])); @@ -91,7 +96,7 @@ fn bc_identifier_and_range_lc() { expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![ - let_!(x = lc!(((i, range!(0..3)), (c, expr!(access!(c)))) => mul!(exp!(int!(2), access!(i)), access!(c))).into() => + let_!(x = lc!(((i, range!(0usize, ident!(THREE))), (c, expr!(access!(c)))) => mul!(exp!(int!(2), access!(i)), access!(c))).into() => enforce!(eq!(bounded_access!(a, Boundary::First), add!(add!(add!(access!(x[0]), access!(x[1])), access!(x[2])), access!(x[3]))))), ], )); diff --git a/parser/src/parser/tests/mod.rs b/parser/src/parser/tests/mod.rs index d570c011..1011e208 100644 --- a/parser/src/parser/tests/mod.rs +++ b/parser/src/parser/tests/mod.rs @@ -292,7 +292,7 @@ macro_rules! slice { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(ident!($name))), - access_type: AccessType::Slice($range), + access_type: AccessType::Slice($range.into()), offset: 0, ty: None, }) @@ -302,7 +302,7 @@ macro_rules! slice { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ResolvableIdentifier::Local(ident!($name)), - access_type: AccessType::Slice($range), + access_type: AccessType::Slice($range.into()), offset: 0, ty: Some($ty), }) @@ -522,7 +522,14 @@ macro_rules! lc { macro_rules! range { ($range:expr) => { - Expr::Range(miden_diagnostics::Span::new(SourceSpan::UNKNOWN, $range)) + Expr::Range($range.into()) + }; + ($start:expr, $end:expr) => { + Expr::Range(RangeExpr { + span: miden_diagnostics::SourceSpan::UNKNOWN, + start: $start.into(), + end: $end.into(), + }) }; } diff --git a/parser/src/sema/binding_type.rs b/parser/src/sema/binding_type.rs index a0b2cc55..db389378 100644 --- a/parser/src/sema/binding_type.rs +++ b/parser/src/sema/binding_type.rs @@ -196,11 +196,13 @@ impl BindingType { Err(InvalidAccessError::IndexOutOfBounds) } AccessType::Index(idx) => Ok(elems[idx].clone()), - AccessType::Slice(range) if range.end > elems.len() => { - Err(InvalidAccessError::IndexOutOfBounds) - } AccessType::Slice(range) => { - Ok(Self::Vector(elems[range.start..range.end].to_vec())) + let slice_range = range.to_slice_range(); + if slice_range.end > elems.len() { + Err(InvalidAccessError::IndexOutOfBounds) + } else { + Ok(Self::Vector(elems[slice_range].to_vec())) + } } AccessType::Matrix(row, _) if row >= elems.len() => { Err(InvalidAccessError::IndexOutOfBounds) diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index 37861305..bae4533c 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -1,5 +1,5 @@ use std::{ - collections::{HashMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, fmt, mem, ops::ControlFlow, }; @@ -63,6 +63,7 @@ pub struct SemanticAnalysis<'a> { deps: &'a mut DependencyGraph, imported: Imported, globals: HashMap, + constants: BTreeMap, locals: LexicalScope, referenced: HashMap, current_module: Option, @@ -88,6 +89,7 @@ impl<'a> SemanticAnalysis<'a> { deps, imported, globals: Default::default(), + constants: Default::default(), locals: Default::default(), referenced: Default::default(), current_module: None, @@ -138,6 +140,14 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { fn visit_mut_module(&mut self, module: &mut Module) -> ControlFlow { self.current_module = Some(module.name); + // Collect the values of all named constants that can be referenced in range declarations + self.constants.extend( + module + .constants + .iter() + .map(|(id, c)| (*id, c.value.clone())), + ); + // Register all globals implicitly defined in the module before all locally bound names // // Currently this consists only of the `random_values` declarations. @@ -733,6 +743,98 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { } } + fn visit_mut_range_bound( + &mut self, + expr: &mut crate::ast::RangeBound, + ) -> ControlFlow { + match expr { + crate::ast::RangeBound::Const(_) => ControlFlow::Continue(()), + crate::ast::RangeBound::SymbolAccess(ref mut access) => { + self.visit_mut_const_symbol_access(access)?; + // The identifier must have been resolved to reach here + let qid = access.name.resolved().unwrap(); + let value = if self.current_module.as_ref() == Some(&qid.module) { + &self.constants[&qid.item.id()] + } else { + &self.library.modules[&qid.module].constants[&qid.item.id()].value + }; + match value { + ConstantExpr::Scalar(value) => { + let value = usize::try_from(*value).map_err(|err| { + self.diagnostics + .diagnostic(Severity::Error) + .with_primary_label( + expr.span(), + format!("constant is not a valid range bound: {err}"), + ) + .emit(); + SemanticAnalysisError::Invalid + }); + match value { + Ok(value) => { + *expr = + crate::ast::RangeBound::Const(Span::new(expr.span(), value)); + ControlFlow::Continue(()) + } + Err(err) => ControlFlow::Break(err), + } + } + const_expr => { + self.diagnostics + .diagnostic(Severity::Error) + .with_primary_label( + expr.span(), + format!( + "constant is not a valid range bound: expected scalar, got {}", + const_expr.ty() + ), + ) + .emit(); + ControlFlow::Break(SemanticAnalysisError::Invalid) + } + } + } + } + } + + fn visit_mut_const_symbol_access( + &mut self, + expr: &mut crate::ast::ConstSymbolAccess, + ) -> ControlFlow { + self.visit_mut_resolvable_identifier(&mut expr.name)?; + + // The identifier must have been resolved to reach here + let binding_ty = match self.resolvable_binding_type(&expr.name) { + Ok(ty) => ty, + Err(err) => { + self.diagnostics + .diagnostic(Severity::Error) + .with_primary_label(expr.span, format!("invalid constant identifier: {err}")) + .emit(); + return ControlFlow::Break(SemanticAnalysisError::Invalid); + } + }; + match binding_ty.item { + BindingType::Constant(ty) => { + expr.ty = Some(ty); + ControlFlow::Continue(()) + } + binding_ty => { + self.diagnostics + .diagnostic(Severity::Error) + .with_primary_label( + expr.span, + format!( + "invalid constant identifier '{}', got binding of type {binding_ty}", + &expr.name + ), + ) + .emit(); + ControlFlow::Break(SemanticAnalysisError::Invalid) + } + } + } + fn visit_mut_bounded_symbol_access( &mut self, expr: &mut BoundedSymbolAccess, @@ -756,6 +858,7 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { expr: &mut SymbolAccess, ) -> ControlFlow { self.visit_mut_resolvable_identifier(&mut expr.name)?; + self.visit_mut_access_type(&mut expr.access_type)?; let resolved_binding_ty = match self.resolvable_binding_type(&expr.name) { Ok(ty) => ty, @@ -823,7 +926,8 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { // with a single column is dependent on the access type. A slice of columns of size 1 must // be captured as a vector of size 1 AccessType::Slice(ref range) => { - assert_eq!(expr.ty.replace(Type::Vector(range.end - range.start)), None) + let range = range.to_slice_range(); + assert_eq!(expr.ty.replace(Type::Vector(range.len())), None) } // All other access types can be derived from the binding type _ => assert_eq!(expr.ty.replace(binding_ty.ty().unwrap()), None), @@ -840,7 +944,10 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { .emit(); // Continue with a fabricated type let ty = match &expr.access_type { - AccessType::Slice(ref range) => Type::Vector(range.end - range.start), + AccessType::Slice(ref range) => { + let range = range.to_slice_range(); + Type::Vector(range.len()) + } _ => Type::Felt, }; assert_eq!(expr.ty.replace(ty), None); @@ -1533,7 +1640,9 @@ impl<'a> SemanticAnalysis<'a> { fn expr_binding_type(&self, expr: &Expr) -> Result { match expr { Expr::Const(constant) => Ok(BindingType::Local(constant.ty())), - Expr::Range(range) => Ok(BindingType::Local(Type::Vector(range.end - range.start))), + Expr::Range(range) => Ok(BindingType::Local(Type::Vector( + range.to_slice_range().len(), + ))), Expr::Vector(ref elems) => { let mut binding_tys = Vec::with_capacity(elems.len()); for elem in elems.iter() { diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index cc08998e..28aedd25 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -127,11 +127,11 @@ impl<'a> ConstantPropagation<'a> { self.local.insert(expr.name, value.clone()); } Expr::Range(ref range) => { - let vector = range.item.clone().map(|i| i as u64).collect(); - self.local.insert( - expr.name, - Span::new(range.span(), ConstantExpr::Vector(vector)), - ); + let span = range.span(); + let range = range.to_slice_range(); + let vector = range.map(|i| i as u64).collect(); + self.local + .insert(expr.name, Span::new(span, ConstantExpr::Vector(vector))); } _ => unreachable!(), } @@ -316,7 +316,8 @@ impl<'a> VisitMut for ConstantPropagation<'a> { *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(value))); } AccessType::Slice(range) => { - let vector = value[range.start..range.end].to_vec(); + let range = range.to_slice_range(); + let vector = value[range].to_vec(); *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(vector))); } AccessType::Index(idx) => { @@ -333,7 +334,8 @@ impl<'a> VisitMut for ConstantPropagation<'a> { *expr = Expr::Const(Span::new(span, ConstantExpr::Matrix(value))); } AccessType::Slice(range) => { - let matrix = value[range.start..range.end].to_vec(); + let range = range.to_slice_range(); + let matrix = value[range].to_vec(); *expr = Expr::Const(Span::new(span, ConstantExpr::Matrix(matrix))); } AccessType::Index(idx) => { @@ -506,7 +508,7 @@ impl<'a> VisitMut for ConstantPropagation<'a> { .. }) => rows.len(), Expr::Const(_) => panic!("expected iterable constant, got scalar"), - Expr::Range(range) => range.end - range.start, + Expr::Range(range) => range.to_slice_range().len(), _ => unreachable!(), }; @@ -532,6 +534,7 @@ impl<'a> VisitMut for ConstantPropagation<'a> { self.local.insert(binding, Span::new(span, value)); } Expr::Range(range) => { + let range = range.to_slice_range(); assert!(range.end > range.start + step); let value = ConstantExpr::Scalar((range.start + step) as u64); self.local.insert(binding, Span::new(span, value)); diff --git a/parser/src/transforms/inlining.rs b/parser/src/transforms/inlining.rs index f4221517..846dd854 100644 --- a/parser/src/transforms/inlining.rs +++ b/parser/src/transforms/inlining.rs @@ -950,6 +950,7 @@ impl<'a> Inlining<'a> { // Ranges are constant, so same rules as above apply here Expr::Range(range) => { let span = range.span(); + let range = range.to_slice_range(); let binding_ty = BindingType::Constant(Type::Felt); self.bindings.insert(binding, binding_ty); Expr::Const(Span::new( @@ -1538,7 +1539,7 @@ impl<'a> Inlining<'a> { } else { let start = tb.offset - original_binding.offset; ( - AccessType::Slice(start..(start + tb.size)), + AccessType::Slice(RangeExpr::from(start..(start + tb.size))), Type::Vector(tb.size), ) }; @@ -1580,7 +1581,9 @@ fn eval_expr_binding_type( ) -> Result { match expr { Expr::Const(constant) => Ok(BindingType::Local(constant.ty())), - Expr::Range(range) => Ok(BindingType::Local(Type::Vector(range.end - range.start))), + Expr::Range(range) => Ok(BindingType::Local(Type::Vector( + range.to_slice_range().len(), + ))), Expr::Vector(ref elems) => match elems[0].ty() { None | Some(Type::Felt) => { let mut binding_tys = Vec::with_capacity(elems.len()); @@ -1692,7 +1695,7 @@ impl<'a> RewriteIterableBindingsVisitor<'a> { } Some(Expr::Range(range)) => { let span = range.span(); - let range = range.item.clone(); + let range = range.to_slice_range(); match access.access_type { AccessType::Index(idx) => Some(ScalarExpr::Const(Span::new( span,