From 10a9f8104e1ebb8fad044927ff130a0e2ce9131b Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Tue, 26 Nov 2024 00:08:08 +0000 Subject: [PATCH] feat(comptime): Implement blackbox functions in comptime interpreter (#6551) Co-authored-by: jfecher Co-authored-by: Maxim Vezenov --- Cargo.toml | 2 + acvm-repo/acir/Cargo.toml | 4 +- .../acir/src/circuit/black_box_functions.rs | 4 +- acvm-repo/acvm/Cargo.toml | 11 +- acvm-repo/blackbox_solver/src/bigint.rs | 48 ++ acvm-repo/blackbox_solver/src/lib.rs | 2 +- acvm-repo/brillig_vm/src/black_box.rs | 57 +-- acvm-repo/brillig_vm/src/lib.rs | 4 +- compiler/noirc_frontend/Cargo.toml | 4 +- .../noirc_frontend/src/hir/comptime/errors.rs | 10 + .../src/hir/comptime/interpreter.rs | 455 +++++++----------- .../src/hir/comptime/interpreter/builtin.rs | 38 +- .../interpreter/builtin/builtin_helpers.rs | 160 +++++- .../src/hir/comptime/interpreter/foreign.rs | 421 ++++++++++++++-- .../noirc_frontend/src/hir/comptime/tests.rs | 26 +- .../comptime_blackbox/Nargo.toml | 7 + .../comptime_blackbox/src/main.nr | 155 ++++++ tooling/noirc_abi/Cargo.toml | 4 +- 18 files changed, 988 insertions(+), 424 deletions(-) create mode 100644 test_programs/noir_test_success/comptime_blackbox/Nargo.toml create mode 100644 test_programs/noir_test_success/comptime_blackbox/src/main.nr diff --git a/Cargo.toml b/Cargo.toml index d6a56b7990d..94ebe54fde1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -157,6 +157,8 @@ proptest-derive = "0.4.0" rayon = "1.8.0" sha2 = { version = "0.10.6", features = ["compress"] } sha3 = "0.10.6" +strum = "0.24" +strum_macros = "0.24" im = { version = "15.1", features = ["serde"] } tracing = "0.1.40" diff --git a/acvm-repo/acir/Cargo.toml b/acvm-repo/acir/Cargo.toml index 1f2ad1fb787..8139a58eefc 100644 --- a/acvm-repo/acir/Cargo.toml +++ b/acvm-repo/acir/Cargo.toml @@ -24,11 +24,11 @@ flate2.workspace = true bincode.workspace = true base64.workspace = true serde-big-array = "0.5.1" +strum = { workspace = true } +strum_macros = { workspace = true } [dev-dependencies] serde_json = "1.0" -strum = "0.24" -strum_macros = "0.24" serde-reflection = "0.3.6" serde-generate = "0.25.1" fxhash.workspace = true diff --git a/acvm-repo/acir/src/circuit/black_box_functions.rs b/acvm-repo/acir/src/circuit/black_box_functions.rs index 2e5a94f1c50..25842c14dbc 100644 --- a/acvm-repo/acir/src/circuit/black_box_functions.rs +++ b/acvm-repo/acir/src/circuit/black_box_functions.rs @@ -4,12 +4,10 @@ //! implemented in more basic constraints. use serde::{Deserialize, Serialize}; -#[cfg(test)] use strum_macros::EnumIter; #[allow(clippy::upper_case_acronyms)] -#[derive(Clone, Debug, Hash, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[cfg_attr(test, derive(EnumIter))] +#[derive(Clone, Debug, Hash, Copy, PartialEq, Eq, Serialize, Deserialize, EnumIter)] pub enum BlackBoxFunc { /// Ciphers (encrypts) the provided plaintext using AES128 in CBC mode, /// padding the input using PKCS#7. diff --git a/acvm-repo/acvm/Cargo.toml b/acvm-repo/acvm/Cargo.toml index 6ebfed27a32..e513ae4e727 100644 --- a/acvm-repo/acvm/Cargo.toml +++ b/acvm-repo/acvm/Cargo.toml @@ -25,11 +25,7 @@ acvm_blackbox_solver.workspace = true indexmap = "1.7.0" [features] -bn254 = [ - "acir/bn254", - "brillig_vm/bn254", - "acvm_blackbox_solver/bn254", -] +bn254 = ["acir/bn254", "brillig_vm/bn254", "acvm_blackbox_solver/bn254"] bls12_381 = [ "acir/bls12_381", "brillig_vm/bls12_381", @@ -37,10 +33,11 @@ bls12_381 = [ ] [dev-dependencies] -ark-bls12-381 = { version = "^0.4.0", default-features = false, features = ["curve"] } +ark-bls12-381 = { version = "^0.4.0", default-features = false, features = [ + "curve", +] } ark-bn254.workspace = true bn254_blackbox_solver.workspace = true proptest.workspace = true zkhash = { version = "^0.2.0", default-features = false } num-bigint.workspace = true - diff --git a/acvm-repo/blackbox_solver/src/bigint.rs b/acvm-repo/blackbox_solver/src/bigint.rs index b8bc9dc0d70..540862843ab 100644 --- a/acvm-repo/blackbox_solver/src/bigint.rs +++ b/acvm-repo/blackbox_solver/src/bigint.rs @@ -97,3 +97,51 @@ impl BigIntSolver { Ok(()) } } + +/// Wrapper over the generic bigint solver to automatically assign bigint IDs. +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct BigIntSolverWithId { + solver: BigIntSolver, + last_id: u32, +} + +impl BigIntSolverWithId { + pub fn create_bigint_id(&mut self) -> u32 { + let output = self.last_id; + self.last_id += 1; + output + } + + pub fn bigint_from_bytes( + &mut self, + inputs: &[u8], + modulus: &[u8], + ) -> Result { + let id = self.create_bigint_id(); + self.solver.bigint_from_bytes(inputs, modulus, id)?; + Ok(id) + } + + pub fn bigint_to_bytes(&self, input: u32) -> Result, BlackBoxResolutionError> { + self.solver.bigint_to_bytes(input) + } + + pub fn bigint_op( + &mut self, + lhs: u32, + rhs: u32, + func: BlackBoxFunc, + ) -> Result { + let modulus_lhs = self.solver.get_modulus(lhs, func)?; + let modulus_rhs = self.solver.get_modulus(rhs, func)?; + if modulus_lhs != modulus_rhs { + return Err(BlackBoxResolutionError::Failed( + func, + "moduli should be identical in BigInt operation".to_string(), + )); + } + let id = self.create_bigint_id(); + self.solver.bigint_op(lhs, rhs, id, func)?; + Ok(id) + } +} diff --git a/acvm-repo/blackbox_solver/src/lib.rs b/acvm-repo/blackbox_solver/src/lib.rs index d8f926fcb4b..0fa56c2f531 100644 --- a/acvm-repo/blackbox_solver/src/lib.rs +++ b/acvm-repo/blackbox_solver/src/lib.rs @@ -18,7 +18,7 @@ mod hash; mod logic; pub use aes128::aes128_encrypt; -pub use bigint::BigIntSolver; +pub use bigint::{BigIntSolver, BigIntSolverWithId}; pub use curve_specific_solver::{BlackBoxFunctionSolver, StubbedBlackBoxSolver}; pub use ecdsa::{ecdsa_secp256k1_verify, ecdsa_secp256r1_verify}; pub use hash::{blake2s, blake3, keccakf1600, sha256_compression}; diff --git a/acvm-repo/brillig_vm/src/black_box.rs b/acvm-repo/brillig_vm/src/black_box.rs index 0d90a4c8502..19e2dd7553d 100644 --- a/acvm-repo/brillig_vm/src/black_box.rs +++ b/acvm-repo/brillig_vm/src/black_box.rs @@ -1,9 +1,8 @@ use acir::brillig::{BlackBoxOp, HeapArray, HeapVector, IntegerBitSize}; use acir::{AcirField, BlackBoxFunc}; -use acvm_blackbox_solver::BigIntSolver; use acvm_blackbox_solver::{ aes128_encrypt, blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccakf1600, - sha256_compression, BlackBoxFunctionSolver, BlackBoxResolutionError, + sha256_compression, BigIntSolverWithId, BlackBoxFunctionSolver, BlackBoxResolutionError, }; use num_bigint::BigUint; use num_traits::Zero; @@ -39,11 +38,13 @@ fn to_value_vec(input: &[u8]) -> Vec> { input.iter().map(|&x| x.into()).collect() } +pub(crate) type BrilligBigIntSolver = BigIntSolverWithId; + pub(crate) fn evaluate_black_box>( op: &BlackBoxOp, solver: &Solver, memory: &mut Memory, - bigint_solver: &mut BrilligBigintSolver, + bigint_solver: &mut BrilligBigIntSolver, ) -> Result<(), BlackBoxResolutionError> { match op { BlackBoxOp::AES128Encrypt { inputs, iv, key, outputs } => { @@ -56,7 +57,7 @@ pub(crate) fn evaluate_black_box })?; let key: [u8; 16] = to_u8_vec(read_heap_array(memory, key)).try_into().map_err(|_| { - BlackBoxResolutionError::Failed(bb_func, "Invalid ley length".to_string()) + BlackBoxResolutionError::Failed(bb_func, "Invalid key length".to_string()) })?; let ciphertext = aes128_encrypt(&inputs, iv, key)?; @@ -353,54 +354,6 @@ pub(crate) fn evaluate_black_box } } -/// Wrapper over the generic bigint solver to automatically assign bigint ids in brillig -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub(crate) struct BrilligBigintSolver { - bigint_solver: BigIntSolver, - last_id: u32, -} - -impl BrilligBigintSolver { - pub(crate) fn create_bigint_id(&mut self) -> u32 { - let output = self.last_id; - self.last_id += 1; - output - } - - pub(crate) fn bigint_from_bytes( - &mut self, - inputs: &[u8], - modulus: &[u8], - ) -> Result { - let id = self.create_bigint_id(); - self.bigint_solver.bigint_from_bytes(inputs, modulus, id)?; - Ok(id) - } - - pub(crate) fn bigint_to_bytes(&self, input: u32) -> Result, BlackBoxResolutionError> { - self.bigint_solver.bigint_to_bytes(input) - } - - pub(crate) fn bigint_op( - &mut self, - lhs: u32, - rhs: u32, - func: BlackBoxFunc, - ) -> Result { - let modulus_lhs = self.bigint_solver.get_modulus(lhs, func)?; - let modulus_rhs = self.bigint_solver.get_modulus(rhs, func)?; - if modulus_lhs != modulus_rhs { - return Err(BlackBoxResolutionError::Failed( - func, - "moduli should be identical in BigInt operation".to_string(), - )); - } - let id = self.create_bigint_id(); - self.bigint_solver.bigint_op(lhs, rhs, id, func)?; - Ok(id) - } -} - fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc { match op { BlackBoxOp::AES128Encrypt { .. } => BlackBoxFunc::AES128Encrypt, diff --git a/acvm-repo/brillig_vm/src/lib.rs b/acvm-repo/brillig_vm/src/lib.rs index 45025fbb208..5b3688339b5 100644 --- a/acvm-repo/brillig_vm/src/lib.rs +++ b/acvm-repo/brillig_vm/src/lib.rs @@ -17,7 +17,7 @@ use acir::brillig::{ use acir::AcirField; use acvm_blackbox_solver::BlackBoxFunctionSolver; use arithmetic::{evaluate_binary_field_op, evaluate_binary_int_op, BrilligArithmeticError}; -use black_box::{evaluate_black_box, BrilligBigintSolver}; +use black_box::{evaluate_black_box, BrilligBigIntSolver}; // Re-export `brillig`. pub use acir::brillig; @@ -95,7 +95,7 @@ pub struct VM<'a, F, B: BlackBoxFunctionSolver> { /// The solver for blackbox functions black_box_solver: &'a B, // The solver for big integers - bigint_solver: BrilligBigintSolver, + bigint_solver: BrilligBigIntSolver, // Flag that determines whether we want to profile VM. profiling_active: bool, // Samples for profiling the VM execution. diff --git a/compiler/noirc_frontend/Cargo.toml b/compiler/noirc_frontend/Cargo.toml index 581d7f1b61d..5d1520af54f 100644 --- a/compiler/noirc_frontend/Cargo.toml +++ b/compiler/noirc_frontend/Cargo.toml @@ -30,8 +30,8 @@ cfg-if.workspace = true tracing.workspace = true petgraph = "0.6" rangemap = "1.4.0" -strum = "0.24" -strum_macros = "0.24" +strum.workspace = true +strum_macros.workspace = true [dev-dependencies] diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs index 198ba91156e..446c4dae2d3 100644 --- a/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -200,6 +200,11 @@ pub enum InterpreterError { item: String, location: Location, }, + InvalidInComptimeContext { + item: String, + location: Location, + explanation: String, + }, TypeAnnotationsNeededForMethodCall { location: Location, }, @@ -291,6 +296,7 @@ impl InterpreterError { | InterpreterError::UnsupportedTopLevelItemUnquote { location, .. } | InterpreterError::ComptimeDependencyCycle { location, .. } | InterpreterError::Unimplemented { location, .. } + | InterpreterError::InvalidInComptimeContext { location, .. } | InterpreterError::NoImpl { location, .. } | InterpreterError::ImplMethodTypeMismatch { location, .. } | InterpreterError::DebugEvaluateComptime { location, .. } @@ -540,6 +546,10 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let msg = format!("{item} is currently unimplemented"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } + InterpreterError::InvalidInComptimeContext { item, location, explanation } => { + let msg = format!("{item} is invalid in comptime context"); + CustomDiagnostic::simple_error(msg, explanation.clone(), location.span) + } InterpreterError::BreakNotInLoop { location } => { let msg = "There is no loop to break out of!".into(); CustomDiagnostic::simple_error(msg, String::new(), location.span) diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 994318a371a..49fd86b73bb 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -1,6 +1,7 @@ use std::collections::VecDeque; use std::{collections::hash_map::Entry, rc::Rc}; +use acvm::blackbox_solver::BigIntSolverWithId; use acvm::{acir::AcirField, FieldElement}; use fm::FileId; use im::Vector; @@ -62,6 +63,9 @@ pub struct Interpreter<'local, 'interner> { /// multiple times. Without this map, when one of these inner functions exits we would /// unbind the generic completely instead of resetting it to its previous binding. bound_generics: Vec>, + + /// Stateful bigint calculator. + bigint_solver: BigIntSolverWithId, } #[allow(unused)] @@ -71,9 +75,14 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { crate_id: CrateId, current_function: Option, ) -> Self { - let bound_generics = Vec::new(); - let in_loop = false; - Self { elaborator, crate_id, current_function, bound_generics, in_loop } + Self { + elaborator, + crate_id, + current_function, + bound_generics: Vec::new(), + in_loop: false, + bigint_solver: BigIntSolverWithId::default(), + } } pub(crate) fn call_function( @@ -227,11 +236,9 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { .expect("all builtin functions must contain a function attribute which contains the opcode which it links to"); if let Some(builtin) = func_attrs.builtin() { - let builtin = builtin.clone(); - self.call_builtin(&builtin, arguments, return_type, location) + self.call_builtin(builtin.clone().as_str(), arguments, return_type, location) } else if let Some(foreign) = func_attrs.foreign() { - let foreign = foreign.clone(); - foreign::call_foreign(self.elaborator.interner, &foreign, arguments, location) + self.call_foreign(foreign.clone().as_str(), arguments, return_type, location) } else if let Some(oracle) = func_attrs.oracle() { if oracle == "print" { self.print_oracle(arguments) @@ -906,6 +913,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { } } + #[allow(clippy::bool_comparison)] fn evaluate_infix(&mut self, infix: HirInfixExpression, id: ExprId) -> IResult { let lhs_value = self.evaluate(infix.lhs)?; let rhs_value = self.evaluate(infix.rhs)?; @@ -924,310 +932,183 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { InterpreterError::InvalidValuesForBinary { lhs, rhs, location, operator } }; - use InterpreterError::InvalidValuesForBinary; - match infix.operator.kind { - BinaryOpKind::Add => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs + rhs)), - (Value::I8(lhs), Value::I8(rhs)) => { - Ok(Value::I8(lhs.checked_add(rhs).ok_or(error("+"))?)) - } - (Value::I16(lhs), Value::I16(rhs)) => { - Ok(Value::I16(lhs.checked_add(rhs).ok_or(error("+"))?)) - } - (Value::I32(lhs), Value::I32(rhs)) => { - Ok(Value::I32(lhs.checked_add(rhs).ok_or(error("+"))?)) + /// Generate matches that can promote the type of one side to the other if they are compatible. + macro_rules! match_values { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) { + $( + ($lhs_var:ident, $rhs_var:ident) to $res_var:ident => $expr:expr + ),* + $(,)? + } + ) => { + match ($lhs_value, $rhs_value) { + $( + (Value::$lhs_var($lhs), Value::$rhs_var($rhs)) => { + Ok(Value::$res_var(($expr).ok_or(error($op))?)) + }, + )* + (lhs, rhs) => { + Err(error($op)) + }, + } + }; + } + + /// Generate matches for arithmetic operations on `Field` and integers. + macro_rules! match_arithmetic { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) { field: $field_expr:expr, int: $int_expr:expr, }) => { + match_values! { + ($lhs_value as $lhs $op $rhs_value as $rhs) { + (Field, Field) to Field => Some($field_expr), + (I8, I8) to I8 => $int_expr, + (I16, I16) to I16 => $int_expr, + (I32, I32) to I32 => $int_expr, + (I64, I64) to I64 => $int_expr, + (U8, U8) to U8 => $int_expr, + (U16, U16) to U16 => $int_expr, + (U32, U32) to U32 => $int_expr, + (U64, U64) to U64 => $int_expr, + } } - (Value::I64(lhs), Value::I64(rhs)) => { - Ok(Value::I64(lhs.checked_add(rhs).ok_or(error("+"))?)) + }; + } + + /// Generate matches for comparison operations on all types, returning `Bool`. + macro_rules! match_cmp { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) => $expr:expr) => { + match_values! { + ($lhs_value as $lhs $op $rhs_value as $rhs) { + (Field, Field) to Bool => Some($expr), + (Bool, Bool) to Bool => Some($expr), + (I8, I8) to Bool => Some($expr), + (I16, I16) to Bool => Some($expr), + (I32, I32) to Bool => Some($expr), + (I64, I64) to Bool => Some($expr), + (U8, U8) to Bool => Some($expr), + (U16, U16) to Bool => Some($expr), + (U32, U32) to Bool => Some($expr), + (U64, U64) to Bool => Some($expr), + } } - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_add(rhs).ok_or(error("+"))?)) + }; + } + + /// Generate matches for bitwise operations on `Bool` and integers. + macro_rules! match_bitwise { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) => $expr:expr) => { + match_values! { + ($lhs_value as $lhs $op $rhs_value as $rhs) { + (Bool, Bool) to Bool => Some($expr), + (I8, I8) to I8 => Some($expr), + (I16, I16) to I16 => Some($expr), + (I32, I32) to I32 => Some($expr), + (I64, I64) to I64 => Some($expr), + (U8, U8) to U8 => Some($expr), + (U16, U16) to U16 => Some($expr), + (U32, U32) to U32 => Some($expr), + (U64, U64) to U64 => Some($expr), + } } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_add(rhs).ok_or(error("+"))?)) + }; + } + + /// Generate matches for operations on just integer values. + macro_rules! match_integer { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) => $expr:expr) => { + match_values! { + ($lhs_value as $lhs $op $rhs_value as $rhs) { + (I8, I8) to I8 => $expr, + (I16, I16) to I16 => $expr, + (I32, I32) to I32 => $expr, + (I64, I64) to I64 => $expr, + (U8, U8) to U8 => $expr, + (U16, U16) to U16 => $expr, + (U32, U32) to U32 => $expr, + (U64, U64) to U64 => $expr, + } } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_add(rhs).ok_or(error("+"))?)) + }; + } + + /// Generate matches for bit shifting, which in Noir only accepts `u8` for RHS. + macro_rules! match_bitshift { + (($lhs_value:ident as $lhs:ident $op:literal $rhs_value:ident as $rhs:ident) => $expr:expr) => { + match_values! { + ($lhs_value as $lhs $op $rhs_value as $rhs) { + (I8, U8) to I8 => $expr, + (I16, U8) to I16 => $expr, + (I32, U8) to I32 => $expr, + (I64, U8) to I64 => $expr, + (U8, U8) to U8 => $expr, + (U16, U8) to U16 => $expr, + (U32, U8) to U32 => $expr, + (U64, U8) to U64 => $expr, + } } - (Value::U64(lhs), Value::U64(rhs)) => { - Ok(Value::U64(lhs.checked_add(rhs).ok_or(error("+"))?)) + }; + } + + use InterpreterError::InvalidValuesForBinary; + match infix.operator.kind { + BinaryOpKind::Add => match_arithmetic! { + (lhs_value as lhs "+" rhs_value as rhs) { + field: lhs + rhs, + int: lhs.checked_add(rhs), } - (lhs, rhs) => Err(error("+")), }, - BinaryOpKind::Subtract => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs - rhs)), - (Value::I8(lhs), Value::I8(rhs)) => { - Ok(Value::I8(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::I16(lhs), Value::I16(rhs)) => { - Ok(Value::I16(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::I32(lhs), Value::I32(rhs)) => { - Ok(Value::I32(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::I64(lhs), Value::I64(rhs)) => { - Ok(Value::I64(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_sub(rhs).ok_or(error("-"))?)) + BinaryOpKind::Subtract => match_arithmetic! { + (lhs_value as lhs "-" rhs_value as rhs) { + field: lhs - rhs, + int: lhs.checked_sub(rhs), } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (Value::U64(lhs), Value::U64(rhs)) => { - Ok(Value::U64(lhs.checked_sub(rhs).ok_or(error("-"))?)) - } - (lhs, rhs) => Err(error("-")), }, - BinaryOpKind::Multiply => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs * rhs)), - (Value::I8(lhs), Value::I8(rhs)) => { - Ok(Value::I8(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::I16(lhs), Value::I16(rhs)) => { - Ok(Value::I16(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::I32(lhs), Value::I32(rhs)) => { - Ok(Value::I32(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::I64(lhs), Value::I64(rhs)) => { - Ok(Value::I64(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_mul(rhs).ok_or(error("*"))?)) + BinaryOpKind::Multiply => match_arithmetic! { + (lhs_value as lhs "*" rhs_value as rhs) { + field: lhs * rhs, + int: lhs.checked_mul(rhs), } - (Value::U64(lhs), Value::U64(rhs)) => { - Ok(Value::U64(lhs.checked_mul(rhs).ok_or(error("*"))?)) - } - (lhs, rhs) => Err(error("*")), }, - BinaryOpKind::Divide => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs / rhs)), - (Value::I8(lhs), Value::I8(rhs)) => { - Ok(Value::I8(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::I16(lhs), Value::I16(rhs)) => { - Ok(Value::I16(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::I32(lhs), Value::I32(rhs)) => { - Ok(Value::I32(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::I64(lhs), Value::I64(rhs)) => { - Ok(Value::I64(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_div(rhs).ok_or(error("/"))?)) + BinaryOpKind::Divide => match_arithmetic! { + (lhs_value as lhs "/" rhs_value as rhs) { + field: lhs / rhs, + int: lhs.checked_div(rhs), } - (Value::U64(lhs), Value::U64(rhs)) => { - Ok(Value::U64(lhs.checked_div(rhs).ok_or(error("/"))?)) - } - (lhs, rhs) => Err(error("/")), }, - BinaryOpKind::Equal => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs == rhs)), - (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs == rhs)), - (lhs, rhs) => Err(error("==")), + BinaryOpKind::Equal => match_cmp! { + (lhs_value as lhs "==" rhs_value as rhs) => lhs == rhs }, - BinaryOpKind::NotEqual => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs != rhs)), - (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs != rhs)), - (lhs, rhs) => Err(error("!=")), + BinaryOpKind::NotEqual => match_cmp! { + (lhs_value as lhs "!=" rhs_value as rhs) => lhs != rhs }, - BinaryOpKind::Less => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs < rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs < rhs)), - (lhs, rhs) => Err(error("<")), + BinaryOpKind::Less => match_cmp! { + (lhs_value as lhs "<" rhs_value as rhs) => lhs < rhs }, - BinaryOpKind::LessEqual => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (lhs, rhs) => Err(error("<=")), + BinaryOpKind::LessEqual => match_cmp! { + (lhs_value as lhs "<=" rhs_value as rhs) => lhs <= rhs }, - BinaryOpKind::Greater => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs > rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs > rhs)), - (lhs, rhs) => Err(error(">")), + BinaryOpKind::Greater => match_cmp! { + (lhs_value as lhs ">" rhs_value as rhs) => lhs > rhs }, - BinaryOpKind::GreaterEqual => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (lhs, rhs) => Err(error(">=")), + BinaryOpKind::GreaterEqual => match_cmp! { + (lhs_value as lhs ">=" rhs_value as rhs) => lhs >= rhs }, - BinaryOpKind::And => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs & rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs & rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs & rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs & rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs & rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs & rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs & rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs & rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs & rhs)), - (lhs, rhs) => Err(error("&")), + BinaryOpKind::And => match_bitwise! { + (lhs_value as lhs "&" rhs_value as rhs) => lhs & rhs }, - BinaryOpKind::Or => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs | rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs | rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs | rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs | rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs | rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs | rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs | rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs | rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs | rhs)), - (lhs, rhs) => Err(error("|")), + BinaryOpKind::Or => match_bitwise! { + (lhs_value as lhs "|" rhs_value as rhs) => lhs | rhs }, - BinaryOpKind::Xor => match (lhs_value.clone(), rhs_value.clone()) { - (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs ^ rhs)), - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs ^ rhs)), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs ^ rhs)), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs ^ rhs)), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs ^ rhs)), - (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs ^ rhs)), - (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs ^ rhs)), - (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs ^ rhs)), - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs ^ rhs)), - (lhs, rhs) => Err(error("^")), + BinaryOpKind::Xor => match_bitwise! { + (lhs_value as lhs "^" rhs_value as rhs) => lhs ^ rhs }, - BinaryOpKind::ShiftRight => match (lhs_value.clone(), rhs_value.clone()) { - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8( - lhs.checked_shr(rhs.try_into().map_err(|_| error(">>"))?).ok_or(error(">>"))?, - )), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16( - lhs.checked_shr(rhs.try_into().map_err(|_| error(">>"))?).ok_or(error(">>"))?, - )), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32( - lhs.checked_shr(rhs.try_into().map_err(|_| error(">>"))?).ok_or(error(">>"))?, - )), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64( - lhs.checked_shr(rhs.try_into().map_err(|_| error(">>"))?).ok_or(error(">>"))?, - )), - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_shr(rhs.into()).ok_or(error(">>"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_shr(rhs.into()).ok_or(error(">>"))?)) - } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_shr(rhs).ok_or(error(">>"))?)) - } - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64( - lhs.checked_shr(rhs.try_into().map_err(|_| error(">>"))?).ok_or(error(">>"))?, - )), - (lhs, rhs) => Err(error(">>")), + BinaryOpKind::ShiftRight => match_bitshift! { + (lhs_value as lhs ">>" rhs_value as rhs) => lhs.checked_shr(rhs.into()) }, - BinaryOpKind::ShiftLeft => match (lhs_value.clone(), rhs_value.clone()) { - (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8( - lhs.checked_shl(rhs.try_into().map_err(|_| error("<<"))?).ok_or(error("<<"))?, - )), - (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16( - lhs.checked_shl(rhs.try_into().map_err(|_| error("<<"))?).ok_or(error("<<"))?, - )), - (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32( - lhs.checked_shl(rhs.try_into().map_err(|_| error("<<"))?).ok_or(error("<<"))?, - )), - (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64( - lhs.checked_shl(rhs.try_into().map_err(|_| error("<<"))?).ok_or(error("<<"))?, - )), - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_shl(rhs.into()).ok_or(error("<<"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_shl(rhs.into()).ok_or(error("<<"))?)) - } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_shl(rhs).ok_or(error("<<"))?)) - } - (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64( - lhs.checked_shl(rhs.try_into().map_err(|_| error("<<"))?).ok_or(error("<<"))?, - )), - (lhs, rhs) => Err(error("<<")), + BinaryOpKind::ShiftLeft => match_bitshift! { + (lhs_value as lhs "<<" rhs_value as rhs) => lhs.checked_shl(rhs.into()) }, - BinaryOpKind::Modulo => match (lhs_value.clone(), rhs_value.clone()) { - (Value::I8(lhs), Value::I8(rhs)) => { - Ok(Value::I8(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::I16(lhs), Value::I16(rhs)) => { - Ok(Value::I16(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::I32(lhs), Value::I32(rhs)) => { - Ok(Value::I32(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::I64(lhs), Value::I64(rhs)) => { - Ok(Value::I64(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::U8(lhs), Value::U8(rhs)) => { - Ok(Value::U8(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::U16(lhs), Value::U16(rhs)) => { - Ok(Value::U16(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::U32(lhs), Value::U32(rhs)) => { - Ok(Value::U32(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (Value::U64(lhs), Value::U64(rhs)) => { - Ok(Value::U64(lhs.checked_rem(rhs).ok_or(error("%"))?)) - } - (lhs, rhs) => Err(error("%")), + BinaryOpKind::Modulo => match_integer! { + (lhs_value as lhs "%" rhs_value as rhs) => lhs.checked_rem(rhs) }, } } diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index e8b34ab4323..3d8ccf78926 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -1,14 +1,14 @@ use std::rc::Rc; -use acvm::{AcirField, FieldElement}; +use acvm::{acir::BlackBoxFunc, AcirField, FieldElement}; use builtin_helpers::{ - block_expression_to_value, check_argument_count, check_function_not_yet_resolved, - check_one_argument, check_three_arguments, check_two_arguments, get_bool, get_expr, get_field, - get_format_string, get_function_def, get_module, get_quoted, get_slice, get_struct, - get_trait_constraint, get_trait_def, get_trait_impl, get_tuple, get_type, get_typed_expr, - get_u32, get_unresolved_type, has_named_attribute, hir_pattern_to_tokens, - mutate_func_meta_type, parse, quote_ident, replace_func_meta_parameters, - replace_func_meta_return_type, + block_expression_to_value, byte_array_type, check_argument_count, + check_function_not_yet_resolved, check_one_argument, check_three_arguments, + check_two_arguments, get_bool, get_expr, get_field, get_format_string, get_function_def, + get_module, get_quoted, get_slice, get_struct, get_trait_constraint, get_trait_def, + get_trait_impl, get_tuple, get_type, get_typed_expr, get_u32, get_unresolved_type, + has_named_attribute, hir_pattern_to_tokens, mutate_func_meta_type, parse, quote_ident, + replace_func_meta_parameters, replace_func_meta_return_type, }; use im::Vector; use iter_extended::{try_vecmap, vecmap}; @@ -42,7 +42,7 @@ use crate::{ }; use self::builtin_helpers::{eq_item, get_array, get_ctstring, get_str, get_u8, hash_item, lex}; -use super::{foreign, Interpreter}; +use super::Interpreter; pub(crate) mod builtin_helpers; @@ -57,7 +57,9 @@ impl<'local, 'context> Interpreter<'local, 'context> { let interner = &mut self.elaborator.interner; let call_stack = &self.elaborator.interpreter_call_stack; match name { - "apply_range_constraint" => foreign::apply_range_constraint(arguments, location), + "apply_range_constraint" => { + self.call_foreign("range", arguments, return_type, location) + } "array_as_str_unchecked" => array_as_str_unchecked(interner, arguments, location), "array_len" => array_len(interner, arguments, location), "array_refcount" => Ok(Value::U32(0)), @@ -234,8 +236,11 @@ impl<'local, 'context> Interpreter<'local, 'context> { "unresolved_type_is_field" => unresolved_type_is_field(interner, arguments, location), "unresolved_type_is_unit" => unresolved_type_is_unit(interner, arguments, location), "zeroed" => zeroed(return_type, location.span), + blackbox if BlackBoxFunc::is_valid_black_box_func_name(blackbox) => { + self.call_foreign(blackbox, arguments, return_type, location) + } _ => { - let item = format!("Comptime evaluation for builtin function {name}"); + let item = format!("Comptime evaluation for builtin function '{name}'"); Err(InterpreterError::Unimplemented { item, location }) } } @@ -324,10 +329,7 @@ fn str_as_bytes( let string = get_str(interner, string)?; let bytes: im::Vector = string.bytes().map(Value::U8).collect(); - let byte_array_type = Type::Array( - Box::new(Type::Constant(bytes.len().into(), Kind::u32())), - Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight)), - ); + let byte_array_type = byte_array_type(bytes.len()); Ok(Value::Array(bytes, byte_array_type)) } @@ -820,10 +822,8 @@ fn to_le_radix( Some(digit) => Value::U8(*digit), None => Value::U8(0), }); - Ok(Value::Array( - decomposed_integer.into(), - Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), - )) + let result_type = byte_array_type(decomposed_integer.len()); + Ok(Value::Array(decomposed_integer.into(), result_type)) } fn compute_to_radix_le(field: FieldElement, radix: u32) -> Vec { diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs index 3f9d92cfe88..cf90aab32e0 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs @@ -2,6 +2,7 @@ use std::hash::Hash; use std::{hash::Hasher, rc::Rc}; use acvm::FieldElement; +use iter_extended::try_vecmap; use noirc_errors::Location; use crate::hir::comptime::display::tokens_to_string; @@ -30,6 +31,8 @@ use crate::{ token::{SecondaryAttribute, Token, Tokens}, QuotedType, Type, }; +use crate::{Kind, Shared, StructType}; +use rustc_hash::FxHashMap as HashMap; pub(crate) fn check_argument_count( expected: usize, @@ -45,38 +48,40 @@ pub(crate) fn check_argument_count( } pub(crate) fn check_one_argument( - mut arguments: Vec<(Value, Location)>, + arguments: Vec<(Value, Location)>, location: Location, ) -> IResult<(Value, Location)> { - check_argument_count(1, &arguments, location)?; + let [arg1] = check_arguments(arguments, location)?; - Ok(arguments.pop().unwrap()) + Ok(arg1) } pub(crate) fn check_two_arguments( - mut arguments: Vec<(Value, Location)>, + arguments: Vec<(Value, Location)>, location: Location, ) -> IResult<((Value, Location), (Value, Location))> { - check_argument_count(2, &arguments, location)?; - - let argument2 = arguments.pop().unwrap(); - let argument1 = arguments.pop().unwrap(); + let [arg1, arg2] = check_arguments(arguments, location)?; - Ok((argument1, argument2)) + Ok((arg1, arg2)) } #[allow(clippy::type_complexity)] pub(crate) fn check_three_arguments( - mut arguments: Vec<(Value, Location)>, + arguments: Vec<(Value, Location)>, location: Location, ) -> IResult<((Value, Location), (Value, Location), (Value, Location))> { - check_argument_count(3, &arguments, location)?; + let [arg1, arg2, arg3] = check_arguments(arguments, location)?; - let argument3 = arguments.pop().unwrap(); - let argument2 = arguments.pop().unwrap(); - let argument1 = arguments.pop().unwrap(); + Ok((arg1, arg2, arg3)) +} - Ok((argument1, argument2, argument3)) +#[allow(clippy::type_complexity)] +pub(crate) fn check_arguments( + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult<[(Value, Location); N]> { + check_argument_count(N, &arguments, location)?; + Ok(arguments.try_into().expect("checked arg count")) } pub(crate) fn get_array( @@ -93,6 +98,47 @@ pub(crate) fn get_array( } } +/// Get the fields if the value is a `Value::Struct`, otherwise report that a struct type +/// with `name` is expected. Returns the `Type` but doesn't verify that it's called `name`. +pub(crate) fn get_struct_fields( + name: &str, + (value, location): (Value, Location), +) -> IResult<(HashMap, Value>, Type)> { + match value { + Value::Struct(fields, typ) => Ok((fields, typ)), + _ => { + let expected = StructType::new( + StructId::dummy_id(), + Ident::new(name.to_string(), location.span), + location, + Vec::new(), + Vec::new(), + ); + let expected = Type::Struct(Shared::new(expected), Vec::new()); + type_mismatch(value, expected, location) + } + } +} + +/// Get a specific field of a struct and apply a decoder function on it. +pub(crate) fn get_struct_field( + field_name: &str, + struct_fields: &HashMap, Value>, + struct_type: &Type, + location: Location, + f: impl Fn((Value, Location)) -> IResult, +) -> IResult { + let key = Rc::new(field_name.to_string()); + let Some(value) = struct_fields.get(&key) else { + return Err(InterpreterError::ExpectedStructToHaveField { + typ: struct_type.clone(), + field_name: Rc::into_inner(key).unwrap(), + location, + }); + }; + f((value.clone(), location)) +} + pub(crate) fn get_bool((value, location): (Value, Location)) -> IResult { match value { Value::Bool(value) => Ok(value), @@ -114,6 +160,49 @@ pub(crate) fn get_slice( } } +/// Interpret the input as a slice, then map each element. +/// Returns the values in the slice and the original type. +pub(crate) fn get_slice_map( + interner: &NodeInterner, + (value, location): (Value, Location), + f: impl Fn((Value, Location)) -> IResult, +) -> IResult<(Vec, Type)> { + let (values, typ) = get_slice(interner, (value, location))?; + let values = try_vecmap(values, |value| f((value, location)))?; + Ok((values, typ)) +} + +/// Interpret the input as an array, then map each element. +/// Returns the values in the array and the original array type. +pub(crate) fn get_array_map( + interner: &NodeInterner, + (value, location): (Value, Location), + f: impl Fn((Value, Location)) -> IResult, +) -> IResult<(Vec, Type)> { + let (values, typ) = get_array(interner, (value, location))?; + let values = try_vecmap(values, |value| f((value, location)))?; + Ok((values, typ)) +} + +/// Get an array and convert it to a fixed size. +/// Returns the values in the array and the original array type. +pub(crate) fn get_fixed_array_map( + interner: &NodeInterner, + (value, location): (Value, Location), + f: impl Fn((Value, Location)) -> IResult, +) -> IResult<([T; N], Type)> { + let (values, typ) = get_array_map(interner, (value, location), f)?; + + values.try_into().map(|v| (v, typ.clone())).map_err(|_| { + // Assuming that `values.len()` corresponds to `typ`. + let Type::Array(_, ref elem) = typ else { + unreachable!("get_array_map checked it was an array") + }; + let expected = Type::Array(Box::new(Type::Constant(N.into(), Kind::u32())), elem.clone()); + InterpreterError::TypeMismatch { expected, actual: typ, location } + }) +} + pub(crate) fn get_str( interner: &NodeInterner, (value, location): (Value, Location), @@ -520,3 +609,44 @@ pub(super) fn eq_item( let other_arg = get_item(other_arg)?; Ok(Value::Bool(self_arg == other_arg)) } + +/// Type to be used in `Value::Array(, )`. +pub(crate) fn byte_array_type(len: usize) -> Type { + Type::Array( + Box::new(Type::Constant(len.into(), Kind::u32())), + Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight)), + ) +} + +/// Type to be used in `Value::Slice(, )`. +pub(crate) fn byte_slice_type() -> Type { + Type::Slice(Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight))) +} + +/// Create a `Value::Array` from bytes. +pub(crate) fn to_byte_array(values: &[u8]) -> Value { + Value::Array(values.iter().copied().map(Value::U8).collect(), byte_array_type(values.len())) +} + +/// Create a `Value::Slice` from bytes. +pub(crate) fn to_byte_slice(values: &[u8]) -> Value { + Value::Slice(values.iter().copied().map(Value::U8).collect(), byte_slice_type()) +} + +/// Create a `Value::Array` from fields. +pub(crate) fn to_field_array(values: &[FieldElement]) -> Value { + let typ = Type::Array( + Box::new(Type::Constant(values.len().into(), Kind::u32())), + Box::new(Type::FieldElement), + ); + Value::Array(values.iter().copied().map(Value::Field).collect(), typ) +} + +/// Create a `Value::Struct` from fields and the expected return type. +pub(crate) fn to_struct( + fields: impl IntoIterator, + typ: Type, +) -> Value { + let fields = fields.into_iter().map(|(k, v)| (Rc::new(k.to_string()), v)).collect(); + Value::Struct(fields, typ) +} diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs index 3de72969cab..d2611f72535 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs @@ -1,40 +1,126 @@ use acvm::{ - acir::BlackBoxFunc, blackbox_solver::BlackBoxFunctionSolver, AcirField, BlackBoxResolutionError, + acir::BlackBoxFunc, + blackbox_solver::{BigIntSolverWithId, BlackBoxFunctionSolver}, + AcirField, BlackBoxResolutionError, FieldElement, }; -use bn254_blackbox_solver::Bn254BlackBoxSolver; +use bn254_blackbox_solver::Bn254BlackBoxSolver; // Currently locked to only bn254! use im::Vector; -use iter_extended::try_vecmap; use noirc_errors::Location; use crate::{ - hir::comptime::{errors::IResult, InterpreterError, Value}, + hir::comptime::{ + errors::IResult, interpreter::builtin::builtin_helpers::to_byte_array, InterpreterError, + Value, + }, node_interner::NodeInterner, + Type, }; -use super::builtin::builtin_helpers::{ - check_one_argument, check_two_arguments, get_array, get_field, get_u32, get_u64, +use super::{ + builtin::builtin_helpers::{ + check_arguments, check_one_argument, check_three_arguments, check_two_arguments, + get_array_map, get_bool, get_field, get_fixed_array_map, get_slice_map, get_struct_field, + get_struct_fields, get_u32, get_u64, get_u8, to_byte_slice, to_field_array, to_struct, + }, + Interpreter, }; -pub(super) fn call_foreign( +impl<'local, 'context> Interpreter<'local, 'context> { + pub(super) fn call_foreign( + &mut self, + name: &str, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, + ) -> IResult { + call_foreign( + self.elaborator.interner, + &mut self.bigint_solver, + name, + arguments, + return_type, + location, + ) + } +} + +// Similar to `evaluate_black_box` in `brillig_vm`. +fn call_foreign( interner: &mut NodeInterner, + bigint_solver: &mut BigIntSolverWithId, name: &str, - arguments: Vec<(Value, Location)>, + args: Vec<(Value, Location)>, + return_type: Type, location: Location, ) -> IResult { + use BlackBoxFunc::*; + match name { - "poseidon2_permutation" => poseidon2_permutation(interner, arguments, location), - "keccakf1600" => keccakf1600(interner, arguments, location), + "aes128_encrypt" => aes128_encrypt(interner, args, location), + "bigint_from_le_bytes" => { + bigint_from_le_bytes(interner, bigint_solver, args, return_type, location) + } + "bigint_to_le_bytes" => bigint_to_le_bytes(bigint_solver, args, location), + "bigint_add" => bigint_op(bigint_solver, BigIntAdd, args, return_type, location), + "bigint_sub" => bigint_op(bigint_solver, BigIntSub, args, return_type, location), + "bigint_mul" => bigint_op(bigint_solver, BigIntMul, args, return_type, location), + "bigint_div" => bigint_op(bigint_solver, BigIntDiv, args, return_type, location), + "blake2s" => blake_hash(interner, args, location, acvm::blackbox_solver::blake2s), + "blake3" => blake_hash(interner, args, location, acvm::blackbox_solver::blake3), + "ecdsa_secp256k1" => ecdsa_secp256_verify( + interner, + args, + location, + acvm::blackbox_solver::ecdsa_secp256k1_verify, + ), + "ecdsa_secp256r1" => ecdsa_secp256_verify( + interner, + args, + location, + acvm::blackbox_solver::ecdsa_secp256r1_verify, + ), + "embedded_curve_add" => embedded_curve_add(args, location), + "multi_scalar_mul" => multi_scalar_mul(interner, args, location), + "poseidon2_permutation" => poseidon2_permutation(interner, args, location), + "keccakf1600" => keccakf1600(interner, args, location), + "range" => apply_range_constraint(args, location), + "sha256_compression" => sha256_compression(interner, args, location), _ => { - let item = format!("Comptime evaluation for builtin function {name}"); - Err(InterpreterError::Unimplemented { item, location }) + let explanation = match name { + "schnorr_verify" => "Schnorr verification will be removed.".into(), + "and" | "xor" => "It should be turned into a binary operation.".into(), + "recursive_aggregation" => "A proof cannot be verified at comptime.".into(), + _ => { + let item = format!("Comptime evaluation for foreign function '{name}'"); + return Err(InterpreterError::Unimplemented { item, location }); + } + }; + + let item = format!("Attempting to evaluate foreign function '{name}'"); + Err(InterpreterError::InvalidInComptimeContext { item, location, explanation }) } } } -pub(super) fn apply_range_constraint( +/// `pub fn aes128_encrypt(input: [u8; N], iv: [u8; 16], key: [u8; 16]) -> [u8]` +fn aes128_encrypt( + interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { + let (inputs, iv, key) = check_three_arguments(arguments, location)?; + + let (inputs, _) = get_array_map(interner, inputs, get_u8)?; + let (iv, _) = get_fixed_array_map(interner, iv, get_u8)?; + let (key, _) = get_fixed_array_map(interner, key, get_u8)?; + + let output = acvm::blackbox_solver::aes128_encrypt(&inputs, iv, key) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_byte_slice(&output)) +} + +fn apply_range_constraint(arguments: Vec<(Value, Location)>, location: Location) -> IResult { let (value, num_bits) = check_two_arguments(arguments, location)?; let input = get_field(value)?; @@ -53,21 +139,192 @@ pub(super) fn apply_range_constraint( } } -// poseidon2_permutation(_input: [Field; N], _state_length: u32) -> [Field; N] +/// `fn from_le_bytes(bytes: [u8], modulus: [u8]) -> BigInt` +/// +/// Returns the ID of the new bigint allocated by the solver. +fn bigint_from_le_bytes( + interner: &mut NodeInterner, + solver: &mut BigIntSolverWithId, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + let (bytes, modulus) = check_two_arguments(arguments, location)?; + + let (bytes, _) = get_slice_map(interner, bytes, get_u8)?; + let (modulus, _) = get_slice_map(interner, modulus, get_u8)?; + + let id = solver + .bigint_from_bytes(&bytes, &modulus) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_bigint(id, return_type)) +} + +/// `fn to_le_bytes(self) -> [u8; 32]` +/// +/// Take the ID of a bigint and returned its content. +fn bigint_to_le_bytes( + solver: &mut BigIntSolverWithId, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let int = check_one_argument(arguments, location)?; + let id = get_bigint_id(int)?; + + let mut bytes = + solver.bigint_to_bytes(id).map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + assert!(bytes.len() <= 32); + bytes.resize(32, 0); + + Ok(to_byte_array(&bytes)) +} + +/// `fn bigint_add(self, other: BigInt) -> BigInt` +/// +/// Takes two previous allocated IDs, gets the values from the solver, +/// stores the result of the operation, returns the new ID. +fn bigint_op( + solver: &mut BigIntSolverWithId, + func: BlackBoxFunc, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + let (lhs, rhs) = check_two_arguments(arguments, location)?; + + let lhs = get_bigint_id(lhs)?; + let rhs = get_bigint_id(rhs)?; + + let id = solver + .bigint_op(lhs, rhs, func) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_bigint(id, return_type)) +} + +/// Run one of the Blake hash functions. +/// ```text +/// pub fn blake2s(input: [u8; N]) -> [u8; 32] +/// pub fn blake3(input: [u8; N]) -> [u8; 32] +/// ``` +fn blake_hash( + interner: &mut NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, + f: impl Fn(&[u8]) -> Result<[u8; 32], BlackBoxResolutionError>, +) -> IResult { + let inputs = check_one_argument(arguments, location)?; + + let (inputs, _) = get_array_map(interner, inputs, get_u8)?; + let output = f(&inputs).map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_byte_array(&output)) +} + +/// Run one of the Secp256 signature verifications. +/// ```text +/// pub fn verify_signature( +/// public_key_x: [u8; 32], +/// public_key_y: [u8; 32], +/// signature: [u8; 64], +/// message_hash: [u8; N], +/// ) -> bool + +/// pub fn verify_signature_slice( +/// public_key_x: [u8; 32], +/// public_key_y: [u8; 32], +/// signature: [u8; 64], +/// message_hash: [u8], +/// ) -> bool +/// ``` +fn ecdsa_secp256_verify( + interner: &mut NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, + f: impl Fn(&[u8], &[u8; 32], &[u8; 32], &[u8; 64]) -> Result, +) -> IResult { + let [pub_key_x, pub_key_y, sig, msg_hash] = check_arguments(arguments, location)?; + + let (pub_key_x, _) = get_fixed_array_map(interner, pub_key_x, get_u8)?; + let (pub_key_y, _) = get_fixed_array_map(interner, pub_key_y, get_u8)?; + let (sig, _) = get_fixed_array_map(interner, sig, get_u8)?; + + // Hash can be an array or slice. + let (msg_hash, _) = if matches!(msg_hash.0.get_type().as_ref(), Type::Array(_, _)) { + get_array_map(interner, msg_hash.clone(), get_u8)? + } else { + get_slice_map(interner, msg_hash, get_u8)? + }; + + let is_valid = f(&msg_hash, &pub_key_x, &pub_key_y, &sig) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(Value::Bool(is_valid)) +} + +/// ```text +/// fn embedded_curve_add( +/// point1: EmbeddedCurvePoint, +/// point2: EmbeddedCurvePoint, +/// ) -> [Field; 3] +/// ``` +fn embedded_curve_add(arguments: Vec<(Value, Location)>, location: Location) -> IResult { + let (point1, point2) = check_two_arguments(arguments, location)?; + + let (p1x, p1y, p1inf) = get_embedded_curve_point(point1)?; + let (p2x, p2y, p2inf) = get_embedded_curve_point(point2)?; + + let (x, y, inf) = Bn254BlackBoxSolver + .ec_add(&p1x, &p1y, &p1inf.into(), &p2x, &p2y, &p2inf.into()) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_field_array(&[x, y, inf])) +} + +/// ```text +/// pub fn multi_scalar_mul( +/// points: [EmbeddedCurvePoint; N], +/// scalars: [EmbeddedCurveScalar; N], +/// ) -> [Field; 3] +/// ``` +fn multi_scalar_mul( + interner: &mut NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let (points, scalars) = check_two_arguments(arguments, location)?; + + let (points, _) = get_array_map(interner, points, get_embedded_curve_point)?; + let (scalars, _) = get_array_map(interner, scalars, get_embedded_curve_scalar)?; + + let points: Vec<_> = points.into_iter().flat_map(|(x, y, inf)| [x, y, inf.into()]).collect(); + let mut scalars_lo = Vec::new(); + let mut scalars_hi = Vec::new(); + for (lo, hi) in scalars { + scalars_lo.push(lo); + scalars_hi.push(hi); + } + + let (x, y, inf) = Bn254BlackBoxSolver + .multi_scalar_mul(&points, &scalars_lo, &scalars_hi) + .map_err(|e| InterpreterError::BlackBoxError(e, location))?; + + Ok(to_field_array(&[x, y, inf])) +} + +/// `poseidon2_permutation(_input: [Field; N], _state_length: u32) -> [Field; N]` fn poseidon2_permutation( interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { let (input, state_length) = check_two_arguments(arguments, location)?; - let input_location = input.1; - let (input, typ) = get_array(interner, input)?; + let (input, typ) = get_array_map(interner, input, get_field)?; let state_length = get_u32(state_length)?; - let input = try_vecmap(input, |integer| get_field((integer, input_location)))?; - - // Currently locked to only bn254! let fields = Bn254BlackBoxSolver .poseidon2_permutation(&input, state_length) .map_err(|error| InterpreterError::BlackBoxError(error, location))?; @@ -76,25 +333,135 @@ fn poseidon2_permutation( Ok(Value::Array(array, typ)) } +/// `fn keccakf1600(input: [u64; 25]) -> [u64; 25] {}` fn keccakf1600( interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { let input = check_one_argument(arguments, location)?; - let input_location = input.1; - let (input, typ) = get_array(interner, input)?; + let (state, typ) = get_fixed_array_map(interner, input, get_u64)?; - let input = try_vecmap(input, |integer| get_u64((integer, input_location)))?; - - let mut state = [0u64; 25]; - for (it, input_value) in state.iter_mut().zip(input.iter()) { - *it = *input_value; - } let result_lanes = acvm::blackbox_solver::keccakf1600(state) .map_err(|error| InterpreterError::BlackBoxError(error, location))?; let array: Vector = result_lanes.into_iter().map(Value::U64).collect(); Ok(Value::Array(array, typ)) } + +/// `pub fn sha256_compression(input: [u32; 16], state: [u32; 8]) -> [u32; 8]` +fn sha256_compression( + interner: &mut NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let (input, state) = check_two_arguments(arguments, location)?; + + let (input, _) = get_fixed_array_map(interner, input, get_u32)?; + let (mut state, typ) = get_fixed_array_map(interner, state, get_u32)?; + + acvm::blackbox_solver::sha256_compression(&mut state, &input); + + let state = state.into_iter().map(Value::U32).collect(); + Ok(Value::Array(state, typ)) +} + +/// Decode a `BigInt` struct. +/// +/// Returns the ID of the value in the solver. +fn get_bigint_id((value, location): (Value, Location)) -> IResult { + let (fields, typ) = get_struct_fields("BigInt", (value, location))?; + let p = get_struct_field("pointer", &fields, &typ, location, get_u32)?; + let m = get_struct_field("modulus", &fields, &typ, location, get_u32)?; + assert_eq!(p, m, "`pointer` and `modulus` are expected to be the same"); + Ok(p) +} + +/// Decode an `EmbeddedCurvePoint` struct. +/// +/// Returns `(x, y, is_infinite)`. +fn get_embedded_curve_point( + (value, location): (Value, Location), +) -> IResult<(FieldElement, FieldElement, bool)> { + let (fields, typ) = get_struct_fields("EmbeddedCurvePoint", (value, location))?; + let x = get_struct_field("x", &fields, &typ, location, get_field)?; + let y = get_struct_field("y", &fields, &typ, location, get_field)?; + let is_infinite = get_struct_field("is_infinite", &fields, &typ, location, get_bool)?; + Ok((x, y, is_infinite)) +} + +/// Decode an `EmbeddedCurveScalar` struct. +/// +/// Returns `(lo, hi)`. +fn get_embedded_curve_scalar( + (value, location): (Value, Location), +) -> IResult<(FieldElement, FieldElement)> { + let (fields, typ) = get_struct_fields("EmbeddedCurveScalar", (value, location))?; + let lo = get_struct_field("lo", &fields, &typ, location, get_field)?; + let hi = get_struct_field("hi", &fields, &typ, location, get_field)?; + Ok((lo, hi)) +} + +fn to_bigint(id: u32, typ: Type) -> Value { + to_struct([("pointer", Value::U32(id)), ("modulus", Value::U32(id))], typ) +} + +#[cfg(test)] +mod tests { + use acvm::acir::BlackBoxFunc; + use noirc_errors::Location; + use strum::IntoEnumIterator; + + use crate::hir::comptime::tests::with_interpreter; + use crate::hir::comptime::InterpreterError::{ + ArgumentCountMismatch, InvalidInComptimeContext, Unimplemented, + }; + use crate::Type; + + use super::call_foreign; + + /// Check that all `BlackBoxFunc` are covered by `call_foreign`. + #[test] + fn test_blackbox_implemented() { + let dummy = " + comptime fn main() -> pub u8 { + 0 + } + "; + + let not_implemented = with_interpreter(dummy, |interpreter, _, _| { + let no_location = Location::dummy(); + let mut not_implemented = Vec::new(); + + for blackbox in BlackBoxFunc::iter() { + let name = blackbox.name(); + match call_foreign( + interpreter.elaborator.interner, + &mut interpreter.bigint_solver, + name, + Vec::new(), + Type::Unit, + no_location, + ) { + Ok(_) => { + // Exists and works with no args (unlikely) + } + Err(ArgumentCountMismatch { .. }) => { + // Exists but doesn't work with no args (expected) + } + Err(InvalidInComptimeContext { .. }) => {} + Err(Unimplemented { .. }) => not_implemented.push(name), + Err(other) => panic!("unexpected error: {other:?}"), + }; + } + + not_implemented + }); + + assert!( + not_implemented.is_empty(), + "unimplemented blackbox functions: {not_implemented:?}" + ); + } +} diff --git a/compiler/noirc_frontend/src/hir/comptime/tests.rs b/compiler/noirc_frontend/src/hir/comptime/tests.rs index e033ec6ddb9..2d3bf928917 100644 --- a/compiler/noirc_frontend/src/hir/comptime/tests.rs +++ b/compiler/noirc_frontend/src/hir/comptime/tests.rs @@ -9,14 +9,20 @@ use noirc_errors::Location; use super::errors::InterpreterError; use super::value::Value; +use super::Interpreter; use crate::elaborator::Elaborator; -use crate::hir::def_collector::dc_crate::DefCollector; +use crate::hir::def_collector::dc_crate::{CompilationError, DefCollector}; use crate::hir::def_collector::dc_mod::collect_defs; use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleData}; use crate::hir::{Context, ParsedFiles}; +use crate::node_interner::FuncId; use crate::parse_program; -fn interpret_helper(src: &str) -> Result { +/// Create an interpreter for a code snippet and pass it to a test function. +pub(crate) fn with_interpreter( + src: &str, + f: impl FnOnce(&mut Interpreter, FuncId, &[(CompilationError, FileId)]) -> T, +) -> T { let file = FileId::default(); // Can't use Index::test_new here for some reason, even with #[cfg(test)]. @@ -51,14 +57,24 @@ fn interpret_helper(src: &str) -> Result { context.def_maps.insert(krate, collector.def_map); let main = context.get_main_function(&krate).expect("Expected 'main' function"); + let mut elaborator = Elaborator::elaborate_and_return_self(&mut context, krate, collector.items, None); - assert_eq!(elaborator.errors.len(), 0); + + let errors = elaborator.errors.clone(); let mut interpreter = elaborator.setup_interpreter(); - let no_location = Location::dummy(); - interpreter.call_function(main, Vec::new(), HashMap::new(), no_location) + f(&mut interpreter, main, &errors) +} + +/// Evaluate a code snippet by calling the `main` function. +fn interpret_helper(src: &str) -> Result { + with_interpreter(src, |interpreter, main, errors| { + assert_eq!(errors.len(), 0); + let no_location = Location::dummy(); + interpreter.call_function(main, Vec::new(), HashMap::new(), no_location) + }) } fn interpret(src: &str) -> Value { diff --git a/test_programs/noir_test_success/comptime_blackbox/Nargo.toml b/test_programs/noir_test_success/comptime_blackbox/Nargo.toml new file mode 100644 index 00000000000..5eac6f3c91a --- /dev/null +++ b/test_programs/noir_test_success/comptime_blackbox/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "comptime_blackbox" +type = "bin" +authors = [""] +compiler_version = ">=0.27.0" + +[dependencies] diff --git a/test_programs/noir_test_success/comptime_blackbox/src/main.nr b/test_programs/noir_test_success/comptime_blackbox/src/main.nr new file mode 100644 index 00000000000..c3784e73b09 --- /dev/null +++ b/test_programs/noir_test_success/comptime_blackbox/src/main.nr @@ -0,0 +1,155 @@ +//! Tests to show that the comptime interpreter implement blackbox functions. +use std::bigint; +use std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul}; + +/// Test that all bigint operations work in comptime. +#[test] +fn test_bigint() { + let result: [u8] = comptime { + let a = bigint::Secpk1Fq::from_le_bytes(&[0, 1, 2, 3, 4]); + let b = bigint::Secpk1Fq::from_le_bytes(&[5, 6, 7, 8, 9]); + let c = (a + b) * b / a - a; + c.to_le_bytes() + }; + // Do the same calculation outside comptime. + let a = bigint::Secpk1Fq::from_le_bytes(&[0, 1, 2, 3, 4]); + let b = bigint::Secpk1Fq::from_le_bytes(&[5, 6, 7, 8, 9]); + let c = bigint::Secpk1Fq::from_le_bytes(result); + assert_eq(c, (a + b) * b / a - a); +} + +/// Test that to_le_radix returns an array. +#[test] +fn test_to_le_radix() { + comptime { + let field = 2; + let bytes: [u8; 8] = field.to_le_radix(256); + let _num = bigint::BigInt::from_le_bytes(bytes, bigint::bn254_fq); + }; +} + +#[test] +fn test_bitshift() { + let c = comptime { + let a: i32 = 10; + let b: u32 = 4; + a << b as u8 + }; + assert_eq(c, 160); +} + +#[test] +fn test_aes128_encrypt() { + let ciphertext = comptime { + let plaintext: [u8; 5] = [1, 2, 3, 4, 5]; + let iv: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; + let key: [u8; 16] = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]; + std::aes128::aes128_encrypt(plaintext, iv, key) + }; + let clear_len = 5; + let cipher_len = clear_len + 16 - clear_len % 16; + assert_eq(ciphertext.len(), cipher_len); +} + +#[test] +fn test_blake2s() { + let hash = comptime { + let input = [104, 101, 108, 108, 111]; + std::hash::blake2s(input) + }; + assert_eq(hash[0], 0x19); + assert_eq(hash[31], 0x25); +} + +#[test] +fn test_blake3() { + let hash = comptime { + let input = [104, 101, 108, 108, 111]; + std::hash::blake3(input) + }; + assert_eq(hash[0], 0xea); + assert_eq(hash[31], 0x0f); +} + +/// Test that ecdsa_secp256k1 is implemented. +#[test] +fn test_ecdsa_secp256k1() { + let (valid_array, valid_slice) = comptime { + let pub_key_x: [u8; 32] = hex_to_bytes("a0434d9e47f3c86235477c7b1ae6ae5d3442d49b1943c2b752a68e2a47e247c7").as_array(); + let pub_key_y: [u8; 32] = hex_to_bytes("893aba425419bc27a3b6c7e693a24c696f794c2ed877a1593cbee53b037368d7").as_array(); + let signature: [u8; 64] = hex_to_bytes("e5081c80ab427dc370346f4a0e31aa2bad8d9798c38061db9ae55a4e8df454fd28119894344e71b78770cc931d61f480ecbb0b89d6eb69690161e49a715fcd55").as_array(); + let hashed_message: [u8; 32] = hex_to_bytes("3a73f4123a5cd2121f21cd7e8d358835476949d035d9c2da6806b4633ac8c1e2").as_array(); + + let valid_array = std::ecdsa_secp256k1::verify_signature(pub_key_x, pub_key_y, signature, hashed_message); + let valid_slice = std::ecdsa_secp256k1::verify_signature_slice(pub_key_x, pub_key_y, signature, hashed_message.as_slice()); + + (valid_array, valid_slice) + }; + assert(valid_array); + assert(valid_slice); +} + +/// Test that ecdsa_secp256r1 is implemented. +#[test] +fn test_ecdsa_secp256r1() { + let (valid_array, valid_slice) = comptime { + let pub_key_x: [u8; 32] = hex_to_bytes("550f471003f3df97c3df506ac797f6721fb1a1fb7b8f6f83d224498a65c88e24").as_array(); + let pub_key_y: [u8; 32] = hex_to_bytes("136093d7012e509a73715cbd0b00a3cc0ff4b5c01b3ffa196ab1fb327036b8e6").as_array(); + let signature: [u8; 64] = hex_to_bytes("2c70a8d084b62bfc5ce03641caf9f72ad4da8c81bfe6ec9487bb5e1bef62a13218ad9ee29eaf351fdc50f1520c425e9b908a07278b43b0ec7b872778c14e0784").as_array(); + let hashed_message: [u8; 32] = hex_to_bytes("54705ba3baafdbdfba8c5f9a70f7a89bee98d906b53e31074da7baecdc0da9ad").as_array(); + + let valid_array = std::ecdsa_secp256r1::verify_signature(pub_key_x, pub_key_y, signature, hashed_message); + let valid_slice = std::ecdsa_secp256r1::verify_signature_slice(pub_key_x, pub_key_y, signature, hashed_message.as_slice()); + (valid_array, valid_slice) + }; + assert(valid_array); + assert(valid_slice); +} + +/// Test that sha256_compression is implemented. +#[test] +fn test_sha256() { + let hash = comptime { + let input: [u8; 1] = [0xbd]; + std::hash::sha256(input) + }; + assert_eq(hash[0], 0x68); + assert_eq(hash[31], 0x2b); +} + +/// Test that `embedded_curve_add` and `multi_scalar_mul` are implemented. +#[test] +fn test_embedded_curve_ops() { + let (sum, mul) = comptime { + let g1 = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; + let s1 = EmbeddedCurveScalar { lo: 1, hi: 0 }; + let sum = g1 + g1; + let mul = multi_scalar_mul([g1, g1], [s1, s1]); + (sum, mul) + }; + assert_eq(sum, mul); +} + +/// Parse a lowercase hexadecimal string (without 0x prefix) as byte slice. +comptime fn hex_to_bytes(s: str) -> [u8] { + assert(N % 2 == 0); + let mut out = &[]; + let bz = s.as_bytes(); + let mut h: u32 = 0; + for i in 0 .. bz.len() { + let ascii = bz[i]; + let d = if ascii < 58 { + ascii - 48 + } else { + assert(ascii >= 97); // enforce >= 'a' + assert(ascii <= 102); // enforce <= 'f' + ascii - 87 + }; + h = h * 16 + d as u32; + if i % 2 == 1 { + out = out.push_back(h as u8); + h = 0; + } + } + out +} diff --git a/tooling/noirc_abi/Cargo.toml b/tooling/noirc_abi/Cargo.toml index a7baf334bff..22114408e18 100644 --- a/tooling/noirc_abi/Cargo.toml +++ b/tooling/noirc_abi/Cargo.toml @@ -23,8 +23,8 @@ num-bigint = "0.4" num-traits = "0.2" [dev-dependencies] -strum = "0.24" -strum_macros = "0.24" +strum.workspace = true +strum_macros.workspace = true proptest.workspace = true proptest-derive.workspace = true