Skip to content

Commit

Permalink
chore: test boxing of fields in brillig vm
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench committed Jan 2, 2025
1 parent 7566b0f commit 5a8eea8
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 24 deletions.
2 changes: 1 addition & 1 deletion acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ fn extract_failure_payload_from_memory<F: AcirField>(
let error_selector = ErrorSelector::new(
revert_values_iter
.next()
.copied()
.cloned()
.expect("Incorrect revert data size")
.try_into()
.expect("Error selector is not u64"),
Expand Down
4 changes: 2 additions & 2 deletions acvm-repo/brillig_vm/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub(crate) fn evaluate_binary_field_op<F: AcirField>(
rhs: MemoryValue<F>,
) -> Result<MemoryValue<F>, BrilligArithmeticError> {
let a = match lhs {
MemoryValue::Field(a) => a,
MemoryValue::Field(a) => *a,
MemoryValue::Integer(_, bit_size) => {
return Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: bit_size.into(),
Expand All @@ -30,7 +30,7 @@ pub(crate) fn evaluate_binary_field_op<F: AcirField>(
}
};
let b = match rhs {
MemoryValue::Field(b) => b,
MemoryValue::Field(b) => *b,
MemoryValue::Integer(_, bit_size) => {
return Err(BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: bit_size.into(),
Expand Down
14 changes: 7 additions & 7 deletions acvm-repo/brillig_vm/src/black_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn read_heap_array<'a, F: AcirField>(
/// Extracts the last byte of every value
fn to_u8_vec<F: AcirField>(inputs: &[MemoryValue<F>]) -> Vec<u8> {
let mut result = Vec::with_capacity(inputs.len());
for &input in inputs {
for input in inputs {
result.push(input.try_into().unwrap());
}
result
Expand Down Expand Up @@ -81,7 +81,7 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
BlackBoxOp::Keccakf1600 { input, output } => {
let state_vec: Vec<u64> = read_heap_array(memory, input)
.iter()
.map(|&memory_value| memory_value.try_into().unwrap())
.map(|memory_value| memory_value.try_into().unwrap())
.collect();
let state: [u64; 25] = state_vec.try_into().unwrap();

Expand Down Expand Up @@ -145,7 +145,7 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
let points: Vec<F> = read_heap_vector(memory, points)
.iter()
.enumerate()
.map(|(i, &x)| {
.map(|(i, x)| {
if i % 3 == 2 {
let is_infinite: bool = x.try_into().unwrap();
F::from(is_infinite as u128)
Expand Down Expand Up @@ -245,9 +245,9 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
}
BlackBoxOp::BigIntFromLeBytes { inputs, modulus, output } => {
let input = read_heap_vector(memory, inputs);
let input: Vec<u8> = input.iter().map(|&x| x.try_into().unwrap()).collect();
let input: Vec<u8> = input.iter().map(|x| x.try_into().unwrap()).collect();
let modulus = read_heap_vector(memory, modulus);
let modulus: Vec<u8> = modulus.iter().map(|&x| x.try_into().unwrap()).collect();
let modulus: Vec<u8> = modulus.iter().map(|x| x.try_into().unwrap()).collect();

let new_id = bigint_solver.bigint_from_bytes(&input, &modulus)?;
memory.write(*output, new_id.into());
Expand Down Expand Up @@ -289,7 +289,7 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
format!("Expected 16 inputs but encountered {}", &inputs.len()),
));
}
for (i, &input) in inputs.iter().enumerate() {
for (i, input) in inputs.iter().enumerate() {
message[i] = input.try_into().unwrap();
}
let mut state = [0; 8];
Expand All @@ -300,7 +300,7 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
format!("Expected 8 values but encountered {}", &values.len()),
));
}
for (i, &value) in values.iter().enumerate() {
for (i, value) in values.iter().enumerate() {
state[i] = value.try_into().unwrap();
}

Expand Down
8 changes: 4 additions & 4 deletions acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> VM<'a, F, B> {

/// Casts a value to a different bit size.
fn cast(&self, target_bit_size: BitSize, source_value: MemoryValue<F>) -> MemoryValue<F> {
match (source_value, target_bit_size) {
match (&source_value, target_bit_size) {
// Field to field, no op
(MemoryValue::Field(_), BitSize::Field) => source_value,
// Field downcast to u128
Expand All @@ -817,13 +817,13 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> VM<'a, F, B> {
}
// Integer upcast to field
(MemoryValue::Integer(integer, _), BitSize::Field) => {
MemoryValue::new_field(integer.into())
MemoryValue::new_field((*integer).into())
}
// Integer upcast to integer
(MemoryValue::Integer(integer, source_bit_size), BitSize::Integer(target_bit_size))
if source_bit_size <= target_bit_size =>
if *source_bit_size <= target_bit_size =>
{
MemoryValue::Integer(integer, target_bit_size)
MemoryValue::Integer(*integer, target_bit_size)
}
// Integer downcast
(MemoryValue::Integer(integer, _), BitSize::Integer(target_bit_size)) => {
Expand Down
56 changes: 50 additions & 6 deletions acvm-repo/brillig_vm/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use num_traits::{One, Zero};

pub const MEMORY_ADDRESSING_BIT_SIZE: IntegerBitSize = IntegerBitSize::U32;

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum MemoryValue<F> {
Field(F),
Field(Box<F>),
Integer(u128, IntegerBitSize),
}

Expand All @@ -21,7 +21,7 @@ pub enum MemoryTypeError {
impl<F> MemoryValue<F> {
/// Builds a field-typed memory value.
pub fn new_field(value: F) -> Self {
MemoryValue::Field(value)
MemoryValue::Field(Box::new(value))
}

/// Builds an integer-typed memory value.
Expand Down Expand Up @@ -86,7 +86,7 @@ impl<F: AcirField> MemoryValue<F> {
/// Converts the memory value to a field element, independent of its type.
pub fn to_field(&self) -> F {
match self {
MemoryValue::Field(value) => *value,
MemoryValue::Field(value) => **value,
MemoryValue::Integer(value, _) => F::from(*value),
}
}
Expand Down Expand Up @@ -181,6 +181,14 @@ impl<F: AcirField> TryFrom<MemoryValue<F>> for bool {
type Error = MemoryTypeError;

fn try_from(memory_value: MemoryValue<F>) -> Result<Self, Self::Error> {
bool::try_from(&memory_value)
}
}

impl<F: AcirField> TryFrom<&MemoryValue<F>> for bool {
type Error = MemoryTypeError;

fn try_from(memory_value: &MemoryValue<F>) -> Result<Self, Self::Error> {
let as_integer = memory_value.expect_integer_with_bit_size(IntegerBitSize::U1)?;

if as_integer.is_zero() {
Expand All @@ -197,6 +205,15 @@ impl<F: AcirField> TryFrom<MemoryValue<F>> for u8 {
type Error = MemoryTypeError;

fn try_from(memory_value: MemoryValue<F>) -> Result<Self, Self::Error> {
u8::try_from(&memory_value)
}
}


impl<F: AcirField> TryFrom<&MemoryValue<F>> for u8 {
type Error = MemoryTypeError;

fn try_from(memory_value: &MemoryValue<F>) -> Result<Self, Self::Error> {
memory_value.expect_integer_with_bit_size(IntegerBitSize::U8).map(|value| value as u8)
}
}
Expand All @@ -205,6 +222,15 @@ impl<F: AcirField> TryFrom<MemoryValue<F>> for u32 {
type Error = MemoryTypeError;

fn try_from(memory_value: MemoryValue<F>) -> Result<Self, Self::Error> {
u32::try_from(&memory_value)
}
}


impl<F: AcirField> TryFrom<&MemoryValue<F>> for u32 {
type Error = MemoryTypeError;

fn try_from(memory_value: &MemoryValue<F>) -> Result<Self, Self::Error> {
memory_value.expect_integer_with_bit_size(IntegerBitSize::U32).map(|value| value as u32)
}
}
Expand All @@ -213,6 +239,15 @@ impl<F: AcirField> TryFrom<MemoryValue<F>> for u64 {
type Error = MemoryTypeError;

fn try_from(memory_value: MemoryValue<F>) -> Result<Self, Self::Error> {
u64::try_from(&memory_value)
}
}


impl<F: AcirField> TryFrom<&MemoryValue<F>> for u64 {
type Error = MemoryTypeError;

fn try_from(memory_value: &MemoryValue<F>) -> Result<Self, Self::Error> {
memory_value.expect_integer_with_bit_size(IntegerBitSize::U64).map(|value| value as u64)
}
}
Expand All @@ -221,6 +256,15 @@ impl<F: AcirField> TryFrom<MemoryValue<F>> for u128 {
type Error = MemoryTypeError;

fn try_from(memory_value: MemoryValue<F>) -> Result<Self, Self::Error> {
u128::try_from(&memory_value)
}
}


impl<F: AcirField> TryFrom<&MemoryValue<F>> for u128 {
type Error = MemoryTypeError;

fn try_from(memory_value: &MemoryValue<F>) -> Result<Self, Self::Error> {
memory_value.expect_integer_with_bit_size(IntegerBitSize::U128)
}
}
Expand All @@ -247,7 +291,7 @@ impl<F: AcirField> Memory<F> {
/// Gets the value at address
pub fn read(&self, address: MemoryAddress) -> MemoryValue<F> {
let resolved_addr = self.resolve(address);
self.inner.get(resolved_addr).copied().unwrap_or_default()
self.inner.get(resolved_addr).cloned().unwrap_or_default()
}

pub fn read_ref(&self, ptr: MemoryAddress) -> MemoryAddress {
Expand Down Expand Up @@ -283,7 +327,7 @@ impl<F: AcirField> Memory<F> {
pub fn write_slice(&mut self, address: MemoryAddress, values: &[MemoryValue<F>]) {
let resolved_address = self.resolve(address);
self.resize_to_fit(resolved_address + values.len());
self.inner[resolved_address..(resolved_address + values.len())].copy_from_slice(values);
self.inner[resolved_address..(resolved_address + values.len())].clone_from_slice(values);
}

/// Returns the values of the memory
Expand Down
6 changes: 3 additions & 3 deletions compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -659,12 +659,12 @@ impl<'brillig> Context<'brillig> {
) -> ValueId {
match typ {
Type::Numeric(typ) => {
let memory = memory_values[*memory_index];
let memory = &memory_values[*memory_index];
*memory_index += 1;

let field_value = match memory {
MemoryValue::Field(field_value) => field_value,
MemoryValue::Integer(u128_value, _) => u128_value.into(),
MemoryValue::Field(field_value) => **field_value,
MemoryValue::Integer(u128_value, _) => (*u128_value).into(),
};
dfg.make_constant(field_value, typ)
}
Expand Down
2 changes: 1 addition & 1 deletion tooling/debugger/src/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ impl<'a, B: BlackBoxFunctionSolver<FieldElement>> ReplDebugger<'a, B> {
for (index, value) in memory.iter().enumerate() {
// Zero field is the default value, we omit it when printing memory
if let MemoryValue::Field(field) = value {
if field == &FieldElement::zero() {
if **field == FieldElement::zero() {
continue;
}
}
Expand Down

0 comments on commit 5a8eea8

Please sign in to comment.